Esempio n. 1
0
    def _setup_tree(self, tree_node, **kwargs):
        """
        Check for desc_uids missing from uid list.
        """

        super()._setup_tree(tree_node, **kwargs)
        uids = tree_node.uids
        desc_uids = tree_node.desc_uids
        missing = np.setdiff1d(desc_uids, uids)
        missing = np.setdiff1d(missing, [-1])
        if missing.size == 0:
            return

        mfields = ["snap_index", "descendant_index"]
        field_data = self._node_io._read_fields(tree_node, mfields)

        xsize = self._redshifts.size - 1
        ysize = tree_node._ei - tree_node._si
        xindex = xsize - field_data["snap_index"] - 1
        yindex = field_data["descendant_index"] - tree_node._si

        for muid in missing:
            mi = np.where(desc_uids == muid)[0]
            ndi = np.ravel_multi_index([xindex[mi], yindex[mi]],
                                       (xsize, ysize))
            for my_mi, my_ndi in zip(mi, ndi):
                new_uid = uids[np.where(my_ndi == tree_node._status)][0]
                mylog.info(
                    f"Reassigning descendent of halo {uids[my_mi]} from "
                    f"{desc_uids[my_mi]} to {new_uid}.")
                desc_uids[my_mi] = new_uid
Esempio n. 2
0
def compare_arbors(a1, a2, groups=None, fields=None, skip1=1, skip2=1):
    """
    Compare all fields for all trees in two arbors.
    """

    if groups is None:
        groups = ["tree", "prog"]

    if fields is None:
        fields = a1.field_list

    for i, field in enumerate(fields):
        mylog.info(f"Comparing arbor field: {field} ({i+1}/{len(fields)}).")
        assert_array_equal(a1[field][::skip1],
                           a2[field][::skip2],
                           err_msg=f"Arbor field mismatch: {a1, a2, field}.")

    trees1 = list(a1[::skip1])
    trees2 = list(a2[::skip2])

    ntot = len(trees1)
    pbar = get_pbar("Comparing trees", ntot)
    for i, (t1, t2) in enumerate(zip(trees1, trees2)):
        compare_trees(t1, t2, groups=groups, fields=fields)
        pbar.update(i + 1)
    pbar.finish()
Esempio n. 3
0
    def test_vector_fields(self):
        a = self.arbor
        t = a[0]
        for field in a.field_info.vector_fields:

            mylog.info(f"Comparing vector field: {field}.")
            magfield = np.sqrt((a[field]**2).sum(axis=1))
            assert_array_equal(a[f"{field}_magnitude"],
                               magfield,
                               err_msg=f"Magnitude field incorrect: {field}.")

            for i, ax in enumerate("xyz"):
                assert_array_equal(
                    a[f"{field}_{ax}"],
                    a[field][:, i],
                    err_msg=(f"Arbor vector field {field} does not match "
                             f"in dimension {i}."))

                assert_array_equal(
                    t[f"{field}_{ax}"],
                    t[field][i],
                    err_msg=(f"Tree vector field {field} does not match "
                             f"in dimension {i}."))

                for group in ["prog", "tree"]:
                    assert_array_equal(
                        t[group, f"{field}_{ax}"],
                        t[group, field][:, i],
                        err_msg=(
                            f"{group} vector field {field} does not match "
                            f"in dimension {i}."))
Esempio n. 4
0
def _print_link_info(ds1, ds2):
    """
    Print information about linking datasets.
    """

    units = {"current_time": "Gyr"}
    for attr in ["basename", "current_time", "current_redshift"]:
        v1 = getattr(ds1, attr)
        v2 = getattr(ds2, attr)
        if attr in units:
            v1.convert_to_units(units[attr])
            v2.convert_to_units(units[attr])
        s = "Linking: %-20s = %-28s - %-28s" % (attr, v1, v2)
        mylog.info(s)
Esempio n. 5
0
    def _save_catalog(self, filename, ds, halos, fields=None):
        """
        Save halo catalog with descendent information.
        """
        if self.comm is None:
            rank = 0
        else:
            rank = self.comm.rank
        filename = get_output_filename(
            filename, "%s.%d" % (_get_tree_basename(ds), rank), ".h5")

        if fields is None:
            my_fields = []
        else:
            my_fields = fields[:]

        default_fields = \
          ["particle_identifier",
           "descendent_identifier",
           "particle_mass"] + \
           ["particle_position_%s" % ax for ax in "xyz"] + \
           ["particle_velocity_%s" % ax for ax in "xyz"]
        for field in default_fields:
            if field not in my_fields:
                my_fields.append(field)

        if isinstance(halos, list):
            num_halos = len(halos)
            data = self._create_halo_data_lists(halos, my_fields)
        else:
            num_halos = ds.index.particle_count[halos]
            data = dict((field, ds.r[halos, field].in_base())
                        for field in my_fields
                        if field != "descendent_identifier")
            data["descendent_identifier"] = -1 * np.ones(num_halos)
        ftypes = dict([(field, ".") for field in data])
        extra_attrs = {"num_halos": num_halos, "data_type": "halo_catalog"}
        mylog.info("Saving catalog with %d halos to %s." %
                   (num_halos, filename))
        save_as_dataset(ds,
                        filename,
                        data,
                        field_types=ftypes,
                        extra_attrs=extra_attrs)
Esempio n. 6
0
    def trace_ancestors(self, halo_type, root_ids, fields=None, filename=None):
        """
        Trace the ancestry of a given set of halos.

        A merger-tree for a specific set of halos will be created,
        starting with the last halo catalog and moving backward.

        Parameters
        ----------
        halo_type : string
            The type of halo, typically "FOF" for FoF groups or
            "Subfind" for subhalos.
        root_ids : integer or array of integers
            The halo IDs from the last halo catalog for the
            targeted halos.
        fields : optional, list of strings
            List of additional fields to be saved to halo catalogs.
        filename : optional, string
            Directory in which merger-tree catalogs will be saved.
        """

        output_dir = os.path.dirname(filename)
        if self.comm.rank == 0 and len(output_dir) > 0:
            ensure_dir(output_dir)

        all_outputs = self.ts.outputs[::-1]
        ds1 = None

        for i, fn2 in enumerate(all_outputs[1:]):
            fn1 = all_outputs[i]
            target_filename = get_output_filename(
                filename, "%s.%d" % (_get_tree_basename(fn1), 0), ".h5")
            catalog_filename = get_output_filename(
                filename, "%s.%d" % (_get_tree_basename(fn2), 0), ".h5")
            if os.path.exists(catalog_filename):
                continue

            if ds1 is None:
                ds1 = self._load_ds(fn1, index_ptype=halo_type)
            ds2 = self._load_ds(fn2, index_ptype=halo_type)

            if self.comm.rank == 0:
                _print_link_info(ds1, ds2)

            if ds2.index.particle_count[halo_type] == 0:
                mylog.info("%s has no halos of type %s, ending." %
                           (ds2, halo_type))
                break

            if i == 0:
                target_ids = root_ids
                if not iterable(target_ids):
                    target_ids = np.array([target_ids])
                if isinstance(target_ids, YTArray):
                    target_ids = target_ids.d
                if target_ids.dtype != np.int64:
                    target_ids = target_ids.astype(np.int64)
            else:
                mylog.info("Loading target ids from %s.", target_filename)
                ds_target = yt_load(target_filename)
                target_ids = \
                  ds_target.r["halos",
                              "particle_identifier"].d.astype(np.int64)
                del ds_target

            id_store = []
            target_halos = []
            ancestor_halos = []

            njobs = min(self.comm.size, target_ids.size)
            pbar = get_pbar("Linking halos", target_ids.size, parallel=True)
            my_i = 0
            for halo_id in parallel_objects(target_ids, njobs=njobs):
                my_halo = ds1.halo(halo_type, halo_id)

                target_halos.append(my_halo)
                my_ancestors = self._find_ancestors(my_halo,
                                                    ds2,
                                                    id_store=id_store)
                ancestor_halos.extend(my_ancestors)
                my_i += njobs
                pbar.update(my_i)
            pbar.finish()

            if i == 0:
                for halo in target_halos:
                    halo.descendent_identifier = -1
                self._save_catalog(filename, ds1, target_halos, fields)
            self._save_catalog(filename, ds2, ancestor_halos, fields)

            if len(ancestor_halos) == 0:
                break

            ds1 = ds2
            clear_id_cache()