Beispiel #1
0
 def __init__(self, uid, arbor=None, root=False):
     """
     Initialize a TreeNode with at least its halo catalog ID and
     its level in the tree.
     """
     self.uid = uid
     self.arbor = weakref.proxy(arbor)
     if root:
         self.root = -1
         self._field_data = FieldContainer(arbor)
     else:
         self.root = None
Beispiel #2
0
    def __init__(self, filename):
        """
        Initialize an Arbor given an input file.
        """

        self.filename = filename
        self.basename = os.path.basename(filename)
        self._parse_parameter_file()
        self._set_units()
        self._field_data = FieldContainer(self)
        self._node_io = self._tree_field_io_class(self)
        self._root_io = self._root_field_io_class(self)
        self._get_data_files()
        self._setup_fields()
        self._set_default_selector()
Beispiel #3
0
 def _setup_fields(self):
     """
     Setup field containers and definitions.
     """
     self.field_data = FieldContainer(self)
     self.derived_field_list = []
     self.analysis_field_list = []
     self.field_info.setup_known_fields()
     self.field_info.setup_aliases()
     self.field_info.setup_derived_fields()
     self.field_info.setup_vector_fields()
Beispiel #4
0
class Arbor(object, metaclass=RegisteredArbor):
    """
    Base class for all Arbor classes.

    Loads a merger-tree output file or a series of halo catalogs
    and create trees, stored in an array in
    :func:`~ytree.data_structures.arbor.Arbor.trees`.
    Arbors can be saved in a universal format with
    :func:`~ytree.data_structures.arbor.Arbor.save_arbor`.  Also, provide some
    convenience functions for creating YTArrays and YTQuantities and
    a cosmology calculator.
    """

    _field_info_class = FieldInfoContainer
    _root_field_io_class = DefaultRootFieldIO
    _tree_field_io_class = TreeFieldIO

    def __init__(self, filename):
        """
        Initialize an Arbor given an input file.
        """

        self.filename = filename
        self.basename = os.path.basename(filename)
        self._parse_parameter_file()
        self._set_units()
        self._field_data = FieldContainer(self)
        self._node_io = self._tree_field_io_class(self)
        self._root_io = self._root_field_io_class(self)
        self._get_data_files()
        self._setup_fields()
        self._set_default_selector()

    def _get_data_files(self):
        """
        Get all files that hold field data and make them known
        to the i/o system.
        """
        pass

    def _parse_parameter_file(self):
        """
        Read relevant parameters from parameter file or file header
        and detect fields.
        """
        raise NotImplementedError

    def _plant_trees(self):
        """
        Create the list of root tree nodes.
        """
        raise NotImplementedError

    def is_setup(self, tree_node):
        """
        Return True if arrays of uids and descendent uids have
        been read in. Setup has also completed if tree is already
        grown.
        """
        return self.is_grown(tree_node) or \
          tree_node._uids is not None

    def _setup_tree(self, tree_node, **kwargs):
        """
        Create arrays of uids and desc_uids and attach them to the
        root node.
        """
        # skip if this is not a root or if already setup
        if self.is_setup(tree_node):
            return

        idtype      = np.int64
        fields, _ = \
          self.field_info.resolve_field_dependencies(["uid", "desc_uid"])
        halo_id_f, desc_id_f = fields
        dtypes      = {halo_id_f: idtype, desc_id_f: idtype}
        field_data  = self._node_io._read_fields(tree_node, fields,
                                                 dtypes=dtypes, **kwargs)
        tree_node._uids      = field_data[halo_id_f]
        tree_node._desc_uids = field_data[desc_id_f]
        tree_node._tree_size = tree_node._uids.size

    def is_grown(self, tree_node):
        """
        Return True if a tree has been fully assembled, i.e.,
        the hierarchy of ancestor tree nodes has been built.
        """
        return tree_node.root != -1

    def _grow_tree(self, tree_node, **kwargs):
        """
        Create an array of TreeNodes hanging off the root node
        and assemble the tree structure.
        """
        # skip this if not a root or if already grown
        if self.is_grown(tree_node):
            return

        self._setup_tree(tree_node, **kwargs)
        nhalos   = tree_node.uids.size
        nodes    = np.empty(nhalos, dtype=np.object)
        nodes[0] = tree_node
        for i in range(1, nhalos):
            nodes[i] = TreeNode(tree_node.uids[i], arbor=self)
        tree_node._nodes = nodes

        # Add tree information to nodes
        uidmap = {}
        for i, node in enumerate(nodes):
            node.treeid = i
            node.root   = tree_node
            uidmap[tree_node.uids[i]] = i

        # Link ancestor/descendents
        # Separate loop for trees like lhalotree where descendent
        # can follow in order
        for i, node in enumerate(nodes):
            descid      = tree_node.desc_uids[i]
            if descid != -1:
                desc = nodes[uidmap[descid]]
                desc.add_ancestor(node)
                node.descendent = desc

    def _node_io_loop(self, func, *args, **kwargs):
        """
        Call the provided function over a list of nodes.

        If possible, group nodes by common data files to speed
        things up.  This should work like __iter__, except we call
        a function instead of yielding.

        Parameters
        ----------
        func : function
            Function to be called on an array of nodes.
        pbar : optional, string or yt.funcs.TqdmProgressBar
            A progress bar to be updated with each iteration.
            If a string, a progress bar will be created and the
            finish function will be called. If a progress bar is
            provided, the finish function will not be called.
            Default: None (no progress bar).
        root_nodes : optional, array of root TreeNodes
            Array of nodes over which the function will be called.
            If None, the list will be self.trees (i.e., all
            root_nodes).
            Default: None.
        store : optional, string
            If not None, any return value captured from the function
            will be stored in an attribute with this name associated
            with the TreeNode.
            Default: None.
        """

        pbar = kwargs.pop("pbar", None)
        root_nodes = kwargs.pop("root_nodes", None)
        if root_nodes is None:
            root_nodes = self.trees
        store = kwargs.pop("store", None)
        data_files, node_list = self._node_io_loop_prepare(root_nodes)
        nnodes = sum([nodes.size for nodes in node_list])

        finish = True
        if pbar is None:
            pbar = fake_pbar("", nnodes)
        elif not isinstance(pbar, TqdmProgressBar):
            pbar = get_pbar(pbar, nnodes)
        else:
            finish = False

        for data_file, nodes in zip(data_files, node_list):
            self._node_io_loop_start(data_file)
            for node in nodes:
                rval = func(node, *args, **kwargs)
                if store is not None:
                    setattr(node, store, rval)
                pbar.update(1)
            self._node_io_loop_finish(data_file)

        if finish:
            pbar.finish()

    def _node_io_loop_start(self, data_file):
        pass

    def _node_io_loop_finish(self, data_file):
        pass

    def _node_io_loop_prepare(self, root_nodes):
        """
        This is called at the beginning of _node_io_loop.

        In different frontends, this can be used to group nodes by
        common data files.

        Return
        ------
        list of data files and a list of node arrays

        Each data file corresponds to an array of nodes.
        """

        return [None], [root_nodes]

    def __iter__(self):
        """
        Iterate over all items in the tree list.

        If possible, group nodes by common data files to speed
        things up.
        """

        data_files, node_list = self._node_io_loop_prepare(self.trees)

        for data_file, nodes in zip(data_files, node_list):
            self._node_io_loop_start(data_file)
            for node in nodes:
                yield node
            self._node_io_loop_finish(data_file)

    _trees = None
    @property
    def trees(self):
        """
        Array containing all trees in the arbor.
        """
        if self._trees is None:
            self._plant_trees()
        return self._trees

    def __repr__(self):
        return self.basename

    def __getitem__(self, key):
        return self.query(key)

    def query(self, key):
        """
        If given a string, return an array of field values for the
        roots of all trees.
        If given an integer, return a tree from the list of trees.
        """
        if isinstance(key, str):
            if key in ("tree", "prog"):
                raise SyntaxError("Argument must be a field or integer.")
            self._root_io.get_fields(self, fields=[key])
            if self.field_info[key].get("type") == "analysis":
                return self._field_data.pop(key)
            return self._field_data[key]
        return self.trees[key]

    def __len__(self):
        """
        Return length of tree list.
        """
        return self.trees.size

    _field_info = None
    @property
    def field_info(self):
        """
        A dictionary containing information for each available field.
        """
        if self._field_info is None and \
          self._field_info_class is not None:
            self._field_info = self._field_info_class(self)
        return self._field_info

    @property
    def size(self):
        """
        Return length of tree list.
        """
        return self.trees.size

    _unit_registry = None
    @property
    def unit_registry(self):
        """
        Unit system registry.
        """
        if self._unit_registry is None:
            self._unit_registry = UnitRegistry()
        return self._unit_registry

    @unit_registry.setter
    def unit_registry(self, value):
        self._unit_registry = value
        self._arr = None
        self._quan = None

    _hubble_constant = None
    @property
    def hubble_constant(self):
        """
        Value of the Hubble parameter.
        """
        return self._hubble_constant

    @hubble_constant.setter
    def hubble_constant(self, value):
        self._hubble_constant = value
        # reset the unit registry lut while preserving other changes
        self.unit_registry = UnitRegistry.from_json(
            self.unit_registry.to_json())
        self.unit_registry.modify("h", self.hubble_constant)

    _box_size = None
    @property
    def box_size(self):
        """
        The simulation box size.
        """
        return self._box_size

    @box_size.setter
    def box_size(self, value):
        self._box_size = value
        # set unitary as soon as we know the box size
        self.unit_registry.add(
            "unitary", float(self.box_size.in_base()), length)

    def _setup_fields(self):
        self.derived_field_list = []
        self.analysis_field_list = []
        self.field_info.setup_known_fields()
        self.field_info.setup_aliases()
        self.field_info.setup_derived_fields()
        self.field_info.setup_vector_fields()

    def _set_units(self):
        """
        Set "cm" units for explicitly comoving.
        Note, we are using comoving units all the time since
        we are dealing with data at multiple redshifts.
        """
        for my_unit in ["m", "pc", "AU", "au"]:
            new_unit = "%scm" % my_unit
            self._unit_registry.add(
                new_unit, self._unit_registry.lut[my_unit][0],
                length, self._unit_registry.lut[my_unit][3])

        self.cosmology = Cosmology(
            hubble_constant=self.hubble_constant,
            omega_matter=self.omega_matter,
            omega_lambda=self.omega_lambda,
            unit_registry=self.unit_registry)

    def set_selector(self, selector, *args, **kwargs):
        r"""
        Sets the tree node selector to be used.

        This sets the manner in which halo progenitors are
        chosen from a list of ancestors.  The most obvious example
        is to select the most massive ancestor.

        Parameters
        ----------
        selector : string
            Name of the selector to be used.

        Any additional arguments and keywords to be provided to
        the selector function should follow.

        Examples
        --------

        >>> import ytree
        >>> a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat")
        >>> a.set_selector("max_field_value", "mass")

        """
        self.selector = tree_node_selector_registry.find(
            selector, *args, **kwargs)

    _arr = None
    @property
    def arr(self):
        """
        Create a YTArray using the Arbor's unit registry.
        """
        if self._arr is not None:
            return self._arr
        self._arr = functools.partial(YTArray,
                                      registry=self.unit_registry)
        return self._arr

    _quan = None
    @property
    def quan(self):
        """
        Create a YTQuantity using the Arbor's unit registry.
        """
        if self._quan is not None:
            return self._quan
        self._quan = functools.partial(YTQuantity,
                                       registry=self.unit_registry)
        return self._quan

    def _set_default_selector(self):
        """
        Set the default tree node selector as maximum mass.
        """
        self.set_selector("max_field_value", "mass")

    def select_halos(self, criteria, trees=None, select_from="tree",
                     fields=None):
        """
        Select halos from the arbor based on a set of criteria given as a string.


        Parameters
        ----------

        criteria: string
            A string that will eval to a Numpy-like selection operation
            performed on a TreeNode object called "tree".
            Example: 'tree["tree", "redshift"] > 1'
        trees : optional, list or array of TreeNodes
            A list or array of TreeNode objects in which to search. If none given,
            the search is performed over the full arbor.
        select_from : optional, "tree" or "prog"
            Determines whether to perform the search over the full tree or just
            the main progenitors. Note, the value given must be consistent with
            what appears in the criteria string. For example, a criteria
            string of 'tree["tree", "redshift"] > 1' cannot be used when setting
            select_from to "prog".
            Default: "tree".
        fields : optional, list of strings
            Use to provide a list of fields required by the criteria evaluation.
            If given, fields will be preloaded in an optimized way and the search
            will go faster.
            Default: None.

        Returns
        -------

        halos : array of TreeNodes
            A flat array of all TreeNodes meeting the criteria.

        Examples
        --------

        >>> import ytree
        >>> a = ytree.load("tree_0_0_0.dat")
        >>> halos = a.select_halos('tree["tree", "redshift"] > 1',
        ...                        fields=["redshift"])
        >>>
        >>> halos = a.select_halos('tree["prog", "mass"].to("Msun") >= 1e10',
        ...                        select_from="prog", fields=["mass"])

        """

        if select_from not in ["tree", "prog"]:
            raise SyntaxError(
                "Keyword \"select_from\" must be either \"tree\" or \"prog\".")

        if trees is None:
            trees = self.trees

        if fields is None:
            fields = []

        self._node_io_loop(self._setup_tree, root_nodes=trees,
                           pbar="Setting up trees")
        if fields:
            self._node_io_loop(
                self._node_io.get_fields,
                pbar="Getting fields",
                root_nodes=trees, fields=fields, root_only=False)


        halos = []
        pbar = get_pbar("Selecting halos", self.trees.size)
        for tree in trees:
            my_filter = np.asarray(eval(criteria))
            if my_filter.size != tree[select_from].size:
                raise RuntimeError(
                    ("Filter array and tree array sizes do not match. " +
                     "Make sure select_from (\"%s\") matches criteria (\"%s\").") %
                    (select_from, criteria))
            halos.extend(tree[select_from][my_filter])
            pbar.update(1)
        pbar.finish()
        return np.array(halos)

    def add_analysis_field(self, name, units):
        r"""
        Add an empty field to be filled by analysis operations.

        Parameters
        ----------
        name : string
            Field name.
        units : string
            Field units.

        Examples
        --------

        >>> import ytree
        >>> a = ytree.load("tree_0_0_0.dat")
        >>> a.add_analysis_field("robots", "Msun * kpc")
        >>> # Set field for some halo.
        >>> a[0]["tree"][7]["robots"] = 1979.816
        """

        if name in self.field_info:
            raise ArborFieldAlreadyExists(name, arbor=self)

        self.analysis_field_list.append(name)
        self.field_info[name] = {"type": "analysis",
                                 "units": units}

    def add_alias_field(self, alias, field, units=None,
                        force_add=True):
        r"""
        Add a field as an alias to another field.

        Parameters
        ----------
        alias : string
            Alias name.
        field : string
            The field to be aliased.
        units : optional, string
            Units in which the field will be returned.
        force_add : optional, bool
            If True, add field even if it already exists and warn the
            user and raise an exception if dependencies do not exist.
            If False, silently do nothing in both instances.
            Default: True.

        Examples
        --------

        >>> import ytree
        >>> a = ytree.load("tree_0_0_0.dat")
        >>> # "Mvir" exists on disk
        >>> a.add_alias_field("mass", "Mvir", units="Msun")
        >>> print (a["mass"])

        """

        if alias in self.field_info:
            if force_add:
                ftype = self.field_info[alias].get("type", "on-disk")
                if ftype in ["alias", "derived"]:
                    fl = self.derived_field_list
                else:
                    fl = self.field_list
                mylog.warn(
                    ("Overriding field \"%s\" that already " +
                     "exists as %s field.") % (alias, ftype))
                fl.pop(fl.index(alias))
            else:
                return

        if field not in self.field_info:
            if force_add:
                raise ArborFieldDependencyNotFound(
                    field, alias, arbor=self)
            else:
                return

        if units is None:
            units = self.field_info[field].get("units")
        self.derived_field_list.append(alias)
        self.field_info[alias] = \
          {"type": "alias", "units": units,
           "dependencies": [field]}
        if "aliases" not in self.field_info[field]:
            self.field_info[field]["aliases"] = []
            self.field_info[field]["aliases"].append(alias)

    def add_derived_field(self, name, function,
                          units=None, description=None,
                          vector_field=False, force_add=True):
        r"""
        Add a field that is a function of other fields.

        Parameters
        ----------
        name : string
            Field name.
        function : callable
            The function to be called to generate the field.
            This function should take two arguments, the
            arbor and the data structure containing the
            dependent fields.  See below for an example.
        units : optional, string
            The units in which the field will be returned.
        description : optional, string
            A short description of the field.
        vector_field: optional, bool
            If True, field is an xyz vector.
            Default: False.
        force_add : optional, bool
            If True, add field even if it already exists and warn the
            user and raise an exception if dependencies do not exist.
            If False, silently do nothing in both instances.
            Default: True.

        Examples
        --------

        >>> import ytree
        >>> a = ytree.load("tree_0_0_0.dat")
        >>> def _redshift(field, data):
        ...     return 1. / data["scale"] - 1
        ...
        >>> a.add_derived_field("redshift", _redshift)
        >>> print (a["redshift"])

        """

        if name in self.field_info:
            if force_add:
                ftype = self.field_info[name].get("type", "on-disk")
                if ftype in ["alias", "derived"]:
                    fl = self.derived_field_list
                else:
                    fl = self.field_list
                mylog.warn(
                    ("Overriding field \"%s\" that already " +
                     "exists as %s field.") % (name, ftype))
                fl.pop(fl.index(name))
            else:
                return

        if units is None:
            units = ""
        info = {"name": name, "type": "derived", "function": function,
                "units": units, "vector_field": vector_field,
                "description": description}

        fc = FakeFieldContainer(self, name=name)
        try:
            rv = function(info, fc)
        except TypeError as e:
            raise RuntimeError(
"""

Field function syntax in ytree has changed. Field functions must
now take two arguments, as in the following:
def my_field(field, data):
    return data['mass']

Check the TypeError exception above for more details.
""")
            raise e

        except ArborFieldDependencyNotFound as e:
            if force_add:
                raise e
            else:
                return

        rv.convert_to_units(units)
        info["dependencies"] = list(fc.keys())

        self.derived_field_list.append(name)
        self.field_info[name] = info

    @classmethod
    def _is_valid(cls, *args, **kwargs):
        """
        Check if input file works with a specific Arbor class.
        This is used with :func:`~ytree.data_structures.arbor.load` function.
        """
        return False

    def save_arbor(self, filename="arbor", fields=None, trees=None,
                   max_file_size=524288):
        r"""
        Save the arbor to a file.

        The saved arbor can be re-loaded as an arbor.

        Parameters
        ----------
        filename : optional, string
            Output file keyword.  If filename ends in ".h5",
            the main header file will be just that.  If not,
            filename will be <filename>/<basename>.h5.
            Default: "arbor".
        fields : optional, list of strings
            The fields to be saved.  If not given, all
            fields will be saved.

        Returns
        -------
        header_filename : string
            The filename of the saved arbor.

        Examples
        --------

        >>> import ytree
        >>> a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat")
        >>> fn = a.save_arbor()
        >>> # reload it
        >>> a2 = ytree.load(fn)

        """

        if trees is None:
            all_trees = True
            trees = self.trees
            roots = trees
        else:
            all_trees = False
            # assemble unique tree roots for getting fields
            trees = np.asarray(trees)
            roots = []
            root_uids = []
            for tree in trees:
                if tree.root == -1:
                    my_root = tree
                else:
                    my_root = tree.root
                if my_root.uid not in root_uids:
                    roots.append(my_root)
                    root_uids.append(my_root.uid)
            roots = np.array(roots)
            del root_uids

        if fields in [None, "all"]:
            # If a field has an alias, get that instead.
            fields = []
            for field in self.field_list + self.analysis_field_list:
                fields.extend(
                    self.field_info[field].get("aliases", [field]))
        else:
            fields.extend([f for f in ["uid", "desc_uid"]
                           if f not in fields])

        ds = {}
        for attr in ["hubble_constant",
                     "omega_matter",
                     "omega_lambda"]:
            if hasattr(self, attr):
                ds[attr] = getattr(self, attr)
        extra_attrs = {"box_size": self.box_size,
                       "arbor_type": "YTreeArbor",
                       "unit_registry_json": self.unit_registry.to_json()}

        self._node_io_loop(self._setup_tree, root_nodes=roots,
                           pbar="Setting up trees")
        if all_trees:
            self._root_io.get_fields(self, fields=fields)

        # determine file layout
        nn = 0 # node count
        nt = 0 # tree count
        nnodes = []
        ntrees = []
        tree_size = np.array([tree.tree_size for tree in trees])
        for ts in tree_size:
            nn += ts
            nt += 1
            if nn > max_file_size:
                nnodes.append(nn-ts)
                ntrees.append(nt-1)
                nn = ts
                nt = 1
        if nn > 0:
            nnodes.append(nn)
            ntrees.append(nt)
        nfiles = len(nnodes)
        nnodes = np.array(nnodes)
        ntrees = np.array(ntrees)
        tree_end_index   = ntrees.cumsum()
        tree_start_index = tree_end_index - ntrees

        # write header file
        fieldnames = [field.replace("/", "_") for field in fields]
        myfi = {}
        rdata = {}
        rtypes = {}
        for field, fieldname in zip(fields, fieldnames):
            fi = self.field_info[field]
            myfi[fieldname] = \
              dict((key, fi[key])
                   for key in ["units", "description"]
                   if key in fi)
            if all_trees:
                rdata[fieldname] = self._field_data[field]
            else:
                rdata[fieldname] = self.arr([t[field] for t in trees])
            rtypes[fieldname] = "data"
        # all saved trees will be roots
        if not all_trees:
            rdata["desc_uid"][:] = -1
        extra_attrs["field_info"] = json.dumps(myfi)
        extra_attrs["total_files"] = nfiles
        extra_attrs["total_trees"] = trees.size
        extra_attrs["total_nodes"] = tree_size.sum()
        hdata = {"tree_start_index": tree_start_index,
                 "tree_end_index"  : tree_end_index,
                 "tree_size"       : ntrees}
        hdata.update(rdata)
        htypes = dict((f, "index") for f in hdata)
        htypes.update(rtypes)

        filename = _determine_output_filename(filename, ".h5")
        header_filename = "%s.h5" % filename
        save_as_dataset(ds, header_filename, hdata,
                        field_types=htypes,
                        extra_attrs=extra_attrs)

        # write data files
        ftypes = dict((f, "data") for f in fieldnames)
        for i in range(nfiles):
            my_nodes = trees[tree_start_index[i]:tree_end_index[i]]
            self._node_io_loop(
                self._node_io.get_fields,
                pbar="Getting fields [%d/%d]" % (i+1, nfiles),
                root_nodes=my_nodes, fields=fields, root_only=False)
            fdata = dict((field, np.empty(nnodes[i])) for field in fieldnames)
            my_tree_size  = tree_size[tree_start_index[i]:tree_end_index[i]]
            my_tree_end   = my_tree_size.cumsum()
            my_tree_start = my_tree_end - my_tree_size
            pbar = get_pbar("Creating field arrays [%d/%d]" %
                            (i+1, nfiles), len(fields)*nnodes[i])
            c = 0
            for field, fieldname in zip(fields, fieldnames):
                for di, node in enumerate(my_nodes):
                    if node.is_root:
                        ndata = node._field_data[field]
                    else:
                        ndata = node["tree", field]
                        if field == "desc_uid":
                            # make sure it's a root when loaded
                            ndata[0] = -1
                    fdata[fieldname][
                        my_tree_start[di]:my_tree_end[di]] = ndata
                    c += my_tree_size[di]
                    pbar.update(c)
            pbar.finish()
            fdata["tree_start_index"] = my_tree_start
            fdata["tree_end_index"]   = my_tree_end
            fdata["tree_size"]        = my_tree_size
            for ft in ["tree_start_index",
                      "tree_end_index",
                      "tree_size"]:
                ftypes[ft] = "index"
            my_filename = "%s_%04d.h5" % (filename, i)
            save_as_dataset({}, my_filename, fdata,
                            field_types=ftypes)

        return header_filename
Beispiel #5
0
class TreeNode:
    """
    Class for objects stored in Arbors.

    Each TreeNode represents a halo in a tree.  A TreeNode knows
    its halo ID, the level in the tree, and its global ID in the
    Arbor that holds it.  It also has a list of its ancestors.
    Fields can be queried for it, its progenitor list, and the
    tree beneath.
    """

    _link = None

    def __init__(self, uid, arbor=None, root=False):
        """
        Initialize a TreeNode with at least its halo catalog ID and
        its level in the tree.
        """
        self.uid = uid
        self.arbor = weakref.proxy(arbor)
        if root:
            self.root = -1
            self._field_data = FieldContainer(arbor)
        else:
            self.root = None

    _tree_id = None  # used by CatalogArbor

    @property
    def tree_id(self):
        """
        Return the index of this node in a list of all nodes in the tree.
        """
        if self.is_root:
            return 0
        elif self._link is not None:
            return self._link.tree_id
        else:
            return self._tree_id

    @tree_id.setter
    def tree_id(self, value):
        """
        Set the tree_id manually in CatalogArbors.
        """
        self._tree_id = value

    @property
    def is_root(self):
        """
        Is this node the last in the tree?
        """
        return self.root in [-1, self]

    def find_root(self):
        """
        Find the root node.
        """

        if self.is_root:
            return self

        root = self.root
        if root is not None:
            return root

        return self.walk_to_root()

    def walk_to_root(self):
        """
        Walk descendents until root.
        """

        my_node = self
        while not my_node.is_root:
            if my_node.descendent in (-1, None):
                break
            my_node = my_node.descendent
        return my_node

    def clear_fields(self):
        """
        If a root node, delete field data.
        If not root node, do nothing.
        """

        if not self.is_root:
            return
        self._field_data.clear()

    _descendent = None  # used by CatalogArbor

    @property
    def descendent(self):
        """
        Return the descendent node.
        """

        if self.is_root:
            return None

        # set in CatalogArbor._plant_trees
        if self._descendent is not None:
            return self._descendent

        # conventional Arbor object
        desc_link = self._link.descendent
        return self.arbor._generate_tree_node(self.root, desc_link)

    _ancestors = None  # used by CatalogArbor

    @property
    def ancestors(self):
        """
        Return a generator of ancestor nodes.
        """

        self.arbor._grow_tree(self)

        # conventional Arbor object
        if self._link is not None:
            for link in self._link.ancestors:
                yield self.arbor._generate_tree_node(self.root, link)
            return

        # set in CatalogArbor._plant_trees
        if self._ancestors is not None:
            for ancestor in self._ancestors:
                yield ancestor
            return
        return None

    _uids = None

    @property
    def uids(self):
        """
        Array of uids for all nodes in the tree.
        """
        if not self.is_root:
            return None
        if self._uids is None:
            self.arbor._build_attr("_uids", self)
        return self._uids

    _desc_uids = None

    @property
    def desc_uids(self):
        """
        Array of descendent uids for all nodes in the tree.
        """
        if not self.is_root:
            return None
        if self._desc_uids is None:
            self.arbor._build_attr("_desc_uids", self)
        return self._desc_uids

    _tree_size = None

    @property
    def tree_size(self):
        """
        Number of nodes in the tree.
        """
        if self._tree_size is not None:
            return self._tree_size
        if self.is_root:
            self.arbor._setup_tree(self)
            # pass back to the arbor to avoid calculating again
            self.arbor._store_node_info(self, '_tree_size')
        else:
            self._tree_size = len(list(self["tree"]))
        return self._tree_size

    _link_storage = None

    @property
    def _links(self):
        """
        Array of NodeLink objects with the ancestor/descendent structure.

        This is only used by conventional Arbor objects, i.e., not
        CatalogArbor objects.
        """
        if not self.is_root:
            return None
        if self._link_storage is None:
            self.arbor._build_attr("_link_storage", self)
        return self._link_storage

    def __setitem__(self, key, value):
        """
        Set analysis field value for this node.
        """

        if self.is_root:
            root = self
            tree_id = 0
            # if root, set the value in the arbor field storage
            self.arbor._field_data[key][self._arbor_index] = value
        else:
            root = self.root
            tree_id = self.tree_id
        self.arbor._node_io.get_fields(self, fields=[key], root_only=False)
        data = root._field_data[key]
        data[tree_id] = value

    def __getitem__(self, key):
        """
        Return field values or tree/prog generators.
        """
        return self.query(key)

    def query(self, key):
        """
        Return field values for this TreeNode, progenitor list, or tree.

        Parameters
        ----------
        key : string or tuple
            If a single string, it can be either a field to be queried or
            one of "tree" or "prog".  If a field, then return the value of
            the field for this TreeNode.  If "tree" or "prog", then return
            the list of TreeNodes in the tree or progenitor list.

            If a tuple, this can be either (string, string) or (string, int),
            where the first argument must be either "tree" or "prog".
            If second argument is a string, then return the field values
            for either the tree or the progenitor list.  If second argument
            is an int, then return the nth TreeNode in the tree or progenitor
            list list.

        Examples
        --------
        >>> # virial mass for this halo
        >>> print (my_tree["mvir"].to("Msun/h"))

        >>> # all TreeNodes in the progenitor list
        >>> print (my_tree["prog"])
        >>> # all TreeNodes in the entire tree
        >>> print (my_tree["tree"])

        >>> # virial masses for the progenitor list
        >>> print (my_tree["prog", "mvir"].to("Msun/h"))

        >>> # the 3rd TreeNode in the progenitor list
        >>> print (my_tree["prog", 2])

        Returns
        -------
        float, ndarray/unyt_array, TreeNode

        """
        arr_types = ("forest", "prog", "tree")
        if isinstance(key, tuple):
            if len(key) != 2:
                raise SyntaxError("Must be either 1 or 2 arguments.")
            ftype, field = key
            if ftype not in arr_types:
                raise SyntaxError("First argument must be one of %s." %
                                  str(arr_types))
            if not isinstance(field, str):
                raise SyntaxError("Second argument must be a string.")

            self.arbor._node_io.get_fields(self,
                                           fields=[field],
                                           root_only=False)
            indices = getattr(self, "_%s_field_indices" % ftype)

            data_object = self.find_root()
            return data_object._field_data[field][indices]

        else:
            if not isinstance(key, str):
                raise SyntaxError("Single argument must be a string.")

            # return the progenitor list or tree nodes in a list
            if key in arr_types:
                self.arbor._setup_tree(self)
                return getattr(self, "_%s_nodes" % key)

            # return field value for this node
            self.arbor._node_io.get_fields(self,
                                           fields=[key],
                                           root_only=self.is_root)
            data_object = self.find_root()
            return data_object._field_data[key][self.tree_id]

    def __repr__(self):
        """
        Call me TreeNode.
        """
        return "TreeNode[%d]" % self.uid

    _ffi = slice(None)

    @property
    def _forest_field_indices(self):
        """
        Return default slice to select the whole forest.
        """
        return self._ffi

    @property
    def _forest_nodes(self):
        """
        An iterator over all TreeNodes in the forest.

        This is different from _tree_nodes in that we don't walk
        through the ancestors lists. We just yield every TreeNode
        there is.
        """

        self.arbor._grow_tree(self)
        root = self.root
        for link in root._links:
            yield self.arbor._generate_tree_node(self.root, link)

    @property
    def _tree_nodes(self):
        """
        An iterator over all TreeNodes in the tree beneath,
        starting with this TreeNode.

        For internal use only. Use the following instead:

        >>> for my_node in my_tree['tree']:
        ...     print (my_node)

        Examples
        --------

        >>> for my_node in my_tree._tree_nodes:
        ...     print (my_node)

        """

        self.arbor._grow_tree(self)
        yield self
        if self.ancestors is None:
            return
        for ancestor in self.ancestors:
            for a_node in ancestor._tree_nodes:
                yield a_node

    _tfi = None

    @property
    def _tree_field_indices(self):
        """
        Return the field array indices for all TreeNodes in
        the tree beneath, starting with this TreeNode.
        """

        if self._tfi is not None:
            return self._tfi

        self.arbor._grow_tree(self)
        self._tfi = np.array([node.tree_id for node in self._tree_nodes])
        return self._tfi

    @property
    def _prog_nodes(self):
        """
        An iterator over all TreeNodes in the progenitor list,
        starting with this TreeNode.

        For internal use only. Use the following instead:

        >>> for my_node in my_tree['prog']:
        ...     print (my_node)

        Examples
        --------

        >>> for my_node in my_tree._prog_nodes:
        ...     print (my_node)

        """

        self.arbor._grow_tree(self)
        my_node = self
        while my_node is not None:
            yield my_node
            ancestors = list(my_node.ancestors)
            if ancestors:
                my_node = my_node.arbor.selector(ancestors)
            else:
                my_node = None

    _pfi = None

    @property
    def _prog_field_indices(self):
        """
        Return the field array indices for all TreeNodes in
        the progenitor list, starting with this TreeNode.
        """

        if self._pfi is not None:
            return self._pfi

        self.arbor._grow_tree(self)
        self._pfi = np.array([node.tree_id for node in self._prog_nodes])
        return self._pfi

    def save_tree(self, filename=None, fields=None):
        r"""
        Save the tree to a file.

        The saved tree can be re-loaded as an arbor.

        Parameters
        ----------
        filename : optional, string
            Output file keyword.  Main header file will be named
            <filename>/<filename>.h5.
            Default: "tree_<uid>".
        fields : optional, list of strings
            The fields to be saved.  If not given, all
            fields will be saved.

        Returns
        -------
        filename : string
            The filename of the saved arbor.

        Examples
        --------

        >>> import ytree
        >>> a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat")
        >>> # save the first tree
        >>> fn = a[0].save_tree()
        >>> # reload it
        >>> a2 = ytree.load(fn)

        """

        if filename is None:
            filename = "tree_%d" % self.uid

        return self.arbor.save_arbor(filename=filename,
                                     fields=fields,
                                     trees=[self])
Beispiel #6
0
class TreeNode(object):
    """
    Class for objects stored in Arbors.

    Each TreeNode represents a halo in a tree.  A TreeNode knows
    its halo ID, the level in the tree, and its global ID in the
    Arbor that holds it.  It also has a list of its ancestors.
    Fields can be queried for it, its progenitor list, and the
    tree beneath.
    """
    def __init__(self, uid, arbor=None, root=False):
        """
        Initialize a TreeNode with at least its halo catalog ID and
        its level in the tree.
        """
        self.uid = uid
        self.arbor = weakref.proxy(arbor)
        if root:
            self.root = -1
            self.treeid = 0
            self.descendent = None
            self._field_data = FieldContainer(arbor)
        else:
            self.root = None

    @property
    def is_root(self):
        return self.root in [-1, self]

    def find_root(self):
        """
        Find the root node.
        """
        my_node = self
        while not my_node.is_root:
            if my_node.descendent == -1:
                break
            my_node = my_node.descendent
        return my_node

    def clear_fields(self):
        """
        If a root node, delete field data.
        If not root node, do nothing.
        """

        if not self.is_root:
            return
        self._field_data.clear()

    def reset(self):
        """
        Reset all data structures.
        """

        self.clear_fields()
        attrs = ["_tfi", "_tn", "_pfi", "_pn"]
        if self.is_root:
            self.root = -1
            attrs.extend(["_nodes", "_desc_uids", "_uids"])
        else:
            self.root = None
        for attr in attrs:
            setattr(self, attr, None)

    def add_ancestor(self, ancestor):
        """
        Add another TreeNode to the list of ancestors.

        Parameters
        ----------
        ancestor : TreeNode
            The ancestor TreeNode.
        """
        if self._ancestors is None:
            self._ancestors = []
        self._ancestors.append(ancestor)

    _ancestors = None
    @property
    def ancestors(self):
        if self.root == -1:
            self.arbor._grow_tree(self)
        return self._ancestors

    _uids = None
    @property
    def uids(self):
        if not self.is_root:
            return None
        if self._uids is None:
            self.arbor._setup_tree(self)
        return self._uids

    _desc_uids = None
    @property
    def desc_uids(self):
        if not self.is_root:
            return None
        if self._desc_uids is None:
            self.arbor._setup_tree(self)
        return self._desc_uids

    _tree_size = None
    @property
    def tree_size(self):
        if not self.is_root:
            return self["tree"].size
        if self._tree_size is None:
            self.arbor._setup_tree(self)
        return self._tree_size

    _nodes = None
    @property
    def nodes(self):
        if not self.is_root:
            return None
        self.arbor._grow_tree(self)
        return self._nodes

    def __setitem__(self, key, value):
        if self.root == -1:
            root = self
            treeid = 0
        else:
            root = self.root
            treeid = self.treeid
        self.arbor._node_io.get_fields(self, fields=[key],
                                       root_only=False)
        data = root._field_data[key]
        data[treeid] = value

    def __getitem__(self, key):
        return self.query(key)

    def query(self, key):
        """
        Return field values for this TreeNode, progenitor list, or tree.

        Parameters
        ----------
        key : string or tuple
            If a single string, it can be either a field to be queried or
            one of "tree" or "prog".  If a field, then return the value of
            the field for this TreeNode.  If "tree" or "prog", then return
            the list of TreeNodes in the tree or progenitor list.

            If a tuple, this can be either (string, string) or (string, int),
            where the first argument must be either "tree" or "prog".
            If second argument is a string, then return the field values
            for either the tree or the progenitor list.  If second argument
            is an int, then return the nth TreeNode in the tree or progenitor
            list list.

        Examples
        --------
        >>> # virial mass for this halo
        >>> print (my_tree["mvir"].to("Msun/h"))

        >>> # all TreeNodes in the progenitor list
        >>> print (my_tree["prog"])
        >>> # all TreeNodes in the entire tree
        >>> print (my_tree["tree"])

        >>> # virial masses for the progenitor list
        >>> print (my_tree["prog", "mvir"].to("Msun/h"))

        >>> # the 3rd TreeNode in the progenitor list
        >>> print (my_tree["prog", 2])

        Returns
        -------
        float, ndarray/YTArray, TreeNode

        """
        arr_types = ("prog", "tree")
        if isinstance(key, tuple):
            if len(key) != 2:
                raise SyntaxError(
                    "Must be either 1 or 2 arguments.")
            ftype, field = key
            if ftype not in arr_types:
                raise SyntaxError(
                    "First argument must be one of %s." % str(arr_types))
            if not isinstance(field, str):
                raise SyntaxError("Second argument must be a string.")

            self.arbor._node_io.get_fields(self, fields=[field], root_only=False)
            indices = getattr(self, "_%s_field_indices" % ftype)
            return self.root._field_data[field][indices]

        else:
            if not isinstance(key, str):
                raise SyntaxError("Single argument must be a string.")

            # return the progenitor list or tree nodes in a list
            if key in arr_types:
                self.arbor._setup_tree(self)
                return getattr(self, "_%s_nodes" % key)

            # return field value for this node
            self.arbor._node_io.get_fields(self, fields=[key],
                                           root_only=self.is_root)
            if self.is_root:
                data_object = self
            else:
                data_object = self.root
            return data_object._field_data[key][self.treeid]

    def __repr__(self):
        return "TreeNode[%d]" % self.uid

    _tfi = None
    @property
    def _tree_field_indices(self):
        """
        Return the field array indices for all TreeNodes in
        the tree beneath, starting with this TreeNode.
        """
        if self._tfi is None:
            self._set_tree_attrs()
        return self._tfi

    _tn = None
    @property
    def _tree_nodes(self):
        """
        Return a list of all TreeNodes in the tree beneath,
        starting with this TreeNode.
        """
        if self._tn is None:
            self._set_tree_attrs()
        return self._tn

    def _set_tree_attrs(self):
        """
        Prepare the TreeNode list and field indices.
        """
        self.arbor._grow_tree(self)
        tfi = []
        tn = []
        for my_node in self.twalk():
            tfi.append(my_node.treeid)
            tn.append(my_node)
        self._tfi = np.array(tfi)
        self._tn = np.array(tn)

    _pfi = None
    @property
    def _prog_field_indices(self):
        """
        Return the field array indices for all TreeNodes in
        the progenitor list, starting with this TreeNode.
        """
        if self._pfi is None:
            self._set_prog_attrs()
        return self._pfi

    _pn = None
    @property
    def _prog_nodes(self):
        """
        Return a list of all TreeNodes in the progenitor list, starting
        with this TreeNode.
        """
        if self._pn is None:
            self._set_prog_attrs()
        return self._pn

    def _set_prog_attrs(self):
        """
        Prepare the progenitor list list and field indices.
        """
        self.arbor._grow_tree(self)
        lfi = []
        ln = []
        for my_node in self.pwalk():
            lfi.append(my_node.treeid)
            ln.append(my_node)
        self._pfi = np.array(lfi)
        self._pn = np.array(ln)

    def twalk(self):
        r"""
        An iterator over all TreeNodes in the tree beneath,
        starting with this TreeNode.

        Examples
        --------

        >>> for my_node in my_tree.twalk():
        ...     print (my_node)

        """
        self.arbor._grow_tree(self)
        yield self
        if self.ancestors is None:
            return
        for ancestor in self.ancestors:
            for a_node in ancestor.twalk():
                yield a_node

    def pwalk(self):
        r"""
        An iterator over all TreeNodes in the progenitor list,
        starting with this TreeNode.

        Examples
        --------

        >>> for my_node in my_tree.pwalk():
        ...     print (my_node)

        """
        self.arbor._grow_tree(self)
        my_node = self
        while my_node is not None:
            yield my_node
            if my_node.ancestors is None:
                my_node = None
            else:
                my_node = my_node.arbor.selector(my_node.ancestors)

    def save_tree(self, filename=None, fields=None):
        r"""
        Save the tree to a file.

        The saved tree can be re-loaded as an arbor.

        Parameters
        ----------
        filename : optional, string
            Output file keyword.  Main header file will be named
            <filename>/<filename>.h5.
            Default: "tree_<uid>".
        fields : optional, list of strings
            The fields to be saved.  If not given, all
            fields will be saved.

        Returns
        -------
        filename : string
            The filename of the saved arbor.

        Examples
        --------

        >>> import ytree
        >>> a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat")
        >>> # save the first tree
        >>> fn = a[0].save_tree()
        >>> # reload it
        >>> a2 = ytree.load(fn)

        """

        if filename is None:
            filename = "tree_%d" % self.uid

        return self.arbor.save_arbor(
            filename=filename, fields=fields,
            trees=[self])
Beispiel #7
0
class TreeNode:
    """
    Class for objects stored in Arbors.

    Each TreeNode represents a halo in a tree.  A TreeNode knows
    its halo ID, the level in the tree, and its global ID in the
    Arbor that holds it.  It also has a list of its ancestors.
    Fields can be queried for it, its progenitor list, and the
    tree beneath.
    """

    _link = None

    def __init__(self, uid, arbor=None, root=False):
        """
        Initialize a TreeNode with at least its halo catalog ID and
        its level in the tree.
        """
        self.uid = uid
        self.arbor = weakref.proxy(arbor)
        if root:
            self.root = -1
            self.field_data = FieldContainer(arbor)
        else:
            self.root = None

    _tree_id = None  # used by CatalogArbor

    @property
    def tree_id(self):
        """
        Return the index of this node in a list of all nodes in the tree.
        """
        if self.is_root:
            return 0
        elif self._link is not None:
            return self._link.tree_id
        else:
            return self._tree_id

    @tree_id.setter
    def tree_id(self, value):
        """
        Set the tree_id manually in CatalogArbors.
        """
        self._tree_id = value

    @property
    def is_root(self):
        """
        Is this node the last in the tree?
        """
        return self.root in [-1, self]

    def find_root(self):
        """
        Find the root node.
        """

        if self.is_root:
            return self

        root = self.root
        if root is not None:
            return root

        return self.walk_to_root()

    def walk_to_root(self):
        """
        Walk descendents until root.
        """

        my_node = self
        while not my_node.is_root:
            if my_node.descendent in (-1, None):
                break
            my_node = my_node.descendent
        return my_node

    def clear_fields(self):
        """
        If a root node, delete field data.
        If not root node, do nothing.
        """

        if not self.is_root:
            return
        self.field_data.clear()

    _descendent = None  # used by CatalogArbor

    @property
    def descendent(self):
        """
        Return the descendent node.
        """

        if self.is_root:
            return None

        # set in CatalogArbor._plant_trees
        if self._descendent is not None:
            return self._descendent

        # conventional Arbor object
        desc_link = self._link.descendent
        if desc_link is None:
            return None
        return self.arbor._generate_tree_node(self.root, desc_link)

    _ancestors = None  # used by CatalogArbor

    @property
    def ancestors(self):
        """
        Return a generator of ancestor nodes.
        """

        self.arbor._grow_tree(self)

        # conventional Arbor object
        if self._link is not None:
            for link in self._link.ancestors:
                yield self.arbor._generate_tree_node(self.root, link)
            return

        # If tree is not setup yet, the ancestor nodes will not have
        # root pointers yet.
        need_root = not self.arbor.is_setup(self)
        if need_root:
            root = self.walk_to_root()

        # set in CatalogArbor._plant_trees
        if self._ancestors is not None:
            for ancestor in self._ancestors:
                if need_root:
                    ancestor.root = root
                yield ancestor
            return
        return None

    _uids = None

    @property
    def uids(self):
        """
        Array of uids for all nodes in the tree.
        """
        if not self.is_root:
            return None
        if self._uids is None:
            self.arbor._build_attr("_uids", self)
        return self._uids

    _desc_uids = None

    @property
    def desc_uids(self):
        """
        Array of descendent uids for all nodes in the tree.
        """
        if not self.is_root:
            return None
        if self._desc_uids is None:
            self.arbor._build_attr("_desc_uids", self)
        return self._desc_uids

    _tree_size = None

    @property
    def tree_size(self):
        """
        Number of nodes in the tree.
        """
        if self._tree_size is not None:
            return self._tree_size
        if self.is_root:
            self.arbor._setup_tree(self)
            # pass back to the arbor to avoid calculating again
            self.arbor._store_node_info(self, '_tree_size')
        else:
            self._tree_size = len(list(self["tree"]))
        return self._tree_size

    _link_storage = None

    @property
    def _links(self):
        """
        Array of NodeLink objects with the ancestor/descendent structure.

        This is only used by conventional Arbor objects, i.e., not
        CatalogArbor objects.
        """
        if not self.is_root:
            return None
        if self._link_storage is None:
            self.arbor._build_attr("_link_storage", self)
        return self._link_storage

    def __setitem__(self, key, value):
        """
        Set analysis field value for this node.
        """

        fi = self.arbor.field_info[key]
        ftype = fi.get('type')
        if ftype not in ['analysis', 'analysis_saved']:
            raise ArborUnsettableField(key, self.arbor)

        vector_fieldname = fi.get("vector_fieldname", None)
        has_vector_field = vector_fieldname is not None

        if self.is_root:
            root = self
            tree_id = 0
            # if root, set the value in the arbor field storage
            self.arbor[key][self._arbor_index] = value
            if has_vector_field and vector_fieldname in self.arbor.field_data:
                del self.arbor.field_data[vector_fieldname]
        else:
            root = self.root
            tree_id = self.tree_id
        self.arbor._node_io.get_fields(self, fields=[key], root_only=False)
        data = root.field_data[key]
        data[tree_id] = value
        if has_vector_field and vector_fieldname in root.field_data:
            del root.field_data[vector_fieldname]

    def __getitem__(self, key):
        """
        Return field values or tree/prog generators.
        """
        return self.query(key)

    def query(self, key):
        """
        Return field values for this TreeNode, progenitor list, or tree.

        Parameters
        ----------
        key : string or tuple
            If a single string, it can be either a field to be queried or
            one of "tree" or "prog".  If a field, then return the value of
            the field for this TreeNode.  If "tree" or "prog", then return
            the list of TreeNodes in the tree or progenitor list.

            If a tuple, this can be either (string, string) or (string, int),
            where the first argument must be either "tree" or "prog".
            If second argument is a string, then return the field values
            for either the tree or the progenitor list.  If second argument
            is an int, then return the nth TreeNode in the tree or progenitor
            list list.

        Examples
        --------
        >>> # virial mass for this halo
        >>> print (my_tree["mvir"].to("Msun/h"))

        >>> # all TreeNodes in the progenitor list
        >>> print (my_tree["prog"])
        >>> # all TreeNodes in the entire tree
        >>> print (my_tree["tree"])

        >>> # virial masses for the progenitor list
        >>> print (my_tree["prog", "mvir"].to("Msun/h"))

        >>> # the 3rd TreeNode in the progenitor list
        >>> print (my_tree["prog", 2])

        Returns
        -------
        float, ndarray/unyt_array, TreeNode

        """
        arr_types = ("forest", "prog", "tree")
        if isinstance(key, tuple):
            if len(key) != 2:
                raise SyntaxError("Must be either 1 or 2 arguments.")
            ftype, field = key
            if ftype not in arr_types:
                raise SyntaxError(
                    f"First argument must be one of {str(arr_types)}.")
            if not isinstance(field, str):
                raise SyntaxError("Second argument must be a string.")

            self.arbor._node_io.get_fields(self,
                                           fields=[field],
                                           root_only=False)
            indices = getattr(self, f"_{ftype}_field_indices")

            data_object = self.find_root()
            return data_object.field_data[field][indices]

        else:
            if not isinstance(key, str):
                raise SyntaxError("Single argument must be a string.")

            # return the progenitor list or tree nodes in a list
            if key in arr_types:
                self.arbor._setup_tree(self)
                return getattr(self, f"_{key}_nodes")

            # return field value for this node
            self.arbor._node_io.get_fields(self,
                                           fields=[key],
                                           root_only=self.is_root)
            data_object = self.find_root()
            return data_object.field_data[key][self.tree_id]

    def __repr__(self):
        """
        Call me TreeNode.
        """
        return f"TreeNode[{self.uid}]"

    def get_node(self, selector, index):
        """
        Get a single TreeNode from a tree.

        Use this to get the nth TreeNode from a forest, tree, or
        progenitor list for which the calling TreeNode is the head.

        Parameters
        ----------
        selector : str ("forest", "tree", or "prog")
            The tree selector from which to get the TreeNode. This
            should be "forest", "tree", or "prog".
        index : int
            The index of the desired TreeNode in the forest, tree,
            or progenitor list.

        Returns
        -------
        node: :class:`~ytree.data_structures.tree_node.TreeNode`

        Examples
        --------

        >>> import ytree
        >>> a = ytree.load("tiny_ctrees/locations.dat")
        >>> my_tree = a[0]
        >>> # get 6th TreeNode in the progenitor list
        >>> my_node = my_tree.get_node('prog', 5)

        """

        self.arbor._setup_tree(self)
        self.arbor._grow_tree(self)
        indices = getattr(self, f"_{selector}_field_indices", None)
        if indices is None:
            raise RuntimeError("Bad selector.")

        my_link = self.root._links[indices][index]
        return self.arbor._generate_tree_node(self.root, my_link)

    def get_leaf_nodes(self, selector=None):
        """
        Get all leaf nodes from the tree of which this is the head.

        This returns a generator of all leaf nodes belonging to this
        tree. A leaf node is a node that has no ancestors.

        Parameters
        ----------
        selector : optional, str ("forest", "tree", or "prog")
            The tree selector from which leaf nodes will be found.
            If none given, this will be set to "forest" if the
            calling node is a root node and "tree" otherwise.

        Returns
        -------
        leaf_nodes : a generator of
            :class:`~ytree.data_structures.tree_node.TreeNode` objects.

        Examples
        --------

        >>> import ytree
        >>> a = ytree.load("tiny_ctrees/locations.dat")
        >>> my_tree = a[0]
        >>> for leaf in my_tree.get_leaf_nodes():
        ...     print (leaf["mass"])

        """

        if selector is None:
            if self.is_root:
                selector = "forest"
            else:
                selector = "tree"

        uids = self[selector, "uid"]
        desc_uids = self[selector, "desc_uid"]
        lids = np.where(~np.in1d(uids, desc_uids))[0]
        for lid in lids:
            yield self.get_node(selector, lid)

    def get_root_nodes(self):
        """
        Get all root nodes from the forest to which this node belongs.

        This returns a generator of all root nodes in the forest. A root
        node is a node that has no descendents.

        Returns
        -------
        root_nodes : a generator of
            :class:`~ytree.data_structures.tree_node.TreeNode` objects.

        Examples
        --------

        >>> import ytree
        >>> a = ytree.load("consistent_trees_hdf5/soa/forest.h5",
        ...                access="forest")
        >>> my_tree = a[0]
        >>> for root in my_tree.get_root_nodes():
        ...     print (root["mass"])

        """

        selector = "forest"
        desc_uids = self[selector, "desc_uid"]
        rids = np.where(desc_uids == -1)[0]
        for rid in rids:
            yield self.get_node(selector, rid)

    _ffi = slice(None)

    @property
    def _forest_field_indices(self):
        """
        Return default slice to select the whole forest.
        """
        return self._ffi

    @property
    def _forest_nodes(self):
        """
        An iterator over all TreeNodes in the forest.

        This is different from _tree_nodes in that we don't walk
        through the ancestors lists. We just yield every TreeNode
        there is.
        """

        self.arbor._grow_tree(self)
        root = self.root
        for link in root._links:
            yield self.arbor._generate_tree_node(self.root, link)

    @property
    def _tree_nodes(self):
        """
        An iterator over all TreeNodes in the tree beneath,
        starting with this TreeNode.

        For internal use only. Use the following instead:

        >>> for my_node in my_tree['tree']:
        ...     print (my_node)

        Examples
        --------

        >>> for my_node in my_tree._tree_nodes:
        ...     print (my_node)

        """

        self.arbor._grow_tree(self)
        yield self
        if self.ancestors is None:
            return
        for ancestor in self.ancestors:
            for a_node in ancestor._tree_nodes:
                yield a_node

    _tfi = None

    @property
    def _tree_field_indices(self):
        """
        Return the field array indices for all TreeNodes in
        the tree beneath, starting with this TreeNode.
        """

        if self._tfi is not None:
            return self._tfi

        self.arbor._grow_tree(self)
        self._tfi = np.array([node.tree_id for node in self._tree_nodes])
        return self._tfi

    @property
    def _prog_nodes(self):
        """
        An iterator over all TreeNodes in the progenitor list,
        starting with this TreeNode.

        For internal use only. Use the following instead:

        >>> for my_node in my_tree['prog']:
        ...     print (my_node)

        Examples
        --------

        >>> for my_node in my_tree._prog_nodes:
        ...     print (my_node)

        """

        self.arbor._grow_tree(self)
        my_node = self
        while my_node is not None:
            yield my_node
            ancestors = list(my_node.ancestors)
            if ancestors:
                my_node = my_node.arbor.selector(ancestors)
            else:
                my_node = None

    _pfi = None

    @property
    def _prog_field_indices(self):
        """
        Return the field array indices for all TreeNodes in
        the progenitor list, starting with this TreeNode.
        """

        if self._pfi is not None:
            return self._pfi

        self.arbor._grow_tree(self)
        self._pfi = np.array([node.tree_id for node in self._prog_nodes])
        return self._pfi

    def save_tree(self, filename=None, fields=None):
        r"""
        Save the tree to a file.

        The saved tree can be re-loaded as an arbor.

        Parameters
        ----------
        filename : optional, string
            Output file keyword.  Main header file will be named
            <filename>/<filename>.h5.
            Default: "tree_<uid>".
        fields : optional, list of strings
            The fields to be saved.  If not given, all
            fields will be saved.

        Returns
        -------
        filename : string
            The filename of the saved arbor.

        Examples
        --------

        >>> import ytree
        >>> a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat")
        >>> # save the first tree
        >>> fn = a[0].save_tree()
        >>> # reload it
        >>> a2 = ytree.load(fn)

        """

        if filename is None:
            filename = f"tree_{self.uid}"

        return self.arbor.save_arbor(filename=filename,
                                     fields=fields,
                                     trees=[self])