Exemplo n.º 1
0
def save_field(ds, fields, field_parameters=None):
    """
    Write a single field associated with the dataset ds to the
    backup file.

    Parameters
    ----------
    ds : Dataset object
        The yt dataset that the field is associated with.
    fields : field of list of fields
        The name(s) of the field(s) to save.
    field_parameters : dictionary
        A dictionary of field parameters to set.
    """

    fields = list(iter_fields(fields))
    for field_name in fields:
        if isinstance(field_name, tuple):
            field_name = field_name[1]
        field_obj = ds._get_field_info(field_name)
        if field_obj.sampling_type == "particle":
            print("Saving particle fields currently not supported.")
            return

    with _get_backup_file(ds) as f:
        # now save the field
        _write_fields_to_gdf(
            ds,
            f,
            fields,
            particle_type_name="dark_matter",
            field_parameters=field_parameters,
        )
Exemplo n.º 2
0
    def __init__(
        self,
        data_source,
        x_field,
        y_field,
        z_fields=None,
        color="b",
        x_bins=800,
        y_bins=800,
        weight_field=None,
        deposition="ngp",
        fontsize=18,
        figure_size=8.0,
        shading="nearest",
    ):

        # if no z_fields are passed in, use a constant color
        if z_fields is None:
            self.use_cbar = False
            self.splat_color = color
            z_fields = [("all", "particle_ones")]

        profile = create_profile(
            data_source,
            [x_field, y_field],
            list(iter_fields(z_fields)),
            n_bins=[x_bins, y_bins],
            weight_field=weight_field,
            deposition=deposition,
        )

        type(self)._initialize_instance(
            self, data_source, profile, fontsize, figure_size, shading
        )
Exemplo n.º 3
0
 def __init__(self, field):
     """
     This validator ensures that the output file has a given data field stored
     in it.
     """
     FieldValidator.__init__(self)
     self.fields = list(iter_fields(field))
Exemplo n.º 4
0
 def __call__(self, fields, weight):
     fields = list(iter_fields(fields))
     units = [self.data_source.ds._get_field_info(field).units for field in fields]
     rv = super().__call__(fields, weight)
     rv = [self.data_source.ds.arr(v, u) for v, u in zip(rv, units)]
     if len(rv) == 1:
         rv = rv[0]
     return rv
Exemplo n.º 5
0
def _2d_display(self, fields=None):
    skip = self._key_fields
    skip += list(set(frb._exclude_fields).difference(set(self._key_fields)))
    self.fields = [k for k in self.field_data if k not in skip]
    if fields is not None:
        self.fields = list(iter_fields(fields)) + self.fields
    if len(self.fields) == 0:
        raise ValueError("No fields found to plot in display()")
    return display_yt(self, self.fields[0])
Exemplo n.º 6
0
    def retrieve_ghost_zones(self,
                             n_zones,
                             fields,
                             all_levels=False,
                             smoothed=False):
        NGZ = self.ds.parameters.get("NumberOfGhostZones", 3)
        if n_zones > NGZ:
            return EnzoGrid.retrieve_ghost_zones(self, n_zones, fields,
                                                 all_levels, smoothed)

        # ----- Below is mostly the original code, except we remove the field
        # ----- access section
        # We will attempt this by creating a datacube that is exactly bigger
        # than the grid by nZones*dx in each direction
        nl = self.get_global_startindex() - n_zones
        new_left_edge = nl * self.dds + self.ds.domain_left_edge
        # Something different needs to be done for the root grid, though
        level = self.Level
        kwargs = {
            "dims": self.ActiveDimensions + 2 * n_zones,
            "num_ghost_zones": n_zones,
            "use_pbar": False,
        }
        # This should update the arguments to set the field parameters to be
        # those of this grid.
        kwargs.update(self.field_parameters)
        if smoothed:
            # cube = self.index.smoothed_covering_grid(
            #    level, new_left_edge, new_right_edge, **kwargs)
            cube = self.index.smoothed_covering_grid(level, new_left_edge,
                                                     **kwargs)
        else:
            cube = self.index.covering_grid(level, new_left_edge, **kwargs)
        # ----- This is EnzoGrid.get_data, duplicated here mostly for
        # ----  efficiency's sake.
        start_zone = NGZ - n_zones
        if start_zone == 0:
            end_zone = None
        else:
            end_zone = -(NGZ - n_zones)
        sl = tuple(slice(start_zone, end_zone) for i in range(3))
        if fields is None:
            return cube
        for field in iter_fields(fields):
            if field in self.field_list:
                conv_factor = 1.0
                if field in self.ds.field_info:
                    conv_factor = self.ds.field_info[field]._convert_function(
                        self)
                if self.ds.field_info[field].sampling_type == "particle":
                    continue
                temp = self.index.io._read_raw_data_set(self, field)
                temp = temp.swapaxes(0, 2)
                cube.field_data[field] = np.multiply(temp, conv_factor,
                                                     temp)[sl]
        return cube
Exemplo n.º 7
0
 def get_data(self, fields=None):
     fields = list(iter_fields(fields))
     self.base_object.get_data(fields)
     ind = self._cond_ind
     for field in fields:
         f = self.base_object[field]
         if f.shape != ind.shape:
             parent = getattr(self, "parent", self.base_object)
             self.field_data[field] = parent[field][self._part_ind(field[0])]
         else:
             self.field_data[field] = self.base_object[field][ind]
Exemplo n.º 8
0
    def to_fits_data(self,
                     fields=None,
                     other_keys=None,
                     length_unit=None,
                     **kwargs):
        r"""Export the fields in this FixedResolutionBuffer instance
        to a FITSImageData instance.

        This will export a set of FITS images of either the fields specified
        or all the fields already in the object.

        Parameters
        ----------
        fields : list of strings
            These fields will be pixelized and output. If "None", the keys of
            the FRB will be used.
        other_keys : dictionary, optional
            A set of header keys and values to write into the FITS header.
        length_unit : string, optional
            the length units that the coordinates are written in. The default
            is to use the default length unit of the dataset.
        """
        from yt.visualization.fits_image import FITSImageData

        if length_unit is None:
            length_unit = self.ds.length_unit

        if "units" in kwargs:
            issue_deprecation_warning("The 'units' keyword argument has been "
                                      "replaced by the 'length_unit' keyword "
                                      "argument and the former has been "
                                      "deprecated. Setting 'length_unit' "
                                      "to 'units'.")
            length_unit = kwargs.pop("units")

        if fields is None:
            fields = list(self.data.keys())
        else:
            fields = list(iter_fields(fields))

        if len(fields) == 0:
            raise RuntimeError(
                "No fields to export. Either pass a field or list of fields to "
                "to_fits_data or access a field from the FixedResolutionBuffer "
                "object.")

        fid = FITSImageData(self, fields=fields, length_unit=length_unit)
        if other_keys is not None:
            for k, v in other_keys.items():
                fid.update_all_headers(k, v)
        return fid
Exemplo n.º 9
0
    def hide_axes(self, field=None, draw_frame=False):
        """
        Hides the axes for a plot and updates the size of the
        plot accordingly.  Defaults to operating on all fields for a
        PlotContainer object.

        Parameters
        ----------

        field : string, field tuple, or list of strings or field tuples (optional)
            The name of the field(s) that we want to hide the axes.

        draw_frame : boolean
            If True, the axes frame will still be drawn. Defaults to False.
            See note below for more details.

        Examples
        --------

        This will save an image with no axes.

        >>> import yt
        >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
        >>> s = SlicePlot(ds, 2, "density", "c", (20, "kpc"))
        >>> s.hide_axes()
        >>> s.save()

        This will save an image with no axis or colorbar.

        >>> import yt
        >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
        >>> s = SlicePlot(ds, 2, "density", "c", (20, "kpc"))
        >>> s.hide_axes()
        >>> s.hide_colorbar()
        >>> s.save()

        Note
        ----
        By default, when removing the axes, the patch on which the axes are
        drawn is disabled, making it impossible to later change e.g. the
        background colour. To force the axes patch to be displayed while still
        hiding the axes, set the ``draw_frame`` keyword argument to ``True``.
        """
        if field is None:
            field = self.fields
        for f in iter_fields(field):
            self.plots[f].hide_axes(draw_frame)
        return self
Exemplo n.º 10
0
    def show_axes(self, field=None):
        """
        Shows the axes for a plot and updates the size of the
        plot accordingly.  Defaults to operating on all fields for a
        PlotContainer object.  See hide_axes().

        Parameters
        ----------

        field : string, field tuple, or list of strings or field tuples (optional)
            The name of the field(s) that we want to show the axes.
        """
        if field is None:
            field = self.fields
        for f in iter_fields(field):
            self.plots[f].show_axes()
        return self
Exemplo n.º 11
0
 def __init__(
     self,
     ds,
     normal,
     fields,
     center="c",
     width=(1.0, "unitary"),
     weight_field=None,
     image_res=512,
     data_source=None,
     north_vector=None,
     depth=(1.0, "unitary"),
     method="integrate",
     length_unit=None,
 ):
     fields = list(iter_fields(fields))
     center, dcenter = ds.coordinates.sanitize_center(center, 4)
     buf = {}
     width = ds.coordinates.sanitize_width(normal, width, depth)
     wd = tuple(el.in_units("code_length").v for el in width)
     if not is_sequence(image_res):
         image_res = (image_res, image_res)
     res = (image_res[0], image_res[1])
     if data_source is None:
         source = ds
     else:
         source = data_source
     for field in fields:
         buf[field] = off_axis_projection(
             source,
             center,
             normal,
             wd,
             res,
             field,
             north_vector=north_vector,
             method=method,
             weight=weight_field,
         ).swapaxes(0, 1)
     center = ds.arr([0.0] * 2, "code_length")
     w, not_an_frb, lunit = construct_image(
         ds, normal, buf, center, image_res, width, length_unit
     )
     super(FITSOffAxisProjection, self).__init__(
         buf, fields=fields, wcs=w, length_unit=lunit, ds=ds
     )
Exemplo n.º 12
0
 def __init__(
     self,
     ds,
     normal,
     fields,
     image_res=512,
     center="c",
     width=None,
     north_vector=None,
     length_unit=None,
 ):
     fields = list(iter_fields(fields))
     center, dcenter = ds.coordinates.sanitize_center(center, 4)
     cut = ds.cutting(normal, center, north_vector=north_vector)
     center = ds.arr([0.0] * 2, "code_length")
     w, frb, lunit = construct_image(ds, normal, cut, center, image_res,
                                     width, length_unit)
     super().__init__(frb, fields=fields, length_unit=lunit, wcs=w)
Exemplo n.º 13
0
 def __init__(
     self,
     ds,
     axis,
     fields,
     image_res=512,
     center="c",
     width=None,
     length_unit=None,
     **kwargs,
 ):
     fields = list(iter_fields(fields))
     axis = fix_axis(axis, ds)
     center, dcenter = ds.coordinates.sanitize_center(center, axis)
     slc = ds.slice(axis, center[axis], **kwargs)
     w, frb, lunit = construct_image(ds, axis, slc, dcenter, image_res,
                                     width, length_unit)
     super().__init__(frb, fields=fields, length_unit=lunit, wcs=w)
Exemplo n.º 14
0
 def __init__(
     self,
     ds,
     axis,
     fields,
     image_res=512,
     center="c",
     width=None,
     weight_field=None,
     length_unit=None,
     **kwargs,
 ):
     fields = list(iter_fields(fields))
     axis = fix_axis(axis, ds)
     center, dcenter = ds.coordinates.sanitize_center(center, axis)
     prj = ds.proj(fields[0], axis, weight_field=weight_field, **kwargs)
     w, frb, lunit = construct_image(ds, axis, prj, dcenter, image_res,
                                     width, length_unit)
     super().__init__(frb, fields=fields, length_unit=lunit, wcs=w)
Exemplo n.º 15
0
    def _find_field_values_at_points(self, fields, coords):
        r"""Find the value of fields at a set of coordinates.

        Returns the values [field1, field2,...] of the fields at the given
        (x, y, z) points. Returns a numpy array of field values cross coords
        """
        coords = self.ds.arr(ensure_numpy_array(coords), "code_length")
        grids = self._find_points(coords[:, 0], coords[:, 1], coords[:, 2])[0]
        fields = list(iter_fields(fields))
        mark = np.zeros(3, dtype="int64")
        out = []

        # create point -> grid mapping
        grid_index = {}
        for coord_index, grid in enumerate(grids):
            if grid not in grid_index:
                grid_index[grid] = []
            grid_index[grid].append(coord_index)

        out = []
        for field in fields:
            funit = self.ds._get_field_info(field).units
            out.append(self.ds.arr(np.empty(len(coords)), funit))

        for grid in grid_index:
            cellwidth = (grid.RightEdge -
                         grid.LeftEdge) / grid.ActiveDimensions
            for field_index, field in enumerate(fields):
                for coord_index in grid_index[grid]:
                    mark = (coords[coord_index, :] - grid.LeftEdge) / cellwidth
                    mark = np.array(mark, dtype="int64")
                    out[field_index][coord_index] = grid[field][mark[0],
                                                                mark[1],
                                                                mark[2]]
        if len(fields) == 1:
            return out[0]
        return out
Exemplo n.º 16
0
    def to_pw(self, fields=None, center="c", width=None, axes_unit=None):
        r"""Create a :class:`~yt.visualization.plot_window.PWViewerMPL` from this
        object.

        This is a bare-bones mechanism of creating a plot window from this
        object, which can then be moved around, zoomed, and on and on.  All
        behavior of the plot window is relegated to that routine.
        """
        normal = self.normal
        center = self.center
        self.fields = list(iter_fields(fields)) + [
            k for k in self.field_data.keys() if k not in self._key_fields
        ]
        from yt.visualization.fixed_resolution import FixedResolutionBuffer
        from yt.visualization.plot_window import (
            PWViewerMPL,
            get_oblique_window_parameters,
        )

        (bounds,
         center_rot) = get_oblique_window_parameters(normal, center, width,
                                                     self.ds)
        pw = PWViewerMPL(
            self,
            bounds,
            fields=self.fields,
            origin="center-window",
            periodic=False,
            oblique=True,
            frb_generator=FixedResolutionBuffer,
            plot_type="OffAxisSlice",
        )
        if axes_unit is not None:
            pw.set_axes_unit(axes_unit)
        pw._setup_plots()
        return pw
Exemplo n.º 17
0
def write_to_gdf(
    ds,
    gdf_path,
    fields=None,
    data_author=None,
    data_comment=None,
    dataset_units=None,
    particle_type_name="dark_matter",
    overwrite=False,
    **kwargs,
):
    """
    Write a dataset to the given path in the Grid Data Format.

    Parameters
    ----------
    ds : Dataset object
        The yt data to write out.
    gdf_path : string
        The path of the file to output.
    fields
        The field or list of fields to write out. If None, defaults to
        ds.field_list.
    data_author : string, optional
        The name of the author who wrote the data. Default: None.
    data_comment : string, optional
        A descriptive comment. Default: None.
    dataset_units : dictionary, optional
        A dictionary of (value, unit) tuples to set the default units
        of the dataset. Keys can be:

        * "length_unit"
        * "time_unit"
        * "mass_unit"
        * "velocity_unit"
        * "magnetic_unit"

        If not specified, these will carry over from the parent
        dataset.
    particle_type_name : string, optional
        The particle type of the particles in the dataset. Default: "dark_matter"
    overwrite : boolean, optional
        Whether or not to overwrite an already existing file. If False, attempting
        to overwrite an existing file will result in an exception.

    Examples
    --------
    >>> dataset_units = {"length_unit":(1.0,"Mpc"),
    ...                  "time_unit":(1.0,"Myr")}
    >>> write_to_gdf(ds, "clumps.h5", data_author="John ZuHone",
    ...              dataset_units=dataset_units,
    ...              data_comment="My Really Cool Dataset", overwrite=True)
    """
    if fields is None:
        fields = ds.field_list

    fields = list(iter_fields(fields))

    with _create_new_gdf(
            ds,
            gdf_path,
            data_author,
            data_comment,
            dataset_units=dataset_units,
            particle_type_name=particle_type_name,
            overwrite=overwrite,
    ) as f:

        # now add the fields one-by-one
        _write_fields_to_gdf(ds, f, fields, particle_type_name)
Exemplo n.º 18
0
    def __init__(
        self,
        data,
        fields=None,
        length_unit=None,
        width=None,
        img_ctr=None,
        wcs=None,
        current_time=None,
        time_unit=None,
        mass_unit=None,
        velocity_unit=None,
        magnetic_unit=None,
        ds=None,
        unit_header=None,
        **kwargs,
    ):
        r"""Initialize a FITSImageData object.

        FITSImageData contains a collection of FITS ImageHDU instances and
        WCS information, along with units for each of the images. FITSImageData
        instances can be constructed from ImageArrays, NumPy arrays, dicts
        of such arrays, FixedResolutionBuffers, and YTCoveringGrids. The latter
        two are the most powerful because WCS information can be constructed
        automatically from their coordinates.

        Parameters
        ----------
        data : FixedResolutionBuffer or a YTCoveringGrid. Or, an
            ImageArray, an numpy.ndarray, or dict of such arrays
            The data to be made into a FITS image or images.
        fields : single string or list of strings, optional
            The field names for the data. If *fields* is none and *data* has
            keys, it will use these for the fields. If *data* is just a
            single array one field name must be specified.
        length_unit : string
            The units of the WCS coordinates and the length unit of the file.
            Defaults to the length unit of the dataset, if there is one, or
            "cm" if there is not.
        width : float or YTQuantity
            The width of the image. Either a single value or iterable of values.
            If a float, assumed to be in *units*. Only used if this information
            is not already provided by *data*.
        img_ctr : array_like or YTArray
            The center coordinates of the image. If a list or NumPy array,
            it is assumed to be in *units*. Only used if this information
            is not already provided by *data*.
        wcs : `~astropy.wcs.WCS` instance, optional
            Supply an AstroPy WCS instance. Will override automatic WCS
            creation from FixedResolutionBuffers and YTCoveringGrids.
        current_time : float, tuple, or YTQuantity, optional
            The current time of the image(s). If not specified, one will
            be set from the dataset if there is one. If a float, it will
            be assumed to be in *time_unit* units.
        time_unit : string
            The default time units of the file. Defaults to "s".
        mass_unit : string
            The default time units of the file. Defaults to "g".
        velocity_unit : string
            The default velocity units of the file. Defaults to "cm/s".
        magnetic_unit : string
            The default magnetic units of the file. Defaults to "gauss".
        ds : `~yt.static_output.Dataset` instance, optional
            The dataset associated with the image(s), typically used
            to transfer metadata to the header(s). Does not need to be
            specified if *data* has a dataset as an attribute.

        Examples
        --------

        >>> # This example uses a FRB.
        >>> ds = load("sloshing_nomag2_hdf5_plt_cnt_0150")
        >>> prj = ds.proj(2, "kT", weight_field="density")
        >>> frb = prj.to_frb((0.5, "Mpc"), 800)
        >>> # This example just uses the FRB and puts the coords in kpc.
        >>> f_kpc = FITSImageData(frb, fields="kT", length_unit="kpc",
        ...                       time_unit=(1.0, "Gyr"))
        >>> # This example specifies a specific WCS.
        >>> from astropy.wcs import WCS
        >>> w = WCS(naxis=self.dimensionality)
        >>> w.wcs.crval = [30., 45.] # RA, Dec in degrees
        >>> w.wcs.cunit = ["deg"]*2
        >>> nx, ny = 800, 800
        >>> w.wcs.crpix = [0.5*(nx+1), 0.5*(ny+1)]
        >>> w.wcs.ctype = ["RA---TAN","DEC--TAN"]
        >>> scale = 1./3600. # One arcsec per pixel
        >>> w.wcs.cdelt = [-scale, scale]
        >>> f_deg = FITSImageData(frb, fields="kT", wcs=w)
        >>> f_deg.writeto("temp.fits")
        """

        if fields is not None:
            fields = list(iter_fields(fields))

        if ds is None:
            ds = getattr(data, "ds", None)

        self.fields = []
        self.field_units = {}

        if unit_header is None:
            self._set_units(ds, [
                length_unit, mass_unit, time_unit, velocity_unit, magnetic_unit
            ])
        else:
            self._set_units_from_header(unit_header)

        wcs_unit = str(self.length_unit.units)

        self._fix_current_time(ds, current_time)

        if width is None:
            width = 1.0
        if isinstance(width, tuple):
            if ds is None:
                width = YTQuantity(width[0], width[1])
            else:
                width = ds.quan(width[0], width[1])
        if img_ctr is None:
            img_ctr = np.zeros(3)

        exclude_fields = [
            "x",
            "y",
            "z",
            "px",
            "py",
            "pz",
            "pdx",
            "pdy",
            "pdz",
            "weight_field",
        ]

        if isinstance(data, _astropy.pyfits.PrimaryHDU):
            data = _astropy.pyfits.HDUList([data])

        if isinstance(data, _astropy.pyfits.HDUList):
            self.hdulist = data
            for hdu in data:
                self.fields.append(hdu.header["btype"])
                self.field_units[hdu.header["btype"]] = hdu.header["bunit"]

            self.shape = self.hdulist[0].shape
            self.dimensionality = len(self.shape)
            wcs_names = [
                key for key in self.hdulist[0].header if "WCSNAME" in key
            ]
            for name in wcs_names:
                if name == "WCSNAME":
                    key = " "
                else:
                    key = name[-1]
                w = _astropy.pywcs.WCS(header=self.hdulist[0].header,
                                       key=key,
                                       naxis=self.dimensionality)
                setattr(self, "wcs" + key.strip().lower(), w)

            return

        self.hdulist = _astropy.pyfits.HDUList()

        if hasattr(data, "keys"):
            img_data = data
            if fields is None:
                fields = list(img_data.keys())
        elif isinstance(data, np.ndarray):
            if fields is None:
                mylog.warning(
                    "No field name given for this array. Calling it 'image_data'."
                )
                fn = "image_data"
                fields = [fn]
            else:
                fn = fields[0]
            img_data = {fn: data}

        for fd in fields:
            if isinstance(fd, tuple):
                self.fields.append(fd[1])
            elif isinstance(fd, DerivedField):
                self.fields.append(fd.name[1])
            else:
                self.fields.append(fd)

        # Sanity checking names
        s = set()
        duplicates = {f for f in self.fields if f in s or s.add(f)}
        if len(duplicates) > 0:
            for i, fd in enumerate(self.fields):
                if fd in duplicates:
                    if isinstance(fields[i], tuple):
                        ftype, fname = fields[i]
                    elif isinstance(fields[i], DerivedField):
                        ftype, fname = fields[i].name
                    else:
                        raise RuntimeError(
                            f"Cannot distinguish between fields with same name {fd}!"
                        )
                    self.fields[i] = f"{ftype}_{fname}"

        for is_first, _is_last, (i, (name, field)) in mark_ends(
                enumerate(zip(self.fields, fields))):
            if name not in exclude_fields:
                this_img = img_data[field]
                if hasattr(img_data[field], "units"):
                    if this_img.units.is_code_unit:
                        mylog.warning("Cannot generate an image with code "
                                      "units. Converting to units in CGS.")
                        funits = this_img.units.get_base_equivalent("cgs")
                    else:
                        funits = this_img.units
                    self.field_units[name] = str(funits)
                else:
                    self.field_units[name] = "dimensionless"
                mylog.info("Making a FITS image of field %s", name)
                if isinstance(this_img, ImageArray):
                    if i == 0:
                        self.shape = this_img.shape[::-1]
                    this_img = np.asarray(this_img)
                else:
                    if i == 0:
                        self.shape = this_img.shape
                    this_img = np.asarray(this_img.T)
                if is_first:
                    hdu = _astropy.pyfits.PrimaryHDU(this_img)
                else:
                    hdu = _astropy.pyfits.ImageHDU(this_img)
                hdu.name = name
                hdu.header["btype"] = name
                hdu.header["bunit"] = re.sub("()", "", self.field_units[name])
                for unit in ("length", "time", "mass", "velocity", "magnetic"):
                    if unit == "magnetic":
                        short_unit = "bf"
                    else:
                        short_unit = unit[0]
                    key = f"{short_unit}unit"
                    value = getattr(self, f"{unit}_unit")
                    if value is not None:
                        hdu.header[key] = float(value.value)
                        hdu.header.comments[key] = f"[{value.units}]"
                hdu.header["time"] = float(self.current_time.value)
                if hasattr(self, "current_redshift"):
                    hdu.header["HUBBLE"] = self.hubble_constant
                    hdu.header["REDSHIFT"] = self.current_redshift
                self.hdulist.append(hdu)

        self.dimensionality = len(self.shape)

        if wcs is None:
            w = _astropy.pywcs.WCS(header=self.hdulist[0].header,
                                   naxis=self.dimensionality)
            # FRBs and covering grids are special cases where
            # we have coordinate information, so we take advantage
            # of this and construct the WCS object
            if isinstance(img_data, FixedResolutionBuffer):
                dx = (img_data.bounds[1] -
                      img_data.bounds[0]).to_value(wcs_unit)
                dy = (img_data.bounds[3] -
                      img_data.bounds[2]).to_value(wcs_unit)
                dx /= self.shape[0]
                dy /= self.shape[1]
                xctr = 0.5 * (img_data.bounds[1] +
                              img_data.bounds[0]).to_value(wcs_unit)
                yctr = 0.5 * (img_data.bounds[3] +
                              img_data.bounds[2]).to_value(wcs_unit)
                center = [xctr, yctr]
                cdelt = [dx, dy]
            elif isinstance(img_data, YTCoveringGrid):
                cdelt = img_data.dds.to_value(wcs_unit)
                center = 0.5 * (img_data.left_edge +
                                img_data.right_edge).to_value(wcs_unit)
            else:
                # If img_data is just an array we use the width and img_ctr
                # parameters to determine the cell widths
                if not is_sequence(width):
                    width = [width] * self.dimensionality
                if isinstance(width[0], YTQuantity):
                    cdelt = [
                        wh.to_value(wcs_unit) / n
                        for wh, n in zip(width, self.shape)
                    ]
                else:
                    cdelt = [float(wh) / n for wh, n in zip(width, self.shape)]
                center = img_ctr[:self.dimensionality]
            w.wcs.crpix = 0.5 * (np.array(self.shape) + 1)
            w.wcs.crval = center
            w.wcs.cdelt = cdelt
            w.wcs.ctype = ["linear"] * self.dimensionality
            w.wcs.cunit = [wcs_unit] * self.dimensionality
            self.set_wcs(w)
        else:
            self.set_wcs(wcs)
Exemplo n.º 19
0
 def __call__(self, fields):
     fields = list(iter_fields(fields))
     rv = super().__call__(fields)
     if len(rv) == 1:
         rv = rv[0]
     return rv
Exemplo n.º 20
0
 def __call__(self, fields, non_zero=False):
     fields = list(iter_fields(fields))
     rv = super().__call__(fields, non_zero)
     if len(rv) == 1:
         rv = rv[0]
     return rv
Exemplo n.º 21
0
 def __call__(self, fields, weight):
     fields = list(iter_fields(fields))
     rv = super(WeightedAverageQuantity, self).__call__(fields, weight)
     if len(rv) == 1:
         rv = rv[0]
     return rv