コード例 #1
0
ファイル: point_sources.py プロジェクト: jzuhone/pyxsim
def make_point_sources(area, exp_time, positions, sky_center,
                       spectra, prng=None):
    r"""
    Create a new :class:`~pyxsim.event_list.EventList` which contains
    point sources.

    Parameters
    ----------
    area : float, (value, unit) tuple, :class:`~yt.units.yt_array.YTQuantity`, or :class:`~astropy.units.Quantity`
        The collecting area to determine the number of events. If units are
        not specified, it is assumed to be in cm^2.
    exp_time : float, (value, unit) tuple, :class:`~yt.units.yt_array.YTQuantity`, or :class:`~astropy.units.Quantity`
        The exposure time to determine the number of events. If units are
        not specified, it is assumed to be in seconds.
    positions : array of source positions, shape 2xN
        The positions of the point sources in RA, Dec, where N is the
        number of point sources. Coordinates should be in degrees.
    sky_center : array-like
        Center RA, Dec of the events in degrees.
    spectra : list (size N) of :class:`~soxs.spectra.Spectrum` objects
        The spectra for the point sources, where N is the number 
        of point sources. Assumed to be in the observer frame.
    prng : integer or :class:`~numpy.random.RandomState` object 
        A pseudo-random number generator. Typically will only be specified
        if you have a reason to generate the same set of random numbers, such as for a
        test. Default is to use the :mod:`numpy.random` module.
    """
    prng = parse_prng(prng)

    spectra = ensure_list(spectra)
    positions = ensure_list(positions)

    area = parse_value(area, "cm**2")
    exp_time = parse_value(exp_time, "s")

    t_exp = exp_time.value/comm.size

    x = []
    y = []
    e = []

    for pos, spectrum in zip(positions, spectra):
        eobs = spectrum.generate_energies(t_exp, area.value, prng=prng)
        ne = eobs.size
        x.append(YTArray([pos[0]] * ne, "deg"))
        y.append(YTArray([pos[1]] * ne, "deg"))
        e.append(YTArray.from_astropy(eobs))

    parameters = {"sky_center": YTArray(sky_center, "degree"),
                  "exp_time": exp_time,
                  "area": area}

    events = {}
    events["xsky"] = uconcatenate(x)
    events["ysky"] = uconcatenate(y)
    events["eobs"] = uconcatenate(e)

    return EventList(events, parameters)
コード例 #2
0
 def __init__(self, field):
     """
     This validator ensures that the output file has a given data field stored
     in it.
     """
     FieldValidator.__init__(self)
     self.fields = ensure_list(field)
コード例 #3
0
def save_field(ds, fields, field_parameters=None):
    """
    Write a single field associated with the dataset ds to the
    backup file.

    Parameters
    ----------
    ds : Dataset object
        The yt dataset that the field is associated with.
    fields : field of list of fields
        The name(s) of the field(s) to save.
    field_parameters : dictionary
        A dictionary of field parameters to set.
    """

    fields = ensure_list(fields)
    for field_name in fields:
        if isinstance(field_name, tuple):
            field_name = field_name[1]
        field_obj = ds._get_field_info(field_name)
        if field_obj.particle_type:
            print("Saving particle fields currently not supported.")
            return

    with _get_backup_file(ds) as f:
        # now save the field
        _write_fields_to_gdf(ds, f, fields, 
                             particle_type_name="dark_matter",
                             field_parameters=field_parameters)
コード例 #4
0
    def __init__(self, data_source, x_field, y_fields,
                 weight_field="cell_mass", n_bins=64,
                 accumulation=False, fractional=False,
                 label=None, plot_spec=None,
                 x_log=None, y_log=None):

        if x_log is None:
            logs = None
        else:
            logs = {x_field:x_log}

        profiles = [create_profile(data_source, [x_field],
                                   n_bins=[n_bins],
                                   fields=ensure_list(y_fields),
                                   weight_field=weight_field,
                                   accumulation=accumulation,
                                   fractional=fractional,
                                   logs=logs)]

        if plot_spec is None:
            plot_spec = [dict() for p in profiles]
        if not isinstance(plot_spec, list):
            plot_spec = [plot_spec.copy() for p in profiles]

        ProfilePlot._initialize_instance(self, profiles, label, plot_spec, y_log)
コード例 #5
0
    def __init__(self, name, function, units=None,
                 take_log=True, validators=None,
                 particle_type=False, vector_field=False, display_field=True,
                 not_in_all=False, display_name=None, output_units = None):
        self.name = name
        self.take_log = take_log
        self.display_name = display_name
        self.not_in_all = not_in_all
        self.display_field = display_field
        self.particle_type = particle_type
        self.vector_field = vector_field
        if output_units is None: output_units = units
        self.output_units = output_units

        self._function = function

        if validators:
            self.validators = ensure_list(validators)
        else:
            self.validators = []

        # handle units
        if units is None:
            self.units = ''
        elif isinstance(units, str):
            if units.lower() == 'auto':
                self.units = None
            else:
                self.units = units
        elif isinstance(units, Unit):
            self.units = str(units)
        else:
            raise FieldUnitsError("Cannot handle units '%s' (type %s)." \
                                  "Please provide a string or Unit " \
                                  "object." % (units, type(units)) )
コード例 #6
0
 def __init__(self, ds, axis, fields, center="c", width=None, image_res=None, **kwargs):
     fields = ensure_list(fields)
     axis = fix_axis(axis, ds)
     center, dcenter = ds.coordinates.sanitize_center(center, axis)
     slc = ds.slice(axis, center[axis], **kwargs)
     w, frb = construct_image(ds, axis, slc, dcenter, width=width, image_res=image_res)
     super(FITSSlice, self).__init__(frb, fields=fields, wcs=w)
コード例 #7
0
 def __init__(self, ds, normal, fields, center='c', width=None, image_res=512,
              north_vector=None):
     fields = ensure_list(fields)
     center, dcenter = ds.coordinates.sanitize_center(center, 4)
     cut = ds.cutting(normal, center, north_vector=north_vector)
     center = ds.arr([0.0] * 2, 'code_length')
     w, frb = construct_image(ds, normal, cut, center, width=width, image_res=image_res)
     super(FITSOffAxisSlice, self).__init__(frb, fields=fields, wcs=w)
コード例 #8
0
 def get_dependencies(self, fields):
     deps = []
     fi = self.ds.field_info
     for field in fields:
         if any(getattr(v,"ghost_zones", 0) > 0 for v in
                fi[field].validators): continue
         deps += ensure_list(fi[field].get_dependencies(ds=self.ds).requested)
     return list(set(deps))
コード例 #9
0
 def __init__(self, ds, axis, fields, center="c", width=None,
              weight_field=None, image_res=None, **kwargs):
     fields = ensure_list(fields)
     axis = fix_axis(axis, ds)
     center, dcenter = ds.coordinates.sanitize_center(center, axis)
     prj = ds.proj(fields[0], axis, weight_field=weight_field, **kwargs)
     w, frb = construct_image(ds, axis, prj, dcenter, width=width, image_res=image_res)
     super(FITSProjection, self).__init__(frb, fields=fields, wcs=w)
コード例 #10
0
def write_to_gdf(ds, gdf_path, fields=None, 
                 data_author=None, data_comment=None,
                 dataset_units=None, particle_type_name="dark_matter",
                 clobber=False):
    """
    Write a dataset to the given path in the Grid Data Format.

    Parameters
    ----------
    ds : Dataset object
        The yt data to write out.
    gdf_path : string
        The path of the file to output.
    fields : field or list of fields
        The fields(s) to write out. If None, defaults to 
        ds.field_list.
    data_author : string, optional
        The name of the author who wrote the data. Default: None.
    data_comment : string, optional
        A descriptive comment. Default: None.
    dataset_units : dictionary, optional
        A dictionary of (value, unit) tuples to set the default units
        of the dataset. Keys can be:
            "length_unit"
            "time_unit"
            "mass_unit"
            "velocity_unit"
            "magnetic_unit"
        If not specified, these will carry over from the parent
        dataset.
    particle_type_name : string, optional
        The particle type of the particles in the dataset. Default: "dark_matter"
    clobber : boolean, optional
        Whether or not to clobber an already existing file. If False, attempting
        to overwrite an existing file will result in an exception.

    Examples
    --------
    >>> dataset_units = {"length_unit":(1.0,"Mpc"),
    ...                  "time_unit":(1.0,"Myr")}
    >>> write_to_gdf(ds, "clumps.h5", data_author="John ZuHone",
    ...              dataset_units=dataset_units,
    ...              data_comment="My Really Cool Dataset", clobber=True)
    """

    if fields is None:
        fields = ds.field_list

    fields = ensure_list(fields)
    
    with _create_new_gdf(ds, gdf_path, data_author, 
                         data_comment,
                         dataset_units=dataset_units,
                         particle_type_name=particle_type_name, 
                         clobber=clobber) as f:

        # now add the fields one-by-one
        _write_fields_to_gdf(ds, f, fields, particle_type_name)
コード例 #11
0
 def from_sizes(cls, sizes):
     sizes = ensure_list(sizes)
     pool = cls()
     rank = pool.comm.rank
     for i,size in enumerate(sizes):
         if iterable(size):
             size, name = size
         else:
             name = "workgroup_%02i" % i
         pool.add_workgroup(size, name = name)
     for wg in pool.workgroups:
         if rank in wg.ranks: workgroup = wg
     return pool, workgroup
コード例 #12
0
def sanitize_label(label, nprofiles):
    label = ensure_list(label)
    
    if len(label) == 1:
        label = label * nprofiles
    
    if len(label) != nprofiles:
        raise RuntimeError("Number of labels must match number of profiles")

    for l in label:
        if l is not None and not isinstance(l, string_types):
            raise RuntimeError("All labels must be None or a string")

    return label
コード例 #13
0
    def __init__(self, data_source, x_field, y_field, z_fields,
                 weight_field="cell_mass", x_bins=128, y_bins=128,
                 accumulation=False, fractional=False,
                 fontsize=18, figure_size=8.0):

        profile = create_profile(
            data_source,
            [x_field, y_field],
            ensure_list(z_fields),
            n_bins=[x_bins, y_bins],
            weight_field=weight_field,
            accumulation=accumulation,
            fractional=fractional)

        type(self)._initialize_instance(self, data_source, profile, fontsize,
                                        figure_size)
コード例 #14
0
 def _initialize_instance(cls, obj, profiles, labels, plot_specs, y_log):
     obj.profiles = ensure_list(profiles)
     obj.x_log = None
     obj.y_log = {}
     if y_log is not None:
         for field, log in y_log.items():
             field, = obj.profiles[0].data_source._determine_fields([field])
             obj.y_log[field] = log
     obj.y_title = {}
     obj.label = sanitize_label(labels, len(obj.profiles))
     if plot_specs is None:
         plot_specs = [dict() for p in obj.profiles]
     obj.plot_spec = plot_specs
     obj.figures = FigureContainer()
     obj.axes = AxesContainer(obj.figures)
     obj._setup_plots()
     return obj
コード例 #15
0
    def set_zlim(self, field, zmin, zmax, dynamic_range=None):
        """set the scale of the colormap

        Parameters
        ----------
        field : string
            the field to set a colormap scale
            if field == 'all', applies to all plots.
        zmin : float
            the new minimum of the colormap scale. If 'min', will
            set to the minimum value in the current view.
        zmax : float
            the new maximum of the colormap scale. If 'max', will
            set to the maximum value in the current view.

        Other Parameters
        ----------------
        dynamic_range : float (default: None)
            The dynamic range of the image.
            If zmin == None, will set zmin = zmax / dynamic_range
            If zmax == None, will set zmax = zmin * dynamic_range
            When dynamic_range is specified, defaults to setting
            zmin = zmax / dynamic_range.

        """
        if field is 'all':
            fields = list(self.plots.keys())
        else:
            fields = ensure_list(field)
        for field in self.data_source._determine_fields(fields):
            myzmin = zmin
            myzmax = zmax
            if zmin == 'min':
                myzmin = self.plots[field].image._A.min()
            if zmax == 'max':
                myzmax = self.plots[field].image._A.max()
            if dynamic_range is not None:
                if zmax is None:
                    myzmax = myzmin * dynamic_range
                else:
                    myzmin = myzmax / dynamic_range

            self.plots[field].zmin = myzmin
            self.plots[field].zmax = myzmax
        return self
コード例 #16
0
 def __init__(self, ds, normal, fields, center='c', width=(1.0, 'unitary'),
              weight_field=None, image_res=512, depth_res=256,
              north_vector=None, depth=(1.0,"unitary"), no_ghost=False, method='integrate'):
     fields = ensure_list(fields)
     center, dcenter = ds.coordinates.sanitize_center(center, 4)
     buf = {}
     width = ds.coordinates.sanitize_width(normal, width, depth)
     wd = tuple(el.in_units('code_length').v for el in width)
     if not iterable(image_res):
         image_res = (image_res, image_res)
     res = (image_res[0], image_res[1], depth_res)
     for field in fields:
         buf[field] = off_axis_projection(ds, center, normal, wd, res, field,
                                          no_ghost=no_ghost, north_vector=north_vector,
                                          method=method, weight=weight_field).swapaxes(0, 1)
     center = ds.arr([0.0] * 2, 'code_length')
     w, not_an_frb = construct_image(ds, normal, buf, center, width=width, image_res=image_res)
     super(FITSOffAxisProjection, self).__init__(buf, fields=fields, wcs=w)
コード例 #17
0
ファイル: fits_image.py プロジェクト: DeovratPrasad/yt_ap
 def __init__(self,
              ds,
              axis,
              fields,
              center="c",
              width=None,
              image_res=None,
              **kwargs):
     fields = ensure_list(fields)
     axis = fix_axis(axis, ds)
     center, dcenter = ds.coordinates.sanitize_center(center, axis)
     slc = ds.slice(axis, center[axis], **kwargs)
     w, frb = construct_image(ds,
                              axis,
                              slc,
                              dcenter,
                              width=width,
                              image_res=image_res)
     super(FITSSlice, self).__init__(frb, fields=fields, wcs=w)
コード例 #18
0
 def __init__(self,
              ds,
              normal,
              fields,
              image_res=512,
              center='c',
              width=None,
              north_vector=None,
              length_unit=None):
     fields = ensure_list(fields)
     center, dcenter = ds.coordinates.sanitize_center(center, 4)
     cut = ds.cutting(normal, center, north_vector=north_vector)
     center = ds.arr([0.0] * 2, 'code_length')
     w, frb, lunit = construct_image(ds, normal, cut, center, image_res,
                                     width, length_unit)
     super(FITSOffAxisSlice, self).__init__(frb,
                                            fields=fields,
                                            length_unit=lunit,
                                            wcs=w)
コード例 #19
0
ファイル: fits_image.py プロジェクト: DeovratPrasad/yt_ap
 def __init__(self,
              ds,
              normal,
              fields,
              center='c',
              width=(1.0, 'unitary'),
              weight_field=None,
              image_res=512,
              data_source=None,
              north_vector=None,
              depth=(1.0, "unitary"),
              method='integrate'):
     fields = ensure_list(fields)
     center, dcenter = ds.coordinates.sanitize_center(center, 4)
     buf = {}
     width = ds.coordinates.sanitize_width(normal, width, depth)
     wd = tuple(el.in_units('code_length').v for el in width)
     if not iterable(image_res):
         image_res = (image_res, image_res)
     res = (image_res[0], image_res[1])
     if data_source is None:
         source = ds
     else:
         source = data_source
     for field in fields:
         buf[field] = off_axis_projection(source,
                                          center,
                                          normal,
                                          wd,
                                          res,
                                          field,
                                          north_vector=north_vector,
                                          method=method,
                                          weight=weight_field).swapaxes(
                                              0, 1)
     center = ds.arr([0.0] * 2, 'code_length')
     w, not_an_frb = construct_image(ds,
                                     normal,
                                     buf,
                                     center,
                                     width=width,
                                     image_res=image_res)
     super(FITSOffAxisProjection, self).__init__(buf, fields=fields, wcs=w)
コード例 #20
0
    def get_log(self, field):
        """get the transform type of a field.

        Parameters
        ----------
        field : string
            the field to get a transform
            if field == 'all', applies to all plots.

        """
        # devnote : accepts_all_fields decorator is not applicable here because the return variable isn't self
        log = {}
        if field == "all":
            fields = list(self.plots.keys())
        else:
            fields = ensure_list(field)
        for field in self.data_source._determine_fields(fields):
            log[field] = self._field_transform[field] == log_transform
        return log
コード例 #21
0
    def __init__(self, data_source, x_field, y_field, z_fields,
                 weight_field="cell_mass", x_bins=128, y_bins=128,
                 accumulation=False, fractional=False,
                 fontsize=18, figure_size=8.0):

        if isinstance(data_source.ds, YTProfileDataset):
            profile = data_source.ds.profile
        else:
            profile = create_profile(
                data_source,
                [x_field, y_field],
                ensure_list(z_fields),
                n_bins=[x_bins, y_bins],
                weight_field=weight_field,
                accumulation=accumulation,
                fractional=fractional)

        type(self)._initialize_instance(self, data_source, profile, fontsize,
                                        figure_size)
コード例 #22
0
    def __init__(self, data_source, x_field, y_field, z_fields=None,
                 color='b', x_bins=800, y_bins=800, weight_field=None,
                 deposition='ngp', fontsize=18, figure_size=8.0):

        # if no z_fields are passed in, use a constant color
        if z_fields is None:
            self.use_cbar = False
            self.splat_color = color
            z_fields = ['particle_ones']

        profile = create_profile(
            data_source,
            [x_field, y_field],
            ensure_list(z_fields),
            n_bins=[x_bins, y_bins],
            weight_field=weight_field,
            deposition=deposition)

        type(self)._initialize_instance(self, data_source, profile, fontsize,
                                        figure_size)
コード例 #23
0
 def __init__(self,
              ds,
              axis,
              fields,
              center="c",
              width=None,
              weight_field=None,
              image_res=None,
              **kwargs):
     fields = ensure_list(fields)
     axis = fix_axis(axis, ds)
     center, dcenter = ds.coordinates.sanitize_center(center, axis)
     prj = ds.proj(fields[0], axis, weight_field=weight_field, **kwargs)
     w, frb = construct_image(ds,
                              axis,
                              prj,
                              dcenter,
                              width=width,
                              image_res=image_res)
     super(FITSProjection, self).__init__(frb, fields=fields, wcs=w)
コード例 #24
0
 def __init__(self,
              data_source,
              conditionals,
              ds=None,
              field_parameters=None,
              base_object=None):
     if base_object is not None:
         # passing base_object explicitly has been deprecated,
         # but we handle it here for backward compatibility
         if data_source is not None:
             raise RuntimeError(
                 "Cannot use both base_object and data_source")
         data_source = base_object
     super(YTCutRegion, self).__init__(data_source.center,
                                       ds,
                                       field_parameters,
                                       data_source=data_source)
     self.conditionals = ensure_list(conditionals)
     self.base_object = data_source
     self._selector = None
コード例 #25
0
    def __init__(self, data_source, x_field, y_field, z_fields=None,
                 color='b', x_bins=800, y_bins=800, weight_field=None,
                 deposition='ngp', fontsize=18, figure_size=8.0):

        # if no z_fields are passed in, use a constant color
        if z_fields is None:
            self.use_cbar = False
            self.splat_color = color
            z_fields = ['particle_ones']

        profile = create_profile(
            data_source,
            [x_field, y_field],
            ensure_list(z_fields),
            n_bins=[x_bins, y_bins],
            weight_field=weight_field,
            deposition=deposition)

        type(self)._initialize_instance(self, data_source, profile, fontsize,
                                        figure_size)
コード例 #26
0
    def export_fits(self, filename, fields=None, overwrite=False,
                    other_keys=None, units="cm", **kwargs):
        r"""Export a set of pixelized fields to a FITS file.

        This will export a set of FITS images of either the fields specified
        or all the fields already in the object.

        Parameters
        ----------
        filename : string
            The name of the FITS file to be written.
        fields : list of strings
            These fields will be pixelized and output. If "None", the keys of the
            FRB will be used.
        overwrite : boolean
            If the file exists, this governs whether we will overwrite.
        other_keys : dictionary, optional
            A set of header keys and values to write into the FITS header.
        units : string, optional
            the length units that the coordinates are written in, default 'cm'.
        """

        from yt.visualization.fits_image import FITSImageData

        if fields is None:
            fields = list(self.data.keys())
        else:
            fields = ensure_list(fields)

        if len(fields) == 0:
            raise RuntimeError(
                "No fields to export. Either pass a field or list of fields to "
                "export_fits or access a field from the fixed resolution buffer "
                "object."
            )

        fib = FITSImageData(self, fields=fields, units=units)
        if other_keys is not None:
            for k,v in other_keys.items():
                fib.update_all_headers(k,v)
        fib.writeto(filename, overwrite=overwrite, **kwargs)
コード例 #27
0
    def __init__(
        self,
        data_source,
        conditionals,
        ds=None,
        field_parameters=None,
        base_object=None,
        locals=None,
    ):
        if locals is None:
            locals = {}
        validate_object(data_source, YTSelectionContainer)
        validate_iterable(conditionals)
        for condition in conditionals:
            validate_object(condition, str)
        validate_object(ds, Dataset)
        validate_object(field_parameters, dict)
        validate_object(base_object, YTSelectionContainer)
        if base_object is not None:
            # passing base_object explicitly has been deprecated,
            # but we handle it here for backward compatibility
            if data_source is not None:
                raise RuntimeError(
                    "Cannot use both base_object and data_source")
            data_source = base_object

        self.conditionals = ensure_list(conditionals)
        if isinstance(data_source, YTCutRegion):
            # If the source is also a cut region, add its conditionals
            # and set the source to be its source.
            # Preserve order of conditionals.
            self.conditionals = data_source.conditionals + self.conditionals
            data_source = data_source.base_object

        super(YTCutRegion, self).__init__(data_source.center,
                                          ds,
                                          field_parameters,
                                          data_source=data_source)
        self.base_object = data_source
        self.locals = locals
        self._selector = None
コード例 #28
0
ファイル: profile_plotter.py プロジェクト: nmearl/yt
 def _initialize_instance(cls, obj, profiles, labels, plot_specs, y_log):
     from matplotlib.font_manager import FontProperties
     obj._font_properties = FontProperties(family='stixgeneral', size=18)
     obj._font_color = None
     obj.profiles = ensure_list(profiles)
     obj.x_log = None
     obj.y_log = {}
     if y_log is not None:
         for field, log in y_log.items():
             field, = obj.profiles[0].data_source._determine_fields([field])
             obj.y_log[field] = log
     obj.y_title = {}
     obj.label = sanitize_label(labels, len(obj.profiles))
     if plot_specs is None:
         plot_specs = [dict() for p in obj.profiles]
     obj.plot_spec = plot_specs
     obj.plots = PlotContainer()
     obj.figures = FigureContainer(obj.plots)
     obj.axes = AxesContainer(obj.plots)
     obj._setup_plots()
     return obj
コード例 #29
0
ファイル: selection_data_containers.py プロジェクト: tlmnb/yt
 def __init__(self, data_source, conditionals, ds=None,
              field_parameters=None, base_object=None):
     validate_object(data_source, YTSelectionContainer)
     validate_iterable(conditionals)
     for condition in conditionals:
         validate_object(condition, string_types)
     validate_object(ds, Dataset)
     validate_object(field_parameters, dict)
     validate_object(base_object, YTSelectionContainer)
     if base_object is not None:
         # passing base_object explicitly has been deprecated,
         # but we handle it here for backward compatibility
         if data_source is not None:
             raise RuntimeError(
                 "Cannot use both base_object and data_source")
         data_source=base_object
     super(YTCutRegion, self).__init__(
         data_source.center, ds, field_parameters, data_source=data_source)
     self.conditionals = ensure_list(conditionals)
     self.base_object = data_source
     self._selector = None
コード例 #30
0
    def __init__(self,
                 name,
                 function,
                 units=None,
                 take_log=True,
                 validators=None,
                 particle_type=False,
                 vector_field=False,
                 display_field=True,
                 not_in_all=False,
                 display_name=None,
                 output_units=None):
        self.name = name
        self.take_log = take_log
        self.display_name = display_name
        self.not_in_all = not_in_all
        self.display_field = display_field
        self.particle_type = particle_type
        self.vector_field = vector_field
        if output_units is None: output_units = units
        self.output_units = output_units

        self._function = function

        if validators:
            self.validators = ensure_list(validators)
        else:
            self.validators = []

        # handle units
        if units is None:
            self.units = ""
        elif isinstance(units, str):
            self.units = units
        elif isinstance(units, Unit):
            self.units = str(units)
        else:
            raise FieldUnitsError("Cannot handle units '%s' (type %s)." \
                                  "Please provide a string or Unit " \
                                  "object." % (units, type(units)) )
コード例 #31
0
ファイル: profile_plotter.py プロジェクト: nmearl/yt
    def set_ylim(self, field, ymin=None, ymax=None):
        """Sets the plot limits for the specified field we are binning.

        Parameters
        ----------

        field : string or field tuple

        The field that we want to adjust the plot limits for.
        
        ymin : float or None
          The new y minimum.  Defaults to None, which leaves the ymin
          unchanged.

        ymax : float or None
          The new y maximum.  Defaults to None, which leaves the ymax
          unchanged.

        Examples
        --------

        >>> import yt
        >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030')
        >>> pp = yt.ProfilePlot(ds.all_data(), 'density', ['temperature', 'x-velocity'])
        >>> pp.set_ylim('temperature', 1e4, 1e6)
        >>> pp.save()

        """
        if field is 'all':
            fields = list(self.axes.keys())
        else:
            fields = ensure_list(field)
        for profile in self.profiles:
            for field in profile.data_source._determine_fields(fields):
                if field in profile.field_map:
                    field = profile.field_map[field]
                self.axes.ylim[field] = (ymin, ymax)
                # Continue on to the next profile.
                break
        return self
コード例 #32
0
 def get_data(self, fields):
     mylog.info("Getting %s using ParticleIO" % str(fields))
     fields = ensure_list(fields)
     if not self.ds.index.io._particle_reader:
         mylog.info("not self.ds.index.io._particle_reader")
         return self.source.get_data(fields)
     rtype, args = self._get_args()
     count_list, grid_list = [], []
     for grid in self.source._grids:
         if grid.NumberOfParticles == 0: continue
         grid_list.append(grid)
         if self.source._is_fully_enclosed(grid):
             count_list.append(grid.NumberOfParticles)
         else:
             count_list.append(-1)
     # region type, left_edge, right_edge, periodic, grid_list
     fields_to_read = []
     conv_factors = []
     for field in fields:
         f = self.ds.field_info[field]
         to_add = f.get_dependencies(ds=self.ds).requested
         to_add = list(np.unique(to_add))
         if len(to_add) != 1: raise KeyError
         fields_to_read += to_add
         if f._particle_convert_function is None:
             func = f._convert_function
         else:
             func = f.particle_convert
         func = particle_converter(func)
         conv_factors.append(
             np.fromiter((func(g) for g in grid_list),
                         count=len(grid_list),
                         dtype='float64'))
     conv_factors = np.array(conv_factors).transpose()
     self.conv_factors = conv_factors
     rvs = self.ds.index.io._read_particles(fields_to_read, rtype, args,
                                            grid_list, count_list,
                                            conv_factors)
     for [n, v] in zip(fields, rvs):
         self.source.field_data[n] = v
コード例 #33
0
    def hide_colorbar(self, field=None):
        """
        Hides the colorbar for a plot and updates the size of the
        plot accordingly.  Defaults to operating on all fields for a
        PlotContainer object.

        Parameters
        ----------

        field : string, field tuple, or list of strings or field tuples (optional)
            The name of the field(s) that we want to hide the colorbar. If None
            is provided, will default to using all fields available for this
            object.

        Examples
        --------

        This will save an image with no colorbar.

        >>> import yt
        >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030')
        >>> s = SlicePlot(ds, 2, 'density', 'c', (20, 'kpc'))
        >>> s.hide_colorbar()
        >>> s.save()

        This will save an image with no axis or colorbar.

        >>> import yt
        >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030')
        >>> s = SlicePlot(ds, 2, 'density', 'c', (20, 'kpc'))
        >>> s.hide_axes()
        >>> s.hide_colorbar()
        >>> s.save()
        """
        if field is None:
            field = self.fields
        field = ensure_list(field)
        for f in field:
            self.plots[f].hide_colorbar()
        return self
コード例 #34
0
    def set_ylim(self, field, ymin=None, ymax=None):
        """Sets the plot limits for the specified field we are binning.

        Parameters
        ----------

        field : string or field tuple

        The field that we want to adjust the plot limits for.
        
        ymin : float or None
          The new y minimum.  Defaults to None, which leaves the ymin
          unchanged.

        ymax : float or None
          The new y maximum.  Defaults to None, which leaves the ymax
          unchanged.

        Examples
        --------

        >>> import yt
        >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030')
        >>> pp = yt.ProfilePlot(ds.all_data(), 'density', ['temperature', 'x-velocity'])
        >>> pp.set_ylim('temperature', 1e4, 1e6)
        >>> pp.save()

        """
        if field is 'all':
            fields = self.axes.keys()
        else:
            fields = ensure_list(field)
        for profile in self.profiles:
            for field in profile.data_source._determine_fields(fields):
                if field in profile.field_map:
                    field = profile.field_map[field]
                self.axes.ylim[field] = (ymin, ymax)
                # Continue on to the next profile.
                break
        return self
コード例 #35
0
ファイル: time_series.py プロジェクト: zackcd/yt
 def eval(self, tasks, obj=None):
     tasks = ensure_list(tasks)
     return_values = {}
     for store, ds in self.piter(return_values):
         store.result = []
         for task in tasks:
             try:
                 style = inspect.getargspec(task.eval)[0][1]
                 if style == 'ds':
                     arg = ds
                 elif style == 'data_object':
                     if obj is None:
                         obj = DatasetSeriesObject(self, "all_data")
                     arg = obj.get(ds)
                 rv = task.eval(arg)
             # We catch and store YT-originating exceptions
             # This fixes the standard problem of having a sphere that's too
             # small.
             except YTException:
                 pass
             store.result.append(rv)
     return [v for k, v in sorted(return_values.items())]
コード例 #36
0
ファイル: profile_plotter.py プロジェクト: lindsayad/yt
    def __init__(self,
                 data_source,
                 x_field,
                 y_fields,
                 weight_field="cell_mass",
                 n_bins=64,
                 accumulation=False,
                 fractional=False,
                 label=None,
                 plot_spec=None,
                 x_log=None,
                 y_log=None):

        if x_log is None:
            logs = None
        else:
            logs = {x_field: x_log}

        if isinstance(data_source.ds, YTProfileDataset):
            profiles = [data_source.ds.profile]
        else:
            profiles = [
                create_profile(data_source, [x_field],
                               n_bins=[n_bins],
                               fields=ensure_list(y_fields),
                               weight_field=weight_field,
                               accumulation=accumulation,
                               fractional=fractional,
                               logs=logs)
            ]

        if plot_spec is None:
            plot_spec = [dict() for p in profiles]
        if not isinstance(plot_spec, list):
            plot_spec = [plot_spec.copy() for p in profiles]

        ProfilePlot._initialize_instance(self, profiles, label, plot_spec,
                                         y_log)
コード例 #37
0
    def _find_field_values_at_points(self, fields, coords):
        r"""Find the value of fields at a set of coordinates.

        Returns the values [field1, field2,...] of the fields at the given
        (x, y, z) points. Returns a numpy array of field values cross coords
        """
        coords = self.ds.arr(ensure_numpy_array(coords), 'code_length')
        grids = self._find_points(coords[:, 0], coords[:, 1], coords[:, 2])[0]
        fields = ensure_list(fields)
        mark = np.zeros(3, dtype=np.int)
        out = []

        # create point -> grid mapping
        grid_index = {}
        for coord_index, grid in enumerate(grids):
            if grid not in grid_index:
                grid_index[grid] = []
            grid_index[grid].append(coord_index)

        out = []
        for field in fields:
            funit = self.ds._get_field_info(field).units
            out.append(self.ds.arr(np.empty((len(coords))), funit))

        for grid in grid_index:
            cellwidth = (grid.RightEdge -
                         grid.LeftEdge) / grid.ActiveDimensions
            for field_index, field in enumerate(fields):
                for coord_index in grid_index[grid]:
                    mark = ((coords[coord_index, :] - grid.LeftEdge) /
                            cellwidth)
                    mark = np.array(mark, dtype='int64')
                    out[field_index][coord_index] = \
                        grid[field][mark[0], mark[1], mark[2]]
        if len(fields) == 1:
            return out[0]
        return out
コード例 #38
0
    def to_pw(self, fields=None, center="c", width=None, axes_unit=None):
        r"""Create a :class:`~yt.visualization.plot_window.PWViewerMPL` from this
        object.

        This is a bare-bones mechanism of creating a plot window from this
        object, which can then be moved around, zoomed, and on and on.  All
        behavior of the plot window is relegated to that routine.
        """
        normal = self.normal
        center = self.center
        self.fields = ensure_list(fields) + [
            k for k in self.field_data.keys() if k not in self._key_fields
        ]
        from yt.visualization.fixed_resolution import FixedResolutionBuffer
        from yt.visualization.plot_window import (
            PWViewerMPL,
            get_oblique_window_parameters,
        )

        (bounds, center_rot) = get_oblique_window_parameters(
            normal, center, width, self.ds
        )
        pw = PWViewerMPL(
            self,
            bounds,
            fields=self.fields,
            origin="center-window",
            periodic=False,
            oblique=True,
            frb_generator=FixedResolutionBuffer,
            plot_type="OffAxisSlice",
        )
        if axes_unit is not None:
            pw.set_axes_unit(axes_unit)
        pw._setup_plots()
        return pw
コード例 #39
0
 def __init__(self, name, sub_types):
     self.name = name
     self.sub_types = ensure_list(sub_types)
コード例 #40
0
ファイル: derived_field.py プロジェクト: cgyurgyik/yt
    def __init__(
        self,
        name,
        sampling_type,
        function,
        units=None,
        take_log=True,
        validators=None,
        particle_type=None,
        vector_field=False,
        display_field=True,
        not_in_all=False,
        display_name=None,
        output_units=None,
        dimensions=None,
        ds=None,
        nodal_flag=None,
    ):
        self.name = name
        self.take_log = take_log
        self.display_name = display_name
        self.not_in_all = not_in_all
        self.display_field = display_field
        if particle_type:
            warnings.warn(
                "particle_type for derived fields "
                "has been replaced with sampling_type = 'particle'",
                DeprecationWarning,
            )
            sampling_type = "particle"
        self.sampling_type = sampling_type
        self.vector_field = vector_field
        self.ds = ds

        if self.ds is not None:
            self._ionization_label_format = self.ds._ionization_label_format
        else:
            self._ionization_label_format = "roman_numeral"

        if nodal_flag is None:
            self.nodal_flag = [0, 0, 0]
        else:
            self.nodal_flag = nodal_flag

        self._function = function

        if validators:
            self.validators = ensure_list(validators)
        else:
            self.validators = []

        # handle units
        if units is None:
            self.units = ""
        elif isinstance(units, str):
            if units.lower() == "auto":
                if dimensions is None:
                    raise RuntimeError(
                        "To set units='auto', please specify the dimensions "
                        "of the field with dimensions=<dimensions of field>!")
                self.units = None
            else:
                self.units = units
        elif isinstance(units, Unit):
            self.units = str(units)
        elif isinstance(units, bytes):
            self.units = units.decode("utf-8")
        else:
            raise FieldUnitsError("Cannot handle units '%s' (type %s)."
                                  "Please provide a string or Unit "
                                  "object." % (units, type(units)))
        if output_units is None:
            output_units = self.units
        self.output_units = output_units

        if isinstance(dimensions, str):
            dimensions = getattr(ytdims, dimensions)
        self.dimensions = dimensions
コード例 #41
0
ファイル: derived_field.py プロジェクト: cgyurgyik/yt
 def __init__(self, prop):
     """
     This validator ensures that the data object has a given python attribute.
     """
     FieldValidator.__init__(self)
     self.prop = ensure_list(prop)
コード例 #42
0
ファイル: unions.py プロジェクト: caicairay/yt
 def __init__(self, name, sub_types):
     self.name = name
     self.sub_types = ensure_list(sub_types)
コード例 #43
0
 def run(self):
     self.result = [[p.id for p in ensure_list(g.Parent) \
         if g.Parent is not None]
         for g in self.ds.index.grids]
コード例 #44
0
ファイル: derived_quantities.py プロジェクト: aemerick/yt
 def __call__(self, fields, non_zero=False):
     fields = ensure_list(fields)
     rv = super(Extrema, self).__call__(fields, non_zero)
     if len(rv) == 1: rv = rv[0]
     return rv
コード例 #45
0
ファイル: writer.py プロジェクト: victorgabr/yt
def write_to_gdf(ds,
                 gdf_path,
                 fields=None,
                 data_author=None,
                 data_comment=None,
                 dataset_units=None,
                 particle_type_name="dark_matter",
                 clobber=False):
    """
    Write a dataset to the given path in the Grid Data Format.

    Parameters
    ----------
    ds : Dataset object
        The yt data to write out.
    gdf_path : string
        The path of the file to output.
    fields : field or list of fields
        The fields(s) to write out. If None, defaults to 
        ds.field_list.
    data_author : string, optional
        The name of the author who wrote the data. Default: None.
    data_comment : string, optional
        A descriptive comment. Default: None.
    dataset_units : dictionary, optional
        A dictionary of (value, unit) tuples to set the default units
        of the dataset. Keys can be:

        * "length_unit"
        * "time_unit"
        * "mass_unit"
        * "velocity_unit"
        * "magnetic_unit"

        If not specified, these will carry over from the parent
        dataset.
    particle_type_name : string, optional
        The particle type of the particles in the dataset. Default: "dark_matter"
    clobber : boolean, optional
        Whether or not to clobber an already existing file. If False, attempting
        to overwrite an existing file will result in an exception.

    Examples
    --------
    >>> dataset_units = {"length_unit":(1.0,"Mpc"),
    ...                  "time_unit":(1.0,"Myr")}
    >>> write_to_gdf(ds, "clumps.h5", data_author="John ZuHone",
    ...              dataset_units=dataset_units,
    ...              data_comment="My Really Cool Dataset", clobber=True)
    """

    if fields is None:
        fields = ds.field_list

    fields = ensure_list(fields)

    with _create_new_gdf(ds,
                         gdf_path,
                         data_author,
                         data_comment,
                         dataset_units=dataset_units,
                         particle_type_name=particle_type_name,
                         clobber=clobber) as f:

        # now add the fields one-by-one
        _write_fields_to_gdf(ds, f, fields, particle_type_name)
コード例 #46
0
ファイル: derived_quantities.py プロジェクト: aemerick/yt
 def __call__(self, fields, weight):
     fields = ensure_list(fields)
     rv = super(WeightedAverageQuantity, self).__call__(fields, weight)
     if len(rv) == 1: rv = rv[0]
     return rv
コード例 #47
0
    def __init__(self,
                 filename,
                 dataset_type='fits',
                 auxiliary_files=[],
                 nprocs=None,
                 storage_filename=None,
                 nan_mask=None,
                 suppress_astropy_warnings=True,
                 parameters=None,
                 units_override=None,
                 unit_system="cgs"):

        if parameters is None:
            parameters = {}
        parameters["nprocs"] = nprocs
        self.specified_parameters = parameters

        if suppress_astropy_warnings:
            warnings.filterwarnings('ignore', module="astropy", append=True)
        auxiliary_files = ensure_list(auxiliary_files)
        self.filenames = [filename] + auxiliary_files
        self.num_files = len(self.filenames)
        self.fluid_types += ("fits", )
        if nan_mask is None:
            self.nan_mask = {}
        elif isinstance(nan_mask, float):
            self.nan_mask = {"all": nan_mask}
        elif isinstance(nan_mask, dict):
            self.nan_mask = nan_mask
        self._handle = FITSFileHandler(self.filenames[0])
        if (isinstance(self.filenames[0],
                       _astropy.pyfits.hdu.image._ImageBaseHDU)
                or isinstance(self.filenames[0], _astropy.pyfits.HDUList)):
            fn = "InMemoryFITSFile_%s" % uuid.uuid4().hex
        else:
            fn = self.filenames[0]
        self._handle._fits_files.append(self._handle)
        if self.num_files > 1:
            for fits_file in auxiliary_files:
                if isinstance(fits_file,
                              _astropy.pyfits.hdu.image._ImageBaseHDU):
                    f = _astropy.pyfits.HDUList([fits_file])
                elif isinstance(fits_file, _astropy.pyfits.HDUList):
                    f = fits_file
                else:
                    if os.path.exists(fits_file):
                        fn = fits_file
                    else:
                        fn = os.path.join(ytcfg.get("yt", "test_data_dir"),
                                          fits_file)
                    f = _astropy.pyfits.open(fn,
                                             memmap=True,
                                             do_not_scale_image_data=True,
                                             ignore_blank=True)
                self._handle._fits_files.append(f)

        self.refine_by = 2

        Dataset.__init__(self,
                         fn,
                         dataset_type,
                         units_override=units_override,
                         unit_system=unit_system)
        self.storage_filename = storage_filename
コード例 #48
0
    def __init__(
        self,
        data,
        fields=None,
        length_unit=None,
        width=None,
        img_ctr=None,
        wcs=None,
        current_time=None,
        time_unit=None,
        mass_unit=None,
        velocity_unit=None,
        magnetic_unit=None,
        ds=None,
        unit_header=None,
        **kwargs,
    ):
        r""" Initialize a FITSImageData object.

        FITSImageData contains a collection of FITS ImageHDU instances and
        WCS information, along with units for each of the images. FITSImageData
        instances can be constructed from ImageArrays, NumPy arrays, dicts
        of such arrays, FixedResolutionBuffers, and YTCoveringGrids. The latter
        two are the most powerful because WCS information can be constructed
        automatically from their coordinates.

        Parameters
        ----------
        data : FixedResolutionBuffer or a YTCoveringGrid. Or, an
            ImageArray, an numpy.ndarray, or dict of such arrays
            The data to be made into a FITS image or images.
        fields : single string or list of strings, optional
            The field names for the data. If *fields* is none and *data* has
            keys, it will use these for the fields. If *data* is just a
            single array one field name must be specified.
        length_unit : string
            The units of the WCS coordinates and the length unit of the file.
            Defaults to the length unit of the dataset, if there is one, or
            "cm" if there is not.
        width : float or YTQuantity
            The width of the image. Either a single value or iterable of values.
            If a float, assumed to be in *units*. Only used if this information
            is not already provided by *data*.
        img_ctr : array_like or YTArray
            The center coordinates of the image. If a list or NumPy array,
            it is assumed to be in *units*. Only used if this information
            is not already provided by *data*.
        wcs : `~astropy.wcs.WCS` instance, optional
            Supply an AstroPy WCS instance. Will override automatic WCS
            creation from FixedResolutionBuffers and YTCoveringGrids.
        current_time : float, tuple, or YTQuantity, optional
            The current time of the image(s). If not specified, one will
            be set from the dataset if there is one. If a float, it will
            be assumed to be in *time_unit* units.
        time_unit : string
            The default time units of the file. Defaults to "s".
        mass_unit : string
            The default time units of the file. Defaults to "g".
        velocity_unit : string
            The default velocity units of the file. Defaults to "cm/s".
        magnetic_unit : string
            The default magnetic units of the file. Defaults to "gauss".
        ds : `~yt.static_output.Dataset` instance, optional
            The dataset associated with the image(s), typically used
            to transfer metadata to the header(s). Does not need to be
            specified if *data* has a dataset as an attribute.

        Examples
        --------

        >>> # This example uses a FRB.
        >>> ds = load("sloshing_nomag2_hdf5_plt_cnt_0150")
        >>> prj = ds.proj(2, "kT", weight_field="density")
        >>> frb = prj.to_frb((0.5, "Mpc"), 800)
        >>> # This example just uses the FRB and puts the coords in kpc.
        >>> f_kpc = FITSImageData(frb, fields="kT", length_unit="kpc",
        ...                       time_unit=(1.0, "Gyr"))
        >>> # This example specifies a specific WCS.
        >>> from astropy.wcs import WCS
        >>> w = WCS(naxis=self.dimensionality)
        >>> w.wcs.crval = [30., 45.] # RA, Dec in degrees
        >>> w.wcs.cunit = ["deg"]*2
        >>> nx, ny = 800, 800
        >>> w.wcs.crpix = [0.5*(nx+1), 0.5*(ny+1)]
        >>> w.wcs.ctype = ["RA---TAN","DEC--TAN"]
        >>> scale = 1./3600. # One arcsec per pixel
        >>> w.wcs.cdelt = [-scale, scale]
        >>> f_deg = FITSImageData(frb, fields="kT", wcs=w)
        >>> f_deg.writeto("temp.fits")
        """

        if fields is not None:
            fields = ensure_list(fields)

        if "units" in kwargs:
            issue_deprecation_warning(
                "The 'units' keyword argument has been replaced "
                "by the 'length_unit' keyword argument and the "
                "former has been deprecated. Setting 'length_unit' "
                "to 'units'.")
            length_unit = kwargs.pop("units")

        if ds is None:
            ds = getattr(data, "ds", None)

        self.fields = []
        self.field_units = {}

        if unit_header is None:
            self._set_units(ds, [
                length_unit, mass_unit, time_unit, velocity_unit, magnetic_unit
            ])
        else:
            self._set_units_from_header(unit_header)

        wcs_unit = str(self.length_unit.units)

        self._fix_current_time(ds, current_time)

        if width is None:
            width = 1.0
        if isinstance(width, tuple):
            if ds is None:
                width = YTQuantity(width[0], width[1])
            else:
                width = ds.quan(width[0], width[1])
        if img_ctr is None:
            img_ctr = np.zeros(3)

        exclude_fields = [
            "x",
            "y",
            "z",
            "px",
            "py",
            "pz",
            "pdx",
            "pdy",
            "pdz",
            "weight_field",
        ]

        if isinstance(data, _astropy.pyfits.PrimaryHDU):
            data = _astropy.pyfits.HDUList([data])

        if isinstance(data, _astropy.pyfits.HDUList):
            self.hdulist = data
            for hdu in data:
                self.fields.append(hdu.header["btype"])
                self.field_units[hdu.header["btype"]] = hdu.header["bunit"]

            self.shape = self.hdulist[0].shape
            self.dimensionality = len(self.shape)
            wcs_names = [
                key for key in self.hdulist[0].header if "WCSNAME" in key
            ]
            for name in wcs_names:
                if name == "WCSNAME":
                    key = " "
                else:
                    key = name[-1]
                w = _astropy.pywcs.WCS(header=self.hdulist[0].header,
                                       key=key,
                                       naxis=self.dimensionality)
                setattr(self, "wcs" + key.strip().lower(), w)

            return

        self.hdulist = _astropy.pyfits.HDUList()

        if hasattr(data, "keys"):
            img_data = data
            if fields is None:
                fields = list(img_data.keys())
        elif isinstance(data, np.ndarray):
            if fields is None:
                mylog.warning("No field name given for this array. "
                              "Calling it 'image_data'.")
                fn = "image_data"
                fields = [fn]
            else:
                fn = fields[0]
            img_data = {fn: data}

        for fd in fields:
            if isinstance(fd, tuple):
                self.fields.append(fd[1])
            elif isinstance(fd, DerivedField):
                self.fields.append(fd.name[1])
            else:
                self.fields.append(fd)

        # Sanity checking names
        s = set()
        duplicates = set(f for f in self.fields if f in s or s.add(f))
        if len(duplicates) > 0:
            for i, fd in enumerate(self.fields):
                if fd in duplicates:
                    if isinstance(fields[i], tuple):
                        ftype, fname = fields[i]
                    elif isinstance(fields[i], DerivedField):
                        ftype, fname = fields[i].name
                    else:
                        raise RuntimeError("Cannot distinguish between fields "
                                           "with same name %s!" % fd)
                    self.fields[i] = "%s_%s" % (ftype, fname)

        first = True
        for i, name, field in zip(count(), self.fields, fields):
            if name not in exclude_fields:
                this_img = img_data[field]
                if hasattr(img_data[field], "units"):
                    if this_img.units.is_code_unit:
                        mylog.warning("Cannot generate an image with code "
                                      "units. Converting to units in CGS.")
                        funits = this_img.units.get_base_equivalent("cgs")
                    else:
                        funits = this_img.units
                    self.field_units[name] = str(funits)
                else:
                    self.field_units[name] = "dimensionless"
                mylog.info("Making a FITS image of field %s", name)
                if isinstance(this_img, ImageArray):
                    if i == 0:
                        self.shape = this_img.shape[::-1]
                    this_img = np.asarray(this_img)
                else:
                    if i == 0:
                        self.shape = this_img.shape
                    this_img = np.asarray(this_img.T)
                if first:
                    hdu = _astropy.pyfits.PrimaryHDU(this_img)
                    first = False
                else:
                    hdu = _astropy.pyfits.ImageHDU(this_img)
                hdu.name = name
                hdu.header["btype"] = name
                hdu.header["bunit"] = re.sub("()", "", self.field_units[name])
                for unit in ("length", "time", "mass", "velocity", "magnetic"):
                    if unit == "magnetic":
                        short_unit = "bf"
                    else:
                        short_unit = unit[0]
                    key = "{}unit".format(short_unit)
                    value = getattr(self, "{}_unit".format(unit))
                    if value is not None:
                        hdu.header[key] = float(value.value)
                        hdu.header.comments[key] = "[%s]" % value.units
                hdu.header["time"] = float(self.current_time.value)
                self.hdulist.append(hdu)

        self.dimensionality = len(self.shape)

        if wcs is None:
            w = _astropy.pywcs.WCS(header=self.hdulist[0].header,
                                   naxis=self.dimensionality)
            # FRBs and covering grids are special cases where
            # we have coordinate information, so we take advantage
            # of this and construct the WCS object
            if isinstance(img_data, FixedResolutionBuffer):
                dx = (img_data.bounds[1] -
                      img_data.bounds[0]).to_value(wcs_unit)
                dy = (img_data.bounds[3] -
                      img_data.bounds[2]).to_value(wcs_unit)
                dx /= self.shape[0]
                dy /= self.shape[1]
                xctr = 0.5 * (img_data.bounds[1] +
                              img_data.bounds[0]).to_value(wcs_unit)
                yctr = 0.5 * (img_data.bounds[3] +
                              img_data.bounds[2]).to_value(wcs_unit)
                center = [xctr, yctr]
                cdelt = [dx, dy]
            elif isinstance(img_data, YTCoveringGrid):
                cdelt = img_data.dds.to_value(wcs_unit)
                center = 0.5 * (img_data.left_edge +
                                img_data.right_edge).to_value(wcs_unit)
            else:
                # If img_data is just an array we use the width and img_ctr
                # parameters to determine the cell widths
                if not iterable(width):
                    width = [width] * self.dimensionality
                if isinstance(width[0], YTQuantity):
                    cdelt = [
                        wh.to_value(wcs_unit) / n
                        for wh, n in zip(width, self.shape)
                    ]
                else:
                    cdelt = [float(wh) / n for wh, n in zip(width, self.shape)]
                center = img_ctr[:self.dimensionality]
            w.wcs.crpix = 0.5 * (np.array(self.shape) + 1)
            w.wcs.crval = center
            w.wcs.cdelt = cdelt
            w.wcs.ctype = ["linear"] * self.dimensionality
            w.wcs.cunit = [wcs_unit] * self.dimensionality
            self.set_wcs(w)
        else:
            self.set_wcs(wcs)
コード例 #49
0
 def __init__(self, parameters):
     """
     This validator ensures that the dataset has a given parameter.
     """
     FieldValidator.__init__(self)
     self.parameters = ensure_list(parameters)
コード例 #50
0
 def __init__(self, prop):
     """
     This validator ensures that the data object has a given python attribute.
     """
     FieldValidator.__init__(self)
     self.prop = ensure_list(prop)
コード例 #51
0
ファイル: derived_quantities.py プロジェクト: aemerick/yt
 def __call__(self, fields):
     fields = ensure_list(fields)
     rv = super(TotalQuantity, self).__call__(fields)
     if len(rv) == 1: rv = rv[0]
     return rv
コード例 #52
0
 def get_data(self, fields):
     fields = ensure_list(fields)
     self.source.get_data(fields, force_particle_read=True)
     rvs = [self.source[field] for field in fields]
     if len(fields) == 1: return rvs[0]
     return rvs
コード例 #53
0
 def run(self):
     self.result = \
         all(g in ensure_list(c.Parent) for g in self.ds.index.grids
                                         for c in g.Children)