def setup_class(cls):
        time = np.arange(0, 1e7)
        counts = np.random.poisson(10, time.size)
        cls.lc = Lightcurve(time, counts, skip_checks=True)

        evtimes = np.sort(np.random.uniform(0, 1e7, 10**7))
        pi = np.random.randint(0, 100, evtimes.size)
        energy = pi * 0.04 + 1.6
        cls.ev = EventList(
            time=evtimes,
            pi=pi,
            energy=energy,
            gti=[[0, 1e7]],
            dt=1e-5,
            notes="Bu",
        )
        cls.ev_noattrs = copy.deepcopy(cls.ev)
        cls.ev_noattrs.energy = None
        cls.ev_noattrs.pi = None
        cls.ev_noattrs.mjdref = 0
        cls.ev_noattrs.gti = None
        cls.ev_noattrs.dt = 0
        cls.ev_noattrs.notes = None

        cls.lc_path = saveData(cls.lc, persist=False)
        cls.ev_path = saveData(cls.ev, persist=False)
        cls.ev_path_noattrs = saveData(cls.ev_noattrs, persist=False)
Exemple #2
0
    def __init__(self, data=None, segment_size=None, norm="frac", gti=None,
                 silent=False, dt=None, lc=None, large_data=False,
                 save_all=False):

        if lc is not None:
            warnings.warn("The lc keyword is now deprecated. Use data "
                          "instead", DeprecationWarning)
        # Backwards compatibility: user might have supplied lc instead
        if data is None:
            data = lc

        if segment_size is None and data is not None:
            raise ValueError("segment_size must be specified")
        if segment_size is not None and not np.isfinite(segment_size):
            raise ValueError("segment_size must be finite!")

        if large_data and data is not None:
            chunks = None

            if isinstance(data, EventList):
                input_data = 'EventList'
            elif isinstance(data, Lightcurve):
                input_data = 'Lightcurve'
                chunks = int(np.rint(segment_size // data.dt))
                segment_size = chunks * data.dt
            else:
                raise ValueError(
                    f'Invalid input data type: {type(data).__name__}')

            dir_path = saveData(data, persist=False, chunks=chunks)

            data_path = genDataPath(dir_path)
            spec = createChunkedSpectra(input_data,
                                        'AveragedPowerspectrum',
                                        data_path=data_path,
                                        segment_size=segment_size,
                                        norm=norm,
                                        gti=gti,
                                        power_type=None,
                                        silent=silent,
                                        dt=dt)
            for key, val in spec.__dict__.items():
                setattr(self, key, val)

            return

        self.type = "powerspectrum"
        self.dt = dt
        self.save_all = save_all

        if isinstance(data, EventList):
            lengths = data.gti[:, 1] - data.gti[:, 0]
            good = lengths >= segment_size
            data.gti = data.gti[good]

        self.segment_size = segment_size
        self.show_progress = not silent
        Powerspectrum.__init__(self, data, norm, gti=gti, dt=dt)

        return
    def test_save_lc_small(self):
        test_lc = copy.deepcopy(self.lc)
        # Make sure counts_err exists
        _ = test_lc.counts_err

        # Save small part of data, < certainly chunk_size
        _ = saveData(test_lc[:300], persist=False, chunks=100000)
    def test_save_ev_missing_psutil_not_linux(self, monkeypatch):
        monkeypatch.setitem(sys.modules, "psutil", None)

        with pytest.warns(UserWarning) as record:
            _ = saveData(self.ev, persist=False)
        assert np.any(['will not depend on available RAM' in r.message.args[0]
                       for r in record])
    def test_save_fits_data(self):
        fname = os.path.join(datadir, "monol_testA.evt")
        dir_name = saveData(fname, persist=False)

        evtdata = load_events_and_gtis(fname, additional_columns=["PI"])
        mjdref_def = ref_mjd(fname, hdu=1)
        time_def = evtdata.ev_list
        pi_channel_def = evtdata.additional_data["PI"]
        gti_def = evtdata.gti_list
        tstart_def = evtdata.t_start
        tstop_def = evtdata.t_stop

        main = os.path.join(dir_name, "main_data")
        meta = os.path.join(dir_name, "meta_data")

        errors = []

        if (
            len([f for f in os.listdir(main) if not f.startswith(".")])
            or len([f for f in os.listdir(meta) if not f.startswith(".")])
        ) == 0:
            errors.append("EventList is not saved or does not exist")
        else:
            times = zarr.open_array(store=main, mode="r", path="times")[...]
            pi_channel = zarr.open_array(
                store=main, mode="r", path="pi_channel"
            )[...]
            gti = zarr.open_array(store=main, mode="r", path="gti")[...]
            gti = gti.reshape((gti.size // 2, 2))
            tstart = zarr.open_array(store=meta, mode="r", path="tstart")[...]
            tstop = zarr.open_array(store=meta, mode="r", path="tstop")[...]
            mjdref = zarr.open_array(store=meta, mode="r", path="mjdref")[...]

            order = np.argsort(times)
            times = times[order]
            pi_channel = pi_channel[order]

            if not np.allclose(time_def, times):
                errors.append("fits.events.data.time is not saved precisely")
            if not np.array_equal(pi_channel_def, pi_channel):
                errors.append("fits.events.data.pi is not saved precisely")
            if not np.allclose(gti_def, gti):
                errors.append("fits.gti.data is not saved precisely")
            if not (tstart == tstart_def):
                errors.append(
                    "fits.events.header.tstart is not saved precisely"
                )
            if not (tstop == tstop_def):
                errors.append(
                    "fits.events.header.tstop is not saved precisely"
                )
            if not (mjdref == mjdref_def):
                errors.append(
                    "fits.events.header.mjdref is not saved precisely"
                )

        assert not errors, "Errors encountered:\n{}".format("\n".join(errors))
    def test_save_lc(self):
        test_lc = copy.deepcopy(self.lc)
        # Make sure counts_err exists
        _ = test_lc.counts_err

        dir_name = saveData(test_lc, persist=False)

        main = os.path.join(dir_name, "main_data")
        meta = os.path.join(dir_name, "meta_data")

        errors = []

        if (
            len([f for f in os.listdir(main) if not f.startswith(".")])
            or len([f for f in os.listdir(meta) if not f.startswith(".")])
        ) == 0:
            errors.append("Lightcurve is not saved or does not exist")
        else:
            times = zarr.open_array(store=main, mode="r", path="times")[...]
            counts = zarr.open_array(store=main, mode="r", path="counts")[...]
            count_err = zarr.open_array(
                store=main, mode="r", path="count_err"
            )[...]
            gti = zarr.open_array(store=main, mode="r", path="gti")[...]
            gti = gti.reshape((gti.size // 2, 2))

            dt = zarr.open_array(store=meta, mode="r", path="dt")[...]
            mjdref = zarr.open_array(store=meta, mode="r", path="mjdref")[...]
            err_dist = zarr.open_array(store=meta, mode="r", path="err_dist")[
                ...
            ]

            if not np.array_equal(test_lc.time, times):
                errors.append("lc.time is not saved precisely")
            if not np.array_equal(test_lc.counts, counts):
                errors.append("lc.counts is not saved precisely")
            if not np.array_equal(test_lc.counts_err, count_err):
                errors.append("lc.counts_err is not saved precisely")
            if not np.array_equal(test_lc.gti, gti):
                errors.append("lc.gti is not saved precisely")
            if not (test_lc.dt == dt):
                errors.append("lc.dt is not saved precisely")
            if not (test_lc.mjdref == mjdref):
                errors.append("lc.mjdref is not saved precisely")
            if not (test_lc.err_dist == err_dist):
                errors.append("lc.err_dist is not saved precisely")

        assert not errors, "Errors encountered:\n{}".format("\n".join(errors))
    def test_save_ev(self):
        dir_name = saveData(self.ev, persist=False)

        main = os.path.join(dir_name, "main_data")
        meta = os.path.join(dir_name, "meta_data")

        errors = []

        if (
            len([f for f in os.listdir(main) if not f.startswith(".")])
            or len([f for f in os.listdir(meta) if not f.startswith(".")])
        ) == 0:
            errors.append("EventList is not saved or does not exist")

        else:
            times = zarr.open_array(store=main, mode="r", path="times")[...]
            energy = zarr.open_array(store=main, mode="r", path="energy")[...]
            pi_channel = zarr.open_array(
                store=main, mode="r", path="pi_channel"
            )[...]
            gti = zarr.open_array(store=main, mode="r", path="gti")[...]
            gti = gti.reshape((gti.size // 2, 2))
            dt = zarr.open_array(store=meta, mode="r", path="dt")[...]
            ncounts = zarr.open_array(store=meta, mode="r", path="ncounts")[
                ...
            ]
            mjdref = zarr.open_array(store=meta, mode="r", path="mjdref")[...]
            notes = zarr.open_array(store=meta, mode="r", path="notes")[...]

            if not np.array_equal(self.ev.time, times):
                errors.append("ev.time is not saved precisely")
            if not np.array_equal(self.ev.energy, energy):
                errors.append("ev.energy is not saved precisely")
            if not np.array_equal(self.ev.pi, pi_channel):
                errors.append("ev.pi is not saved precisely")
            if not np.array_equal(self.ev.gti, gti):
                errors.append("ev.gti is not saved precisely")
            if not np.isclose(self.ev.dt, dt):
                errors.append("ev.dt is not saved precisely")
            if not self.ev.ncounts == ncounts:
                errors.append("ev.ncounts is not saved precisely")
            if not np.isclose(self.ev.mjdref, mjdref):
                errors.append("ev.mjdref is not saved precisely")
            if not self.ev.notes == notes:
                errors.append("ev.notes is not saved precisely")

        assert not errors, "Errors encountered:\n{}".format("\n".join(errors))
    def __init__(self, data=None, segment_size=None, norm="frac", gti=None,
                 silent=False, dt=None, lc=None, large_data=False,
                 save_all=False, skip_checks=False,
                 use_common_mean=True, legacy=False):

        self._type = None
        if lc is not None:
            warnings.warn("The lc keyword is now deprecated. Use data "
                          "instead", DeprecationWarning)
        # Backwards compatibility: user might have supplied lc instead
        if data is None:
            data = lc

        good_input = True
        if not skip_checks:
            good_input = self.initial_checks(
                data1=data,
                data2=data,
                norm=norm,
                gti=gti,
                lc1=lc,
                lc2=lc,
                dt=dt,
                segment_size=segment_size
            )

        norm = norm.lower()
        self.norm = norm
        self.dt = dt
        self.save_all = save_all
        self.segment_size = segment_size
        self.show_progress = not silent

        if not good_input:
            return self._initialize_empty()

        if isinstance(data, Generator):
            warnings.warn(
                "The averaged Power spectrum from a generator of "
                "light curves pre-allocates the full list of light "
                "curves, losing all advantage of lazy loading. If it "
                "is important for you, use the "
                "AveragedPowerspectrum.from_lc_iterable static "
                "method, specifying the sampling time `dt`.")
            data = list(data)

       # The large_data option requires the legacy interface.
        if (large_data or save_all) and not legacy:
            warnings.warn("The large_data option and the save_all options are only"
                          "available with the legacy interface (legacy=True).")
            legacy = True

        if not legacy and data is not None:
            return self._initialize_from_any_input(
                data, dt=dt, segment_size=segment_size, norm=norm,
                silent=silent, use_common_mean=use_common_mean)

        if large_data and data is not None:
            if not HAS_ZARR:
                raise ImportError("The large_data option requires zarr.")
            chunks = None

            if isinstance(data, EventList):
                input_data = 'EventList'
            elif isinstance(data, Lightcurve):
                input_data = 'Lightcurve'
                chunks = int(np.rint(segment_size // data.dt))
                segment_size = chunks * data.dt
            else:
                raise ValueError(
                    f'Invalid input data type: {type(data).__name__}')

            dir_path = saveData(data, persist=False, chunks=chunks)

            data_path = genDataPath(dir_path)
            spec = createChunkedSpectra(input_data,
                                        'AveragedPowerspectrum',
                                        data_path=data_path,
                                        segment_size=segment_size,
                                        norm=norm,
                                        gti=gti,
                                        power_type=None,
                                        silent=silent,
                                        dt=dt)
            for key, val in spec.__dict__.items():
                setattr(self, key, val)

            return

        if isinstance(data, EventList):
            lengths = data.gti[:, 1] - data.gti[:, 0]
            good = lengths >= segment_size
            data.gti = data.gti[good]

        Powerspectrum.__init__(
            self, data, norm, gti=gti, dt=dt, skip_checks=True, legacy=legacy)

        return
 def test_save_wrong_data(self):
     with pytest.raises(ValueError) as excinfo:
         saveData("A string", "bububu")
     assert "Invalid data: A string (str)" in str(excinfo.value)
    def test_save_ev_missing_psutil_linux(self, monkeypatch):
        monkeypatch.setitem(sys.modules, "psutil", None)

        _ = saveData(self.ev, persist=False)
 def test_save_ev_small(self):
     # Save small part of data, < certainly chunk_size
     ev = EventList(time=np.arange(1000))
     _ = saveData(self.ev, persist=False, chunks=100000)