def __init__(self, uid, arbor=None, root=False): """ Initialize a TreeNode with at least its halo catalog ID and its level in the tree. """ self.uid = uid self.arbor = weakref.proxy(arbor) if root: self.root = -1 self._field_data = FieldContainer(arbor) else: self.root = None
def __init__(self, filename): """ Initialize an Arbor given an input file. """ self.filename = filename self.basename = os.path.basename(filename) self._parse_parameter_file() self._set_units() self._field_data = FieldContainer(self) self._node_io = self._tree_field_io_class(self) self._root_io = self._root_field_io_class(self) self._get_data_files() self._setup_fields() self._set_default_selector()
def _setup_fields(self): """ Setup field containers and definitions. """ self.field_data = FieldContainer(self) self.derived_field_list = [] self.analysis_field_list = [] self.field_info.setup_known_fields() self.field_info.setup_aliases() self.field_info.setup_derived_fields() self.field_info.setup_vector_fields()
class Arbor(object, metaclass=RegisteredArbor): """ Base class for all Arbor classes. Loads a merger-tree output file or a series of halo catalogs and create trees, stored in an array in :func:`~ytree.data_structures.arbor.Arbor.trees`. Arbors can be saved in a universal format with :func:`~ytree.data_structures.arbor.Arbor.save_arbor`. Also, provide some convenience functions for creating YTArrays and YTQuantities and a cosmology calculator. """ _field_info_class = FieldInfoContainer _root_field_io_class = DefaultRootFieldIO _tree_field_io_class = TreeFieldIO def __init__(self, filename): """ Initialize an Arbor given an input file. """ self.filename = filename self.basename = os.path.basename(filename) self._parse_parameter_file() self._set_units() self._field_data = FieldContainer(self) self._node_io = self._tree_field_io_class(self) self._root_io = self._root_field_io_class(self) self._get_data_files() self._setup_fields() self._set_default_selector() def _get_data_files(self): """ Get all files that hold field data and make them known to the i/o system. """ pass def _parse_parameter_file(self): """ Read relevant parameters from parameter file or file header and detect fields. """ raise NotImplementedError def _plant_trees(self): """ Create the list of root tree nodes. """ raise NotImplementedError def is_setup(self, tree_node): """ Return True if arrays of uids and descendent uids have been read in. Setup has also completed if tree is already grown. """ return self.is_grown(tree_node) or \ tree_node._uids is not None def _setup_tree(self, tree_node, **kwargs): """ Create arrays of uids and desc_uids and attach them to the root node. """ # skip if this is not a root or if already setup if self.is_setup(tree_node): return idtype = np.int64 fields, _ = \ self.field_info.resolve_field_dependencies(["uid", "desc_uid"]) halo_id_f, desc_id_f = fields dtypes = {halo_id_f: idtype, desc_id_f: idtype} field_data = self._node_io._read_fields(tree_node, fields, dtypes=dtypes, **kwargs) tree_node._uids = field_data[halo_id_f] tree_node._desc_uids = field_data[desc_id_f] tree_node._tree_size = tree_node._uids.size def is_grown(self, tree_node): """ Return True if a tree has been fully assembled, i.e., the hierarchy of ancestor tree nodes has been built. """ return tree_node.root != -1 def _grow_tree(self, tree_node, **kwargs): """ Create an array of TreeNodes hanging off the root node and assemble the tree structure. """ # skip this if not a root or if already grown if self.is_grown(tree_node): return self._setup_tree(tree_node, **kwargs) nhalos = tree_node.uids.size nodes = np.empty(nhalos, dtype=np.object) nodes[0] = tree_node for i in range(1, nhalos): nodes[i] = TreeNode(tree_node.uids[i], arbor=self) tree_node._nodes = nodes # Add tree information to nodes uidmap = {} for i, node in enumerate(nodes): node.treeid = i node.root = tree_node uidmap[tree_node.uids[i]] = i # Link ancestor/descendents # Separate loop for trees like lhalotree where descendent # can follow in order for i, node in enumerate(nodes): descid = tree_node.desc_uids[i] if descid != -1: desc = nodes[uidmap[descid]] desc.add_ancestor(node) node.descendent = desc def _node_io_loop(self, func, *args, **kwargs): """ Call the provided function over a list of nodes. If possible, group nodes by common data files to speed things up. This should work like __iter__, except we call a function instead of yielding. Parameters ---------- func : function Function to be called on an array of nodes. pbar : optional, string or yt.funcs.TqdmProgressBar A progress bar to be updated with each iteration. If a string, a progress bar will be created and the finish function will be called. If a progress bar is provided, the finish function will not be called. Default: None (no progress bar). root_nodes : optional, array of root TreeNodes Array of nodes over which the function will be called. If None, the list will be self.trees (i.e., all root_nodes). Default: None. store : optional, string If not None, any return value captured from the function will be stored in an attribute with this name associated with the TreeNode. Default: None. """ pbar = kwargs.pop("pbar", None) root_nodes = kwargs.pop("root_nodes", None) if root_nodes is None: root_nodes = self.trees store = kwargs.pop("store", None) data_files, node_list = self._node_io_loop_prepare(root_nodes) nnodes = sum([nodes.size for nodes in node_list]) finish = True if pbar is None: pbar = fake_pbar("", nnodes) elif not isinstance(pbar, TqdmProgressBar): pbar = get_pbar(pbar, nnodes) else: finish = False for data_file, nodes in zip(data_files, node_list): self._node_io_loop_start(data_file) for node in nodes: rval = func(node, *args, **kwargs) if store is not None: setattr(node, store, rval) pbar.update(1) self._node_io_loop_finish(data_file) if finish: pbar.finish() def _node_io_loop_start(self, data_file): pass def _node_io_loop_finish(self, data_file): pass def _node_io_loop_prepare(self, root_nodes): """ This is called at the beginning of _node_io_loop. In different frontends, this can be used to group nodes by common data files. Return ------ list of data files and a list of node arrays Each data file corresponds to an array of nodes. """ return [None], [root_nodes] def __iter__(self): """ Iterate over all items in the tree list. If possible, group nodes by common data files to speed things up. """ data_files, node_list = self._node_io_loop_prepare(self.trees) for data_file, nodes in zip(data_files, node_list): self._node_io_loop_start(data_file) for node in nodes: yield node self._node_io_loop_finish(data_file) _trees = None @property def trees(self): """ Array containing all trees in the arbor. """ if self._trees is None: self._plant_trees() return self._trees def __repr__(self): return self.basename def __getitem__(self, key): return self.query(key) def query(self, key): """ If given a string, return an array of field values for the roots of all trees. If given an integer, return a tree from the list of trees. """ if isinstance(key, str): if key in ("tree", "prog"): raise SyntaxError("Argument must be a field or integer.") self._root_io.get_fields(self, fields=[key]) if self.field_info[key].get("type") == "analysis": return self._field_data.pop(key) return self._field_data[key] return self.trees[key] def __len__(self): """ Return length of tree list. """ return self.trees.size _field_info = None @property def field_info(self): """ A dictionary containing information for each available field. """ if self._field_info is None and \ self._field_info_class is not None: self._field_info = self._field_info_class(self) return self._field_info @property def size(self): """ Return length of tree list. """ return self.trees.size _unit_registry = None @property def unit_registry(self): """ Unit system registry. """ if self._unit_registry is None: self._unit_registry = UnitRegistry() return self._unit_registry @unit_registry.setter def unit_registry(self, value): self._unit_registry = value self._arr = None self._quan = None _hubble_constant = None @property def hubble_constant(self): """ Value of the Hubble parameter. """ return self._hubble_constant @hubble_constant.setter def hubble_constant(self, value): self._hubble_constant = value # reset the unit registry lut while preserving other changes self.unit_registry = UnitRegistry.from_json( self.unit_registry.to_json()) self.unit_registry.modify("h", self.hubble_constant) _box_size = None @property def box_size(self): """ The simulation box size. """ return self._box_size @box_size.setter def box_size(self, value): self._box_size = value # set unitary as soon as we know the box size self.unit_registry.add( "unitary", float(self.box_size.in_base()), length) def _setup_fields(self): self.derived_field_list = [] self.analysis_field_list = [] self.field_info.setup_known_fields() self.field_info.setup_aliases() self.field_info.setup_derived_fields() self.field_info.setup_vector_fields() def _set_units(self): """ Set "cm" units for explicitly comoving. Note, we are using comoving units all the time since we are dealing with data at multiple redshifts. """ for my_unit in ["m", "pc", "AU", "au"]: new_unit = "%scm" % my_unit self._unit_registry.add( new_unit, self._unit_registry.lut[my_unit][0], length, self._unit_registry.lut[my_unit][3]) self.cosmology = Cosmology( hubble_constant=self.hubble_constant, omega_matter=self.omega_matter, omega_lambda=self.omega_lambda, unit_registry=self.unit_registry) def set_selector(self, selector, *args, **kwargs): r""" Sets the tree node selector to be used. This sets the manner in which halo progenitors are chosen from a list of ancestors. The most obvious example is to select the most massive ancestor. Parameters ---------- selector : string Name of the selector to be used. Any additional arguments and keywords to be provided to the selector function should follow. Examples -------- >>> import ytree >>> a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat") >>> a.set_selector("max_field_value", "mass") """ self.selector = tree_node_selector_registry.find( selector, *args, **kwargs) _arr = None @property def arr(self): """ Create a YTArray using the Arbor's unit registry. """ if self._arr is not None: return self._arr self._arr = functools.partial(YTArray, registry=self.unit_registry) return self._arr _quan = None @property def quan(self): """ Create a YTQuantity using the Arbor's unit registry. """ if self._quan is not None: return self._quan self._quan = functools.partial(YTQuantity, registry=self.unit_registry) return self._quan def _set_default_selector(self): """ Set the default tree node selector as maximum mass. """ self.set_selector("max_field_value", "mass") def select_halos(self, criteria, trees=None, select_from="tree", fields=None): """ Select halos from the arbor based on a set of criteria given as a string. Parameters ---------- criteria: string A string that will eval to a Numpy-like selection operation performed on a TreeNode object called "tree". Example: 'tree["tree", "redshift"] > 1' trees : optional, list or array of TreeNodes A list or array of TreeNode objects in which to search. If none given, the search is performed over the full arbor. select_from : optional, "tree" or "prog" Determines whether to perform the search over the full tree or just the main progenitors. Note, the value given must be consistent with what appears in the criteria string. For example, a criteria string of 'tree["tree", "redshift"] > 1' cannot be used when setting select_from to "prog". Default: "tree". fields : optional, list of strings Use to provide a list of fields required by the criteria evaluation. If given, fields will be preloaded in an optimized way and the search will go faster. Default: None. Returns ------- halos : array of TreeNodes A flat array of all TreeNodes meeting the criteria. Examples -------- >>> import ytree >>> a = ytree.load("tree_0_0_0.dat") >>> halos = a.select_halos('tree["tree", "redshift"] > 1', ... fields=["redshift"]) >>> >>> halos = a.select_halos('tree["prog", "mass"].to("Msun") >= 1e10', ... select_from="prog", fields=["mass"]) """ if select_from not in ["tree", "prog"]: raise SyntaxError( "Keyword \"select_from\" must be either \"tree\" or \"prog\".") if trees is None: trees = self.trees if fields is None: fields = [] self._node_io_loop(self._setup_tree, root_nodes=trees, pbar="Setting up trees") if fields: self._node_io_loop( self._node_io.get_fields, pbar="Getting fields", root_nodes=trees, fields=fields, root_only=False) halos = [] pbar = get_pbar("Selecting halos", self.trees.size) for tree in trees: my_filter = np.asarray(eval(criteria)) if my_filter.size != tree[select_from].size: raise RuntimeError( ("Filter array and tree array sizes do not match. " + "Make sure select_from (\"%s\") matches criteria (\"%s\").") % (select_from, criteria)) halos.extend(tree[select_from][my_filter]) pbar.update(1) pbar.finish() return np.array(halos) def add_analysis_field(self, name, units): r""" Add an empty field to be filled by analysis operations. Parameters ---------- name : string Field name. units : string Field units. Examples -------- >>> import ytree >>> a = ytree.load("tree_0_0_0.dat") >>> a.add_analysis_field("robots", "Msun * kpc") >>> # Set field for some halo. >>> a[0]["tree"][7]["robots"] = 1979.816 """ if name in self.field_info: raise ArborFieldAlreadyExists(name, arbor=self) self.analysis_field_list.append(name) self.field_info[name] = {"type": "analysis", "units": units} def add_alias_field(self, alias, field, units=None, force_add=True): r""" Add a field as an alias to another field. Parameters ---------- alias : string Alias name. field : string The field to be aliased. units : optional, string Units in which the field will be returned. force_add : optional, bool If True, add field even if it already exists and warn the user and raise an exception if dependencies do not exist. If False, silently do nothing in both instances. Default: True. Examples -------- >>> import ytree >>> a = ytree.load("tree_0_0_0.dat") >>> # "Mvir" exists on disk >>> a.add_alias_field("mass", "Mvir", units="Msun") >>> print (a["mass"]) """ if alias in self.field_info: if force_add: ftype = self.field_info[alias].get("type", "on-disk") if ftype in ["alias", "derived"]: fl = self.derived_field_list else: fl = self.field_list mylog.warn( ("Overriding field \"%s\" that already " + "exists as %s field.") % (alias, ftype)) fl.pop(fl.index(alias)) else: return if field not in self.field_info: if force_add: raise ArborFieldDependencyNotFound( field, alias, arbor=self) else: return if units is None: units = self.field_info[field].get("units") self.derived_field_list.append(alias) self.field_info[alias] = \ {"type": "alias", "units": units, "dependencies": [field]} if "aliases" not in self.field_info[field]: self.field_info[field]["aliases"] = [] self.field_info[field]["aliases"].append(alias) def add_derived_field(self, name, function, units=None, description=None, vector_field=False, force_add=True): r""" Add a field that is a function of other fields. Parameters ---------- name : string Field name. function : callable The function to be called to generate the field. This function should take two arguments, the arbor and the data structure containing the dependent fields. See below for an example. units : optional, string The units in which the field will be returned. description : optional, string A short description of the field. vector_field: optional, bool If True, field is an xyz vector. Default: False. force_add : optional, bool If True, add field even if it already exists and warn the user and raise an exception if dependencies do not exist. If False, silently do nothing in both instances. Default: True. Examples -------- >>> import ytree >>> a = ytree.load("tree_0_0_0.dat") >>> def _redshift(field, data): ... return 1. / data["scale"] - 1 ... >>> a.add_derived_field("redshift", _redshift) >>> print (a["redshift"]) """ if name in self.field_info: if force_add: ftype = self.field_info[name].get("type", "on-disk") if ftype in ["alias", "derived"]: fl = self.derived_field_list else: fl = self.field_list mylog.warn( ("Overriding field \"%s\" that already " + "exists as %s field.") % (name, ftype)) fl.pop(fl.index(name)) else: return if units is None: units = "" info = {"name": name, "type": "derived", "function": function, "units": units, "vector_field": vector_field, "description": description} fc = FakeFieldContainer(self, name=name) try: rv = function(info, fc) except TypeError as e: raise RuntimeError( """ Field function syntax in ytree has changed. Field functions must now take two arguments, as in the following: def my_field(field, data): return data['mass'] Check the TypeError exception above for more details. """) raise e except ArborFieldDependencyNotFound as e: if force_add: raise e else: return rv.convert_to_units(units) info["dependencies"] = list(fc.keys()) self.derived_field_list.append(name) self.field_info[name] = info @classmethod def _is_valid(cls, *args, **kwargs): """ Check if input file works with a specific Arbor class. This is used with :func:`~ytree.data_structures.arbor.load` function. """ return False def save_arbor(self, filename="arbor", fields=None, trees=None, max_file_size=524288): r""" Save the arbor to a file. The saved arbor can be re-loaded as an arbor. Parameters ---------- filename : optional, string Output file keyword. If filename ends in ".h5", the main header file will be just that. If not, filename will be <filename>/<basename>.h5. Default: "arbor". fields : optional, list of strings The fields to be saved. If not given, all fields will be saved. Returns ------- header_filename : string The filename of the saved arbor. Examples -------- >>> import ytree >>> a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat") >>> fn = a.save_arbor() >>> # reload it >>> a2 = ytree.load(fn) """ if trees is None: all_trees = True trees = self.trees roots = trees else: all_trees = False # assemble unique tree roots for getting fields trees = np.asarray(trees) roots = [] root_uids = [] for tree in trees: if tree.root == -1: my_root = tree else: my_root = tree.root if my_root.uid not in root_uids: roots.append(my_root) root_uids.append(my_root.uid) roots = np.array(roots) del root_uids if fields in [None, "all"]: # If a field has an alias, get that instead. fields = [] for field in self.field_list + self.analysis_field_list: fields.extend( self.field_info[field].get("aliases", [field])) else: fields.extend([f for f in ["uid", "desc_uid"] if f not in fields]) ds = {} for attr in ["hubble_constant", "omega_matter", "omega_lambda"]: if hasattr(self, attr): ds[attr] = getattr(self, attr) extra_attrs = {"box_size": self.box_size, "arbor_type": "YTreeArbor", "unit_registry_json": self.unit_registry.to_json()} self._node_io_loop(self._setup_tree, root_nodes=roots, pbar="Setting up trees") if all_trees: self._root_io.get_fields(self, fields=fields) # determine file layout nn = 0 # node count nt = 0 # tree count nnodes = [] ntrees = [] tree_size = np.array([tree.tree_size for tree in trees]) for ts in tree_size: nn += ts nt += 1 if nn > max_file_size: nnodes.append(nn-ts) ntrees.append(nt-1) nn = ts nt = 1 if nn > 0: nnodes.append(nn) ntrees.append(nt) nfiles = len(nnodes) nnodes = np.array(nnodes) ntrees = np.array(ntrees) tree_end_index = ntrees.cumsum() tree_start_index = tree_end_index - ntrees # write header file fieldnames = [field.replace("/", "_") for field in fields] myfi = {} rdata = {} rtypes = {} for field, fieldname in zip(fields, fieldnames): fi = self.field_info[field] myfi[fieldname] = \ dict((key, fi[key]) for key in ["units", "description"] if key in fi) if all_trees: rdata[fieldname] = self._field_data[field] else: rdata[fieldname] = self.arr([t[field] for t in trees]) rtypes[fieldname] = "data" # all saved trees will be roots if not all_trees: rdata["desc_uid"][:] = -1 extra_attrs["field_info"] = json.dumps(myfi) extra_attrs["total_files"] = nfiles extra_attrs["total_trees"] = trees.size extra_attrs["total_nodes"] = tree_size.sum() hdata = {"tree_start_index": tree_start_index, "tree_end_index" : tree_end_index, "tree_size" : ntrees} hdata.update(rdata) htypes = dict((f, "index") for f in hdata) htypes.update(rtypes) filename = _determine_output_filename(filename, ".h5") header_filename = "%s.h5" % filename save_as_dataset(ds, header_filename, hdata, field_types=htypes, extra_attrs=extra_attrs) # write data files ftypes = dict((f, "data") for f in fieldnames) for i in range(nfiles): my_nodes = trees[tree_start_index[i]:tree_end_index[i]] self._node_io_loop( self._node_io.get_fields, pbar="Getting fields [%d/%d]" % (i+1, nfiles), root_nodes=my_nodes, fields=fields, root_only=False) fdata = dict((field, np.empty(nnodes[i])) for field in fieldnames) my_tree_size = tree_size[tree_start_index[i]:tree_end_index[i]] my_tree_end = my_tree_size.cumsum() my_tree_start = my_tree_end - my_tree_size pbar = get_pbar("Creating field arrays [%d/%d]" % (i+1, nfiles), len(fields)*nnodes[i]) c = 0 for field, fieldname in zip(fields, fieldnames): for di, node in enumerate(my_nodes): if node.is_root: ndata = node._field_data[field] else: ndata = node["tree", field] if field == "desc_uid": # make sure it's a root when loaded ndata[0] = -1 fdata[fieldname][ my_tree_start[di]:my_tree_end[di]] = ndata c += my_tree_size[di] pbar.update(c) pbar.finish() fdata["tree_start_index"] = my_tree_start fdata["tree_end_index"] = my_tree_end fdata["tree_size"] = my_tree_size for ft in ["tree_start_index", "tree_end_index", "tree_size"]: ftypes[ft] = "index" my_filename = "%s_%04d.h5" % (filename, i) save_as_dataset({}, my_filename, fdata, field_types=ftypes) return header_filename
class TreeNode: """ Class for objects stored in Arbors. Each TreeNode represents a halo in a tree. A TreeNode knows its halo ID, the level in the tree, and its global ID in the Arbor that holds it. It also has a list of its ancestors. Fields can be queried for it, its progenitor list, and the tree beneath. """ _link = None def __init__(self, uid, arbor=None, root=False): """ Initialize a TreeNode with at least its halo catalog ID and its level in the tree. """ self.uid = uid self.arbor = weakref.proxy(arbor) if root: self.root = -1 self._field_data = FieldContainer(arbor) else: self.root = None _tree_id = None # used by CatalogArbor @property def tree_id(self): """ Return the index of this node in a list of all nodes in the tree. """ if self.is_root: return 0 elif self._link is not None: return self._link.tree_id else: return self._tree_id @tree_id.setter def tree_id(self, value): """ Set the tree_id manually in CatalogArbors. """ self._tree_id = value @property def is_root(self): """ Is this node the last in the tree? """ return self.root in [-1, self] def find_root(self): """ Find the root node. """ if self.is_root: return self root = self.root if root is not None: return root return self.walk_to_root() def walk_to_root(self): """ Walk descendents until root. """ my_node = self while not my_node.is_root: if my_node.descendent in (-1, None): break my_node = my_node.descendent return my_node def clear_fields(self): """ If a root node, delete field data. If not root node, do nothing. """ if not self.is_root: return self._field_data.clear() _descendent = None # used by CatalogArbor @property def descendent(self): """ Return the descendent node. """ if self.is_root: return None # set in CatalogArbor._plant_trees if self._descendent is not None: return self._descendent # conventional Arbor object desc_link = self._link.descendent return self.arbor._generate_tree_node(self.root, desc_link) _ancestors = None # used by CatalogArbor @property def ancestors(self): """ Return a generator of ancestor nodes. """ self.arbor._grow_tree(self) # conventional Arbor object if self._link is not None: for link in self._link.ancestors: yield self.arbor._generate_tree_node(self.root, link) return # set in CatalogArbor._plant_trees if self._ancestors is not None: for ancestor in self._ancestors: yield ancestor return return None _uids = None @property def uids(self): """ Array of uids for all nodes in the tree. """ if not self.is_root: return None if self._uids is None: self.arbor._build_attr("_uids", self) return self._uids _desc_uids = None @property def desc_uids(self): """ Array of descendent uids for all nodes in the tree. """ if not self.is_root: return None if self._desc_uids is None: self.arbor._build_attr("_desc_uids", self) return self._desc_uids _tree_size = None @property def tree_size(self): """ Number of nodes in the tree. """ if self._tree_size is not None: return self._tree_size if self.is_root: self.arbor._setup_tree(self) # pass back to the arbor to avoid calculating again self.arbor._store_node_info(self, '_tree_size') else: self._tree_size = len(list(self["tree"])) return self._tree_size _link_storage = None @property def _links(self): """ Array of NodeLink objects with the ancestor/descendent structure. This is only used by conventional Arbor objects, i.e., not CatalogArbor objects. """ if not self.is_root: return None if self._link_storage is None: self.arbor._build_attr("_link_storage", self) return self._link_storage def __setitem__(self, key, value): """ Set analysis field value for this node. """ if self.is_root: root = self tree_id = 0 # if root, set the value in the arbor field storage self.arbor._field_data[key][self._arbor_index] = value else: root = self.root tree_id = self.tree_id self.arbor._node_io.get_fields(self, fields=[key], root_only=False) data = root._field_data[key] data[tree_id] = value def __getitem__(self, key): """ Return field values or tree/prog generators. """ return self.query(key) def query(self, key): """ Return field values for this TreeNode, progenitor list, or tree. Parameters ---------- key : string or tuple If a single string, it can be either a field to be queried or one of "tree" or "prog". If a field, then return the value of the field for this TreeNode. If "tree" or "prog", then return the list of TreeNodes in the tree or progenitor list. If a tuple, this can be either (string, string) or (string, int), where the first argument must be either "tree" or "prog". If second argument is a string, then return the field values for either the tree or the progenitor list. If second argument is an int, then return the nth TreeNode in the tree or progenitor list list. Examples -------- >>> # virial mass for this halo >>> print (my_tree["mvir"].to("Msun/h")) >>> # all TreeNodes in the progenitor list >>> print (my_tree["prog"]) >>> # all TreeNodes in the entire tree >>> print (my_tree["tree"]) >>> # virial masses for the progenitor list >>> print (my_tree["prog", "mvir"].to("Msun/h")) >>> # the 3rd TreeNode in the progenitor list >>> print (my_tree["prog", 2]) Returns ------- float, ndarray/unyt_array, TreeNode """ arr_types = ("forest", "prog", "tree") if isinstance(key, tuple): if len(key) != 2: raise SyntaxError("Must be either 1 or 2 arguments.") ftype, field = key if ftype not in arr_types: raise SyntaxError("First argument must be one of %s." % str(arr_types)) if not isinstance(field, str): raise SyntaxError("Second argument must be a string.") self.arbor._node_io.get_fields(self, fields=[field], root_only=False) indices = getattr(self, "_%s_field_indices" % ftype) data_object = self.find_root() return data_object._field_data[field][indices] else: if not isinstance(key, str): raise SyntaxError("Single argument must be a string.") # return the progenitor list or tree nodes in a list if key in arr_types: self.arbor._setup_tree(self) return getattr(self, "_%s_nodes" % key) # return field value for this node self.arbor._node_io.get_fields(self, fields=[key], root_only=self.is_root) data_object = self.find_root() return data_object._field_data[key][self.tree_id] def __repr__(self): """ Call me TreeNode. """ return "TreeNode[%d]" % self.uid _ffi = slice(None) @property def _forest_field_indices(self): """ Return default slice to select the whole forest. """ return self._ffi @property def _forest_nodes(self): """ An iterator over all TreeNodes in the forest. This is different from _tree_nodes in that we don't walk through the ancestors lists. We just yield every TreeNode there is. """ self.arbor._grow_tree(self) root = self.root for link in root._links: yield self.arbor._generate_tree_node(self.root, link) @property def _tree_nodes(self): """ An iterator over all TreeNodes in the tree beneath, starting with this TreeNode. For internal use only. Use the following instead: >>> for my_node in my_tree['tree']: ... print (my_node) Examples -------- >>> for my_node in my_tree._tree_nodes: ... print (my_node) """ self.arbor._grow_tree(self) yield self if self.ancestors is None: return for ancestor in self.ancestors: for a_node in ancestor._tree_nodes: yield a_node _tfi = None @property def _tree_field_indices(self): """ Return the field array indices for all TreeNodes in the tree beneath, starting with this TreeNode. """ if self._tfi is not None: return self._tfi self.arbor._grow_tree(self) self._tfi = np.array([node.tree_id for node in self._tree_nodes]) return self._tfi @property def _prog_nodes(self): """ An iterator over all TreeNodes in the progenitor list, starting with this TreeNode. For internal use only. Use the following instead: >>> for my_node in my_tree['prog']: ... print (my_node) Examples -------- >>> for my_node in my_tree._prog_nodes: ... print (my_node) """ self.arbor._grow_tree(self) my_node = self while my_node is not None: yield my_node ancestors = list(my_node.ancestors) if ancestors: my_node = my_node.arbor.selector(ancestors) else: my_node = None _pfi = None @property def _prog_field_indices(self): """ Return the field array indices for all TreeNodes in the progenitor list, starting with this TreeNode. """ if self._pfi is not None: return self._pfi self.arbor._grow_tree(self) self._pfi = np.array([node.tree_id for node in self._prog_nodes]) return self._pfi def save_tree(self, filename=None, fields=None): r""" Save the tree to a file. The saved tree can be re-loaded as an arbor. Parameters ---------- filename : optional, string Output file keyword. Main header file will be named <filename>/<filename>.h5. Default: "tree_<uid>". fields : optional, list of strings The fields to be saved. If not given, all fields will be saved. Returns ------- filename : string The filename of the saved arbor. Examples -------- >>> import ytree >>> a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat") >>> # save the first tree >>> fn = a[0].save_tree() >>> # reload it >>> a2 = ytree.load(fn) """ if filename is None: filename = "tree_%d" % self.uid return self.arbor.save_arbor(filename=filename, fields=fields, trees=[self])
class TreeNode(object): """ Class for objects stored in Arbors. Each TreeNode represents a halo in a tree. A TreeNode knows its halo ID, the level in the tree, and its global ID in the Arbor that holds it. It also has a list of its ancestors. Fields can be queried for it, its progenitor list, and the tree beneath. """ def __init__(self, uid, arbor=None, root=False): """ Initialize a TreeNode with at least its halo catalog ID and its level in the tree. """ self.uid = uid self.arbor = weakref.proxy(arbor) if root: self.root = -1 self.treeid = 0 self.descendent = None self._field_data = FieldContainer(arbor) else: self.root = None @property def is_root(self): return self.root in [-1, self] def find_root(self): """ Find the root node. """ my_node = self while not my_node.is_root: if my_node.descendent == -1: break my_node = my_node.descendent return my_node def clear_fields(self): """ If a root node, delete field data. If not root node, do nothing. """ if not self.is_root: return self._field_data.clear() def reset(self): """ Reset all data structures. """ self.clear_fields() attrs = ["_tfi", "_tn", "_pfi", "_pn"] if self.is_root: self.root = -1 attrs.extend(["_nodes", "_desc_uids", "_uids"]) else: self.root = None for attr in attrs: setattr(self, attr, None) def add_ancestor(self, ancestor): """ Add another TreeNode to the list of ancestors. Parameters ---------- ancestor : TreeNode The ancestor TreeNode. """ if self._ancestors is None: self._ancestors = [] self._ancestors.append(ancestor) _ancestors = None @property def ancestors(self): if self.root == -1: self.arbor._grow_tree(self) return self._ancestors _uids = None @property def uids(self): if not self.is_root: return None if self._uids is None: self.arbor._setup_tree(self) return self._uids _desc_uids = None @property def desc_uids(self): if not self.is_root: return None if self._desc_uids is None: self.arbor._setup_tree(self) return self._desc_uids _tree_size = None @property def tree_size(self): if not self.is_root: return self["tree"].size if self._tree_size is None: self.arbor._setup_tree(self) return self._tree_size _nodes = None @property def nodes(self): if not self.is_root: return None self.arbor._grow_tree(self) return self._nodes def __setitem__(self, key, value): if self.root == -1: root = self treeid = 0 else: root = self.root treeid = self.treeid self.arbor._node_io.get_fields(self, fields=[key], root_only=False) data = root._field_data[key] data[treeid] = value def __getitem__(self, key): return self.query(key) def query(self, key): """ Return field values for this TreeNode, progenitor list, or tree. Parameters ---------- key : string or tuple If a single string, it can be either a field to be queried or one of "tree" or "prog". If a field, then return the value of the field for this TreeNode. If "tree" or "prog", then return the list of TreeNodes in the tree or progenitor list. If a tuple, this can be either (string, string) or (string, int), where the first argument must be either "tree" or "prog". If second argument is a string, then return the field values for either the tree or the progenitor list. If second argument is an int, then return the nth TreeNode in the tree or progenitor list list. Examples -------- >>> # virial mass for this halo >>> print (my_tree["mvir"].to("Msun/h")) >>> # all TreeNodes in the progenitor list >>> print (my_tree["prog"]) >>> # all TreeNodes in the entire tree >>> print (my_tree["tree"]) >>> # virial masses for the progenitor list >>> print (my_tree["prog", "mvir"].to("Msun/h")) >>> # the 3rd TreeNode in the progenitor list >>> print (my_tree["prog", 2]) Returns ------- float, ndarray/YTArray, TreeNode """ arr_types = ("prog", "tree") if isinstance(key, tuple): if len(key) != 2: raise SyntaxError( "Must be either 1 or 2 arguments.") ftype, field = key if ftype not in arr_types: raise SyntaxError( "First argument must be one of %s." % str(arr_types)) if not isinstance(field, str): raise SyntaxError("Second argument must be a string.") self.arbor._node_io.get_fields(self, fields=[field], root_only=False) indices = getattr(self, "_%s_field_indices" % ftype) return self.root._field_data[field][indices] else: if not isinstance(key, str): raise SyntaxError("Single argument must be a string.") # return the progenitor list or tree nodes in a list if key in arr_types: self.arbor._setup_tree(self) return getattr(self, "_%s_nodes" % key) # return field value for this node self.arbor._node_io.get_fields(self, fields=[key], root_only=self.is_root) if self.is_root: data_object = self else: data_object = self.root return data_object._field_data[key][self.treeid] def __repr__(self): return "TreeNode[%d]" % self.uid _tfi = None @property def _tree_field_indices(self): """ Return the field array indices for all TreeNodes in the tree beneath, starting with this TreeNode. """ if self._tfi is None: self._set_tree_attrs() return self._tfi _tn = None @property def _tree_nodes(self): """ Return a list of all TreeNodes in the tree beneath, starting with this TreeNode. """ if self._tn is None: self._set_tree_attrs() return self._tn def _set_tree_attrs(self): """ Prepare the TreeNode list and field indices. """ self.arbor._grow_tree(self) tfi = [] tn = [] for my_node in self.twalk(): tfi.append(my_node.treeid) tn.append(my_node) self._tfi = np.array(tfi) self._tn = np.array(tn) _pfi = None @property def _prog_field_indices(self): """ Return the field array indices for all TreeNodes in the progenitor list, starting with this TreeNode. """ if self._pfi is None: self._set_prog_attrs() return self._pfi _pn = None @property def _prog_nodes(self): """ Return a list of all TreeNodes in the progenitor list, starting with this TreeNode. """ if self._pn is None: self._set_prog_attrs() return self._pn def _set_prog_attrs(self): """ Prepare the progenitor list list and field indices. """ self.arbor._grow_tree(self) lfi = [] ln = [] for my_node in self.pwalk(): lfi.append(my_node.treeid) ln.append(my_node) self._pfi = np.array(lfi) self._pn = np.array(ln) def twalk(self): r""" An iterator over all TreeNodes in the tree beneath, starting with this TreeNode. Examples -------- >>> for my_node in my_tree.twalk(): ... print (my_node) """ self.arbor._grow_tree(self) yield self if self.ancestors is None: return for ancestor in self.ancestors: for a_node in ancestor.twalk(): yield a_node def pwalk(self): r""" An iterator over all TreeNodes in the progenitor list, starting with this TreeNode. Examples -------- >>> for my_node in my_tree.pwalk(): ... print (my_node) """ self.arbor._grow_tree(self) my_node = self while my_node is not None: yield my_node if my_node.ancestors is None: my_node = None else: my_node = my_node.arbor.selector(my_node.ancestors) def save_tree(self, filename=None, fields=None): r""" Save the tree to a file. The saved tree can be re-loaded as an arbor. Parameters ---------- filename : optional, string Output file keyword. Main header file will be named <filename>/<filename>.h5. Default: "tree_<uid>". fields : optional, list of strings The fields to be saved. If not given, all fields will be saved. Returns ------- filename : string The filename of the saved arbor. Examples -------- >>> import ytree >>> a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat") >>> # save the first tree >>> fn = a[0].save_tree() >>> # reload it >>> a2 = ytree.load(fn) """ if filename is None: filename = "tree_%d" % self.uid return self.arbor.save_arbor( filename=filename, fields=fields, trees=[self])
class TreeNode: """ Class for objects stored in Arbors. Each TreeNode represents a halo in a tree. A TreeNode knows its halo ID, the level in the tree, and its global ID in the Arbor that holds it. It also has a list of its ancestors. Fields can be queried for it, its progenitor list, and the tree beneath. """ _link = None def __init__(self, uid, arbor=None, root=False): """ Initialize a TreeNode with at least its halo catalog ID and its level in the tree. """ self.uid = uid self.arbor = weakref.proxy(arbor) if root: self.root = -1 self.field_data = FieldContainer(arbor) else: self.root = None _tree_id = None # used by CatalogArbor @property def tree_id(self): """ Return the index of this node in a list of all nodes in the tree. """ if self.is_root: return 0 elif self._link is not None: return self._link.tree_id else: return self._tree_id @tree_id.setter def tree_id(self, value): """ Set the tree_id manually in CatalogArbors. """ self._tree_id = value @property def is_root(self): """ Is this node the last in the tree? """ return self.root in [-1, self] def find_root(self): """ Find the root node. """ if self.is_root: return self root = self.root if root is not None: return root return self.walk_to_root() def walk_to_root(self): """ Walk descendents until root. """ my_node = self while not my_node.is_root: if my_node.descendent in (-1, None): break my_node = my_node.descendent return my_node def clear_fields(self): """ If a root node, delete field data. If not root node, do nothing. """ if not self.is_root: return self.field_data.clear() _descendent = None # used by CatalogArbor @property def descendent(self): """ Return the descendent node. """ if self.is_root: return None # set in CatalogArbor._plant_trees if self._descendent is not None: return self._descendent # conventional Arbor object desc_link = self._link.descendent if desc_link is None: return None return self.arbor._generate_tree_node(self.root, desc_link) _ancestors = None # used by CatalogArbor @property def ancestors(self): """ Return a generator of ancestor nodes. """ self.arbor._grow_tree(self) # conventional Arbor object if self._link is not None: for link in self._link.ancestors: yield self.arbor._generate_tree_node(self.root, link) return # If tree is not setup yet, the ancestor nodes will not have # root pointers yet. need_root = not self.arbor.is_setup(self) if need_root: root = self.walk_to_root() # set in CatalogArbor._plant_trees if self._ancestors is not None: for ancestor in self._ancestors: if need_root: ancestor.root = root yield ancestor return return None _uids = None @property def uids(self): """ Array of uids for all nodes in the tree. """ if not self.is_root: return None if self._uids is None: self.arbor._build_attr("_uids", self) return self._uids _desc_uids = None @property def desc_uids(self): """ Array of descendent uids for all nodes in the tree. """ if not self.is_root: return None if self._desc_uids is None: self.arbor._build_attr("_desc_uids", self) return self._desc_uids _tree_size = None @property def tree_size(self): """ Number of nodes in the tree. """ if self._tree_size is not None: return self._tree_size if self.is_root: self.arbor._setup_tree(self) # pass back to the arbor to avoid calculating again self.arbor._store_node_info(self, '_tree_size') else: self._tree_size = len(list(self["tree"])) return self._tree_size _link_storage = None @property def _links(self): """ Array of NodeLink objects with the ancestor/descendent structure. This is only used by conventional Arbor objects, i.e., not CatalogArbor objects. """ if not self.is_root: return None if self._link_storage is None: self.arbor._build_attr("_link_storage", self) return self._link_storage def __setitem__(self, key, value): """ Set analysis field value for this node. """ fi = self.arbor.field_info[key] ftype = fi.get('type') if ftype not in ['analysis', 'analysis_saved']: raise ArborUnsettableField(key, self.arbor) vector_fieldname = fi.get("vector_fieldname", None) has_vector_field = vector_fieldname is not None if self.is_root: root = self tree_id = 0 # if root, set the value in the arbor field storage self.arbor[key][self._arbor_index] = value if has_vector_field and vector_fieldname in self.arbor.field_data: del self.arbor.field_data[vector_fieldname] else: root = self.root tree_id = self.tree_id self.arbor._node_io.get_fields(self, fields=[key], root_only=False) data = root.field_data[key] data[tree_id] = value if has_vector_field and vector_fieldname in root.field_data: del root.field_data[vector_fieldname] def __getitem__(self, key): """ Return field values or tree/prog generators. """ return self.query(key) def query(self, key): """ Return field values for this TreeNode, progenitor list, or tree. Parameters ---------- key : string or tuple If a single string, it can be either a field to be queried or one of "tree" or "prog". If a field, then return the value of the field for this TreeNode. If "tree" or "prog", then return the list of TreeNodes in the tree or progenitor list. If a tuple, this can be either (string, string) or (string, int), where the first argument must be either "tree" or "prog". If second argument is a string, then return the field values for either the tree or the progenitor list. If second argument is an int, then return the nth TreeNode in the tree or progenitor list list. Examples -------- >>> # virial mass for this halo >>> print (my_tree["mvir"].to("Msun/h")) >>> # all TreeNodes in the progenitor list >>> print (my_tree["prog"]) >>> # all TreeNodes in the entire tree >>> print (my_tree["tree"]) >>> # virial masses for the progenitor list >>> print (my_tree["prog", "mvir"].to("Msun/h")) >>> # the 3rd TreeNode in the progenitor list >>> print (my_tree["prog", 2]) Returns ------- float, ndarray/unyt_array, TreeNode """ arr_types = ("forest", "prog", "tree") if isinstance(key, tuple): if len(key) != 2: raise SyntaxError("Must be either 1 or 2 arguments.") ftype, field = key if ftype not in arr_types: raise SyntaxError( f"First argument must be one of {str(arr_types)}.") if not isinstance(field, str): raise SyntaxError("Second argument must be a string.") self.arbor._node_io.get_fields(self, fields=[field], root_only=False) indices = getattr(self, f"_{ftype}_field_indices") data_object = self.find_root() return data_object.field_data[field][indices] else: if not isinstance(key, str): raise SyntaxError("Single argument must be a string.") # return the progenitor list or tree nodes in a list if key in arr_types: self.arbor._setup_tree(self) return getattr(self, f"_{key}_nodes") # return field value for this node self.arbor._node_io.get_fields(self, fields=[key], root_only=self.is_root) data_object = self.find_root() return data_object.field_data[key][self.tree_id] def __repr__(self): """ Call me TreeNode. """ return f"TreeNode[{self.uid}]" def get_node(self, selector, index): """ Get a single TreeNode from a tree. Use this to get the nth TreeNode from a forest, tree, or progenitor list for which the calling TreeNode is the head. Parameters ---------- selector : str ("forest", "tree", or "prog") The tree selector from which to get the TreeNode. This should be "forest", "tree", or "prog". index : int The index of the desired TreeNode in the forest, tree, or progenitor list. Returns ------- node: :class:`~ytree.data_structures.tree_node.TreeNode` Examples -------- >>> import ytree >>> a = ytree.load("tiny_ctrees/locations.dat") >>> my_tree = a[0] >>> # get 6th TreeNode in the progenitor list >>> my_node = my_tree.get_node('prog', 5) """ self.arbor._setup_tree(self) self.arbor._grow_tree(self) indices = getattr(self, f"_{selector}_field_indices", None) if indices is None: raise RuntimeError("Bad selector.") my_link = self.root._links[indices][index] return self.arbor._generate_tree_node(self.root, my_link) def get_leaf_nodes(self, selector=None): """ Get all leaf nodes from the tree of which this is the head. This returns a generator of all leaf nodes belonging to this tree. A leaf node is a node that has no ancestors. Parameters ---------- selector : optional, str ("forest", "tree", or "prog") The tree selector from which leaf nodes will be found. If none given, this will be set to "forest" if the calling node is a root node and "tree" otherwise. Returns ------- leaf_nodes : a generator of :class:`~ytree.data_structures.tree_node.TreeNode` objects. Examples -------- >>> import ytree >>> a = ytree.load("tiny_ctrees/locations.dat") >>> my_tree = a[0] >>> for leaf in my_tree.get_leaf_nodes(): ... print (leaf["mass"]) """ if selector is None: if self.is_root: selector = "forest" else: selector = "tree" uids = self[selector, "uid"] desc_uids = self[selector, "desc_uid"] lids = np.where(~np.in1d(uids, desc_uids))[0] for lid in lids: yield self.get_node(selector, lid) def get_root_nodes(self): """ Get all root nodes from the forest to which this node belongs. This returns a generator of all root nodes in the forest. A root node is a node that has no descendents. Returns ------- root_nodes : a generator of :class:`~ytree.data_structures.tree_node.TreeNode` objects. Examples -------- >>> import ytree >>> a = ytree.load("consistent_trees_hdf5/soa/forest.h5", ... access="forest") >>> my_tree = a[0] >>> for root in my_tree.get_root_nodes(): ... print (root["mass"]) """ selector = "forest" desc_uids = self[selector, "desc_uid"] rids = np.where(desc_uids == -1)[0] for rid in rids: yield self.get_node(selector, rid) _ffi = slice(None) @property def _forest_field_indices(self): """ Return default slice to select the whole forest. """ return self._ffi @property def _forest_nodes(self): """ An iterator over all TreeNodes in the forest. This is different from _tree_nodes in that we don't walk through the ancestors lists. We just yield every TreeNode there is. """ self.arbor._grow_tree(self) root = self.root for link in root._links: yield self.arbor._generate_tree_node(self.root, link) @property def _tree_nodes(self): """ An iterator over all TreeNodes in the tree beneath, starting with this TreeNode. For internal use only. Use the following instead: >>> for my_node in my_tree['tree']: ... print (my_node) Examples -------- >>> for my_node in my_tree._tree_nodes: ... print (my_node) """ self.arbor._grow_tree(self) yield self if self.ancestors is None: return for ancestor in self.ancestors: for a_node in ancestor._tree_nodes: yield a_node _tfi = None @property def _tree_field_indices(self): """ Return the field array indices for all TreeNodes in the tree beneath, starting with this TreeNode. """ if self._tfi is not None: return self._tfi self.arbor._grow_tree(self) self._tfi = np.array([node.tree_id for node in self._tree_nodes]) return self._tfi @property def _prog_nodes(self): """ An iterator over all TreeNodes in the progenitor list, starting with this TreeNode. For internal use only. Use the following instead: >>> for my_node in my_tree['prog']: ... print (my_node) Examples -------- >>> for my_node in my_tree._prog_nodes: ... print (my_node) """ self.arbor._grow_tree(self) my_node = self while my_node is not None: yield my_node ancestors = list(my_node.ancestors) if ancestors: my_node = my_node.arbor.selector(ancestors) else: my_node = None _pfi = None @property def _prog_field_indices(self): """ Return the field array indices for all TreeNodes in the progenitor list, starting with this TreeNode. """ if self._pfi is not None: return self._pfi self.arbor._grow_tree(self) self._pfi = np.array([node.tree_id for node in self._prog_nodes]) return self._pfi def save_tree(self, filename=None, fields=None): r""" Save the tree to a file. The saved tree can be re-loaded as an arbor. Parameters ---------- filename : optional, string Output file keyword. Main header file will be named <filename>/<filename>.h5. Default: "tree_<uid>". fields : optional, list of strings The fields to be saved. If not given, all fields will be saved. Returns ------- filename : string The filename of the saved arbor. Examples -------- >>> import ytree >>> a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat") >>> # save the first tree >>> fn = a[0].save_tree() >>> # reload it >>> a2 = ytree.load(fn) """ if filename is None: filename = f"tree_{self.uid}" return self.arbor.save_arbor(filename=filename, fields=fields, trees=[self])