Пример #1
0
class FITSDataset(Dataset):
    _index_class = FITSHierarchy
    _field_info_class: Type[FieldInfoContainer] = FITSFieldInfo
    _dataset_type = "fits"
    _handle = None

    def __init__(
        self,
        filename,
        dataset_type="fits",
        auxiliary_files=None,
        nprocs=None,
        storage_filename=None,
        nan_mask=None,
        suppress_astropy_warnings=True,
        parameters=None,
        units_override=None,
        unit_system="cgs",
    ):

        if parameters is None:
            parameters = {}
        parameters["nprocs"] = nprocs
        self.specified_parameters = parameters

        if suppress_astropy_warnings:
            warnings.filterwarnings("ignore", module="astropy", append=True)

        self.filenames = [filename] + list(always_iterable(auxiliary_files))
        self.num_files = len(self.filenames)
        self.fluid_types += ("fits", )
        if nan_mask is None:
            self.nan_mask = {}
        elif isinstance(nan_mask, float):
            self.nan_mask = {"all": nan_mask}
        elif isinstance(nan_mask, dict):
            self.nan_mask = nan_mask
        self._handle = FITSFileHandler(self.filenames[0])
        if isinstance(self.filenames[0],
                      _astropy.pyfits.hdu.image._ImageBaseHDU) or isinstance(
                          self.filenames[0], _astropy.pyfits.HDUList):
            fn = f"InMemoryFITSFile_{uuid.uuid4().hex}"
        else:
            fn = self.filenames[0]
        self._handle._fits_files.append(self._handle)
        if self.num_files > 1:
            for fits_file in auxiliary_files:
                if isinstance(fits_file,
                              _astropy.pyfits.hdu.image._ImageBaseHDU):
                    f = _astropy.pyfits.HDUList([fits_file])
                elif isinstance(fits_file, _astropy.pyfits.HDUList):
                    f = fits_file
                else:
                    if os.path.exists(fits_file):
                        fn = fits_file
                    else:
                        fn = os.path.join(ytcfg.get("yt", "test_data_dir"),
                                          fits_file)
                    f = _astropy.pyfits.open(fn,
                                             memmap=True,
                                             do_not_scale_image_data=True,
                                             ignore_blank=True)
                self._handle._fits_files.append(f)

        self.refine_by = 2

        Dataset.__init__(
            self,
            fn,
            dataset_type,
            units_override=units_override,
            unit_system=unit_system,
        )
        self.storage_filename = storage_filename

    def _set_code_unit_attributes(self):
        """
        Generates the conversion to various physical _units based on the
        parameter file
        """
        if getattr(self, "length_unit", None) is None:
            default_length_units = [
                u for u, v in default_unit_symbol_lut.items()
                if str(v[1]) == "(length)"
            ]
            more_length_units = []
            for unit in default_length_units:
                if unit in self.unit_registry.prefixable_units:
                    more_length_units += [
                        prefix + unit for prefix in unit_prefixes
                    ]
            default_length_units += more_length_units
            file_units = []
            cunits = [
                self.wcs.wcs.cunit[i] for i in range(self.dimensionality)
            ]
            for unit in (_.to_string() for _ in cunits):
                if unit in default_length_units:
                    file_units.append(unit)
            if len(set(file_units)) == 1:
                length_factor = self.wcs.wcs.cdelt[0]
                length_unit = str(file_units[0])
                mylog.info("Found length units of %s.", length_unit)
            else:
                self.no_cgs_equiv_length = True
                mylog.warning(
                    "No length conversion provided. Assuming 1 = 1 cm.")
                length_factor = 1.0
                length_unit = "cm"
            setdefaultattr(self, "length_unit",
                           self.quan(length_factor, length_unit))
        for unit, cgs in [("time", "s"), ("mass", "g")]:
            # We set these to cgs for now, but they may have been overridden
            if getattr(self, unit + "_unit", None) is not None:
                continue
            mylog.warning("Assuming 1.0 = 1.0 %s", cgs)
            setdefaultattr(self, f"{unit}_unit", self.quan(1.0, cgs))
        self.magnetic_unit = np.sqrt(4 * np.pi * self.mass_unit /
                                     (self.time_unit**2 * self.length_unit))
        self.magnetic_unit.convert_to_units("gauss")
        self.velocity_unit = self.length_unit / self.time_unit

    def _parse_parameter_file(self):

        self._determine_structure()
        self._determine_axes()

        if self.parameter_filename.startswith("InMemory"):
            self.unique_identifier = time.time()

        # Determine dimensionality

        self.dimensionality = self.naxis
        self.geometry = "cartesian"

        # Sometimes a FITS file has a 4D datacube, in which case
        # we take the 4th axis and assume it consists of different fields.
        if self.dimensionality == 4:
            self.dimensionality = 3

        self._determine_wcs()

        self.current_time = 0.0

        self.domain_dimensions = np.array(self.dims)[:self.dimensionality]
        if self.dimensionality == 2:
            self.domain_dimensions = np.append(self.domain_dimensions,
                                               [int(1)])
        self._determine_bbox()

        # Get the simulation time
        try:
            self.current_time = self.parameters["time"]
        except Exception:
            mylog.warning("Cannot find time")
            self.current_time = 0.0
            pass

        # For now we'll ignore these
        self._periodicity = (False, ) * 3
        self.current_redshift = 0.0
        self.omega_lambda = 0.0
        self.omega_matter = 0.0
        self.hubble_constant = 0.0
        self.cosmological_simulation = 0

        self._determine_nprocs()

        # Now we can set up some of our parameters for convenience.
        for k, v in self.primary_header.items():
            self.parameters[k] = v
        # Remove potential default keys
        self.parameters.pop("", None)

    def _determine_nprocs(self):
        # If nprocs is None, do some automatic decomposition of the domain
        if self.specified_parameters["nprocs"] is None:
            nprocs = np.around(
                np.prod(self.domain_dimensions) /
                32**self.dimensionality).astype("int")
            self.parameters["nprocs"] = max(min(nprocs, 512), 1)
        else:
            self.parameters["nprocs"] = self.specified_parameters["nprocs"]

    def _determine_structure(self):
        self.primary_header, self.first_image = find_primary_header(
            self._handle)
        self.naxis = self.primary_header["naxis"]
        self.axis_names = [
            self.primary_header.get("ctype%d" % (i + 1), "LINEAR")
            for i in range(self.naxis)
        ]
        self.dims = [
            self.primary_header["naxis%d" % (i + 1)] for i in range(self.naxis)
        ]

    def _determine_wcs(self):
        wcs = _astropy.pywcs.WCS(header=self.primary_header)
        if self.naxis == 4:
            self.wcs = _astropy.pywcs.WCS(naxis=3)
            self.wcs.wcs.crpix = wcs.wcs.crpix[:3]
            self.wcs.wcs.cdelt = wcs.wcs.cdelt[:3]
            self.wcs.wcs.crval = wcs.wcs.crval[:3]
            self.wcs.wcs.cunit = [str(unit) for unit in wcs.wcs.cunit][:3]
            self.wcs.wcs.ctype = [type for type in wcs.wcs.ctype][:3]
        else:
            self.wcs = wcs

    def _determine_bbox(self):
        domain_left_edge = np.array([0.5] * 3)
        domain_right_edge = np.array(
            [float(dim) + 0.5 for dim in self.domain_dimensions])

        if self.dimensionality == 2:
            domain_left_edge[-1] = 0.5
            domain_right_edge[-1] = 1.5

        self.domain_left_edge = domain_left_edge
        self.domain_right_edge = domain_right_edge

    def _determine_axes(self):
        self.lat_axis = 1
        self.lon_axis = 0
        self.lat_name = "Y"
        self.lon_name = "X"

    @classmethod
    def _is_valid(cls, filename, *args, **kwargs):
        fileh = check_fits_valid(filename)
        if fileh is None:
            return False
        else:
            fileh.close()
            return True

    @classmethod
    def _guess_candidates(cls, base, directories, files):
        candidates = []
        for fn, fnl in ((_, _.lower()) for _ in files):
            if (fnl.endswith(".fits") or fnl.endswith(".fits.gz")
                    or fnl.endswith(".fits.fz")):
                candidates.append(fn)
        # FITS files don't preclude subdirectories
        return candidates, True

    def close(self):
        self._handle.close()
Пример #2
0
class FITSDataset(Dataset):
    _index_class = FITSHierarchy
    _field_info_class = FITSFieldInfo
    _dataset_type = "fits"
    _handle = None

    def __init__(self,
                 filename,
                 dataset_type='fits',
                 auxiliary_files=[],
                 nprocs=None,
                 storage_filename=None,
                 nan_mask=None,
                 spectral_factor=1.0,
                 z_axis_decomp=False,
                 suppress_astropy_warnings=True,
                 parameters=None,
                 units_override=None,
                 unit_system="cgs"):

        if parameters is None:
            parameters = {}
        parameters["nprocs"] = nprocs
        self.specified_parameters = parameters

        self.z_axis_decomp = z_axis_decomp
        self.spectral_factor = spectral_factor

        if suppress_astropy_warnings:
            warnings.filterwarnings('ignore', module="astropy", append=True)
        auxiliary_files = ensure_list(auxiliary_files)
        self.filenames = [filename] + auxiliary_files
        self.num_files = len(self.filenames)
        self.fluid_types += ("fits", )
        if nan_mask is None:
            self.nan_mask = {}
        elif isinstance(nan_mask, float):
            self.nan_mask = {"all": nan_mask}
        elif isinstance(nan_mask, dict):
            self.nan_mask = nan_mask
        self._handle = FITSFileHandler(self.filenames[0])
        if (isinstance(self.filenames[0],
                       _astropy.pyfits.hdu.image._ImageBaseHDU)
                or isinstance(self.filenames[0], _astropy.pyfits.HDUList)):
            fn = "InMemoryFITSFile_%s" % uuid.uuid4().hex
        else:
            fn = self.filenames[0]
        self._handle._fits_files.append(self._handle)
        if self.num_files > 1:
            for fits_file in auxiliary_files:
                if isinstance(fits_file,
                              _astropy.pyfits.hdu.image._ImageBaseHDU):
                    f = _astropy.pyfits.HDUList([fits_file])
                elif isinstance(fits_file, _astropy.pyfits.HDUList):
                    f = fits_file
                else:
                    if os.path.exists(fits_file):
                        fn = fits_file
                    else:
                        fn = os.path.join(ytcfg.get("yt", "test_data_dir"),
                                          fits_file)
                    f = _astropy.pyfits.open(fn,
                                             memmap=True,
                                             do_not_scale_image_data=True,
                                             ignore_blank=True)
                self._handle._fits_files.append(f)

        if len(self._handle) > 1 and self._handle[1].name == "EVENTS":
            self.events_data = True
            self.first_image = 1
            self.primary_header = self._handle[self.first_image].header
            self.naxis = 2
            self.wcs = _astropy.pywcs.WCS(naxis=2)
            self.events_info = {}
            for k, v in self.primary_header.items():
                if k.startswith("TTYP"):
                    if v.lower() in ["x", "y"]:
                        num = k.strip("TTYPE")
                        self.events_info[v.lower()] = (
                            self.primary_header["TLMIN" + num],
                            self.primary_header["TLMAX" + num],
                            self.primary_header["TCTYP" + num],
                            self.primary_header["TCRVL" + num],
                            self.primary_header["TCDLT" + num],
                            self.primary_header["TCRPX" + num])
                    elif v.lower() in ["energy", "time"]:
                        num = k.strip("TTYPE")
                        unit = self.primary_header["TUNIT" + num].lower()
                        if unit.endswith("ev"): unit = unit.replace("ev", "eV")
                        self.events_info[v.lower()] = unit
            self.axis_names = [self.events_info[ax][2] for ax in ["x", "y"]]
            self.reblock = 1
            if "reblock" in self.specified_parameters:
                self.reblock = self.specified_parameters["reblock"]
            self.wcs.wcs.cdelt = [
                self.events_info["x"][4] * self.reblock,
                self.events_info["y"][4] * self.reblock
            ]
            self.wcs.wcs.crpix = [
                (self.events_info["x"][5] - 0.5) / self.reblock + 0.5,
                (self.events_info["y"][5] - 0.5) / self.reblock + 0.5
            ]
            self.wcs.wcs.ctype = [
                self.events_info["x"][2], self.events_info["y"][2]
            ]
            self.wcs.wcs.cunit = ["deg", "deg"]
            self.wcs.wcs.crval = [
                self.events_info["x"][3], self.events_info["y"][3]
            ]
            self.dims = [
                (self.events_info["x"][1] - self.events_info["x"][0]) /
                self.reblock,
                (self.events_info["y"][1] - self.events_info["y"][0]) /
                self.reblock
            ]
        else:
            self.events_data = False
            # Sometimes the primary hdu doesn't have an image
            if len(self._handle) > 1 and self._handle[0].header["naxis"] == 0:
                self.first_image = 1
            else:
                self.first_image = 0
            self.primary_header = self._handle[self.first_image].header
            self.naxis = self.primary_header["naxis"]
            self.axis_names = [
                self.primary_header.get("ctype%d" % (i + 1), "LINEAR")
                for i in range(self.naxis)
            ]
            self.dims = [
                self.primary_header["naxis%d" % (i + 1)]
                for i in range(self.naxis)
            ]
            wcs = _astropy.pywcs.WCS(header=self.primary_header)
            if self.naxis == 4:
                self.wcs = _astropy.pywcs.WCS(naxis=3)
                self.wcs.wcs.crpix = wcs.wcs.crpix[:3]
                self.wcs.wcs.cdelt = wcs.wcs.cdelt[:3]
                self.wcs.wcs.crval = wcs.wcs.crval[:3]
                self.wcs.wcs.cunit = [str(unit) for unit in wcs.wcs.cunit][:3]
                self.wcs.wcs.ctype = [type for type in wcs.wcs.ctype][:3]
            else:
                self.wcs = wcs

        self.refine_by = 2

        Dataset.__init__(self,
                         fn,
                         dataset_type,
                         units_override=units_override,
                         unit_system=unit_system)
        self.storage_filename = storage_filename

    def _set_code_unit_attributes(self):
        """
        Generates the conversion to various physical _units based on the parameter file
        """
        default_length_units = [
            u for u, v in default_unit_symbol_lut.items()
            if str(v[1]) == "(length)"
        ]
        more_length_units = []
        for unit in default_length_units:
            if unit in prefixable_units:
                more_length_units += [
                    prefix + unit for prefix in unit_prefixes
                ]
        default_length_units += more_length_units
        file_units = []
        cunits = [self.wcs.wcs.cunit[i] for i in range(self.dimensionality)]
        for unit in (_.to_string() for _ in cunits):
            if unit in default_length_units:
                file_units.append(unit)
        if len(set(file_units)) == 1:
            length_factor = self.wcs.wcs.cdelt[0]
            length_unit = str(file_units[0])
            mylog.info("Found length units of %s." % (length_unit))
        else:
            self.no_cgs_equiv_length = True
            mylog.warning("No length conversion provided. Assuming 1 = 1 cm.")
            length_factor = 1.0
            length_unit = "cm"
        setdefaultattr(self, 'length_unit',
                       self.quan(length_factor, length_unit))
        setdefaultattr(self, 'mass_unit', self.quan(1.0, "g"))
        setdefaultattr(self, 'time_unit', self.quan(1.0, "s"))
        setdefaultattr(self, 'velocity_unit', self.quan(1.0, "cm/s"))
        if "beam_size" in self.specified_parameters:
            beam_size = self.specified_parameters["beam_size"]
            beam_size = self.quan(beam_size[0], beam_size[1]).in_cgs().value
        else:
            beam_size = 1.0
        self.unit_registry.add("beam",
                               beam_size,
                               dimensions=dimensions.solid_angle)
        if self.spec_cube:
            units = self.wcs_2d.wcs.cunit[0]
            if units == "deg": units = "degree"
            if units == "rad": units = "radian"
            pixel_area = np.prod(np.abs(self.wcs_2d.wcs.cdelt))
            pixel_area = self.quan(pixel_area, "%s**2" % (units)).in_cgs()
            pixel_dims = pixel_area.units.dimensions
            self.unit_registry.add("pixel",
                                   float(pixel_area.value),
                                   dimensions=pixel_dims)

    def _parse_parameter_file(self):

        if self.parameter_filename.startswith("InMemory"):
            self.unique_identifier = time.time()
        else:
            self.unique_identifier = \
                int(os.stat(self.parameter_filename)[stat.ST_CTIME])

        # Determine dimensionality

        self.dimensionality = self.naxis
        self.geometry = "cartesian"

        # Sometimes a FITS file has a 4D datacube, in which case
        # we take the 4th axis and assume it consists of different fields.
        if self.dimensionality == 4: self.dimensionality = 3

        self.domain_dimensions = np.array(self.dims)[:self.dimensionality]
        if self.dimensionality == 2:
            self.domain_dimensions = np.append(self.domain_dimensions,
                                               [int(1)])

        domain_left_edge = np.array([0.5] * 3)
        domain_right_edge = np.array(
            [float(dim) + 0.5 for dim in self.domain_dimensions])

        if self.dimensionality == 2:
            domain_left_edge[-1] = 0.5
            domain_right_edge[-1] = 1.5

        self.domain_left_edge = domain_left_edge
        self.domain_right_edge = domain_right_edge

        # Get the simulation time
        try:
            self.current_time = self.parameters["time"]
        except:
            mylog.warning("Cannot find time")
            self.current_time = 0.0
            pass

        # For now we'll ignore these
        self.periodicity = (False, ) * 3
        self.current_redshift = self.omega_lambda = self.omega_matter = \
            self.hubble_constant = self.cosmological_simulation = 0.0

        if self.dimensionality == 2 and self.z_axis_decomp:
            mylog.warning(
                "You asked to decompose along the z-axis, but this is a 2D dataset. "
                + "Ignoring.")
            self.z_axis_decomp = False

        if self.events_data: self.specified_parameters["nprocs"] = 1

        # If nprocs is None, do some automatic decomposition of the domain
        if self.specified_parameters["nprocs"] is None:
            if self.z_axis_decomp:
                nprocs = np.around(self.domain_dimensions[2] / 8).astype("int")
            else:
                nprocs = np.around(
                    np.prod(self.domain_dimensions) /
                    32**self.dimensionality).astype("int")
            self.parameters["nprocs"] = max(min(nprocs, 512), 1)
        else:
            self.parameters["nprocs"] = self.specified_parameters["nprocs"]

        # Check to see if this data is in some kind of (Lat,Lon,Vel) format
        self.spec_cube = False
        self.wcs_2d = None
        x = 0
        for p in lon_prefixes + lat_prefixes + list(spec_names.keys()):
            y = np_char.startswith(self.axis_names[:self.dimensionality], p)
            x += np.any(y)
        if x == self.dimensionality:
            if self.axis_names == ['LINEAR', 'LINEAR']:
                self.wcs_2d = self.wcs
                self.lat_axis = 1
                self.lon_axis = 0
                self.lat_name = "Y"
                self.lon_name = "X"
            else:
                self._setup_spec_cube()

        # Now we can set up some of our parameters for convenience.
        #self.parameters['wcs'] = dict(self.wcs.to_header())
        for k, v in self.primary_header.items():
            self.parameters[k] = v
        # Remove potential default keys
        self.parameters.pop('', None)

    def _setup_spec_cube(self):

        self.spec_cube = True
        self.geometry = "spectral_cube"

        end = min(self.dimensionality + 1, 4)
        if self.events_data:
            ctypes = self.axis_names
        else:
            ctypes = np.array(
                [self.primary_header["CTYPE%d" % (i)] for i in range(1, end)])

        log_str = "Detected these axes: " + "%s " * len(ctypes)
        mylog.info(log_str % tuple([ctype for ctype in ctypes]))

        self.lat_axis = np.zeros((end - 1), dtype="bool")
        for p in lat_prefixes:
            self.lat_axis += np_char.startswith(ctypes, p)
        self.lat_axis = np.where(self.lat_axis)[0][0]
        self.lat_name = ctypes[self.lat_axis].split("-")[0].lower()

        self.lon_axis = np.zeros((end - 1), dtype="bool")
        for p in lon_prefixes:
            self.lon_axis += np_char.startswith(ctypes, p)
        self.lon_axis = np.where(self.lon_axis)[0][0]
        self.lon_name = ctypes[self.lon_axis].split("-")[0].lower()

        if self.lat_axis == self.lon_axis and self.lat_name == self.lon_name:
            self.lat_axis = 1
            self.lon_axis = 0
            self.lat_name = "Y"
            self.lon_name = "X"

        if self.wcs.naxis > 2:

            self.spec_axis = np.zeros((end - 1), dtype="bool")
            for p in spec_names.keys():
                self.spec_axis += np_char.startswith(ctypes, p)
            self.spec_axis = np.where(self.spec_axis)[0][0]
            self.spec_name = spec_names[ctypes[self.spec_axis].split("-")[0]
                                        [0]]

            self.wcs_2d = _astropy.pywcs.WCS(naxis=2)
            self.wcs_2d.wcs.crpix = self.wcs.wcs.crpix[[
                self.lon_axis, self.lat_axis
            ]]
            self.wcs_2d.wcs.cdelt = self.wcs.wcs.cdelt[[
                self.lon_axis, self.lat_axis
            ]]
            self.wcs_2d.wcs.crval = self.wcs.wcs.crval[[
                self.lon_axis, self.lat_axis
            ]]
            self.wcs_2d.wcs.cunit = [
                str(self.wcs.wcs.cunit[self.lon_axis]),
                str(self.wcs.wcs.cunit[self.lat_axis])
            ]
            self.wcs_2d.wcs.ctype = [
                self.wcs.wcs.ctype[self.lon_axis],
                self.wcs.wcs.ctype[self.lat_axis]
            ]

            self._p0 = self.wcs.wcs.crpix[self.spec_axis]
            self._dz = self.wcs.wcs.cdelt[self.spec_axis]
            self._z0 = self.wcs.wcs.crval[self.spec_axis]
            self.spec_unit = str(self.wcs.wcs.cunit[self.spec_axis])

            if self.spectral_factor == "auto":
                self.spectral_factor = float(
                    max(self.domain_dimensions[[self.lon_axis,
                                                self.lat_axis]]))
                self.spectral_factor /= self.domain_dimensions[self.spec_axis]
                mylog.info("Setting the spectral factor to %f" %
                           (self.spectral_factor))
            Dz = self.domain_right_edge[
                self.spec_axis] - self.domain_left_edge[self.spec_axis]
            dre = self.domain_right_edge
            dre[self.spec_axis] = (self.domain_left_edge[self.spec_axis] +
                                   self.spectral_factor * Dz)
            self.domain_right_edge = dre
            self._dz /= self.spectral_factor
            self._p0 = (self._p0 - 0.5) * self.spectral_factor + 0.5

        else:

            self.wcs_2d = self.wcs
            self.spec_axis = 2
            self.spec_name = "z"
            self.spec_unit = "code_length"

    def spec2pixel(self, spec_value):
        sv = self.arr(spec_value).in_units(self.spec_unit)
        return self.arr((sv.v - self._z0) / self._dz + self._p0, "code_length")

    def pixel2spec(self, pixel_value):
        pv = self.arr(pixel_value, "code_length")
        return self.arr((pv.v - self._p0) * self._dz + self._z0,
                        self.spec_unit)

    @classmethod
    def _is_valid(cls, *args, **kwargs):
        ext = args[0].rsplit(".", 1)[-1]
        if ext.upper() in ("GZ", "FZ"):
            # We don't know for sure that there will be > 1
            ext = args[0].rsplit(".", 1)[0].rsplit(".", 1)[-1]
        if ext.upper() not in ("FITS", "FTS"):
            return False
        elif isinstance(_astropy.pyfits, NotAModule):
            raise RuntimeError(
                "This appears to be a FITS file, but AstroPy is not installed."
            )
        try:
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore',
                                        category=UserWarning,
                                        append=True)
                fileh = _astropy.pyfits.open(args[0])
            valid = fileh[0].header["naxis"] >= 2
            if len(fileh) > 1:
                valid = fileh[1].header["naxis"] >= 2 or valid
            fileh.close()
            return valid
        except:
            pass
        return False

    @classmethod
    def _guess_candidates(cls, base, directories, files):
        candidates = []
        for fn, fnl in ((_, _.lower()) for _ in files):
            if fnl.endswith(".fits") or \
               fnl.endswith(".fits.gz") or \
               fnl.endswith(".fits.fz"):
                candidates.append(fn)
        # FITS files don't preclude subdirectories
        return candidates, True

    def close(self):
        self._handle.close()