コード例 #1
0
ファイル: test_serialize_yaml.py プロジェクト: fjhzwl/gammapy
def test_datasets_to_io(tmpdir):
    filedata = "$GAMMAPY_DATA/tests/models/gc_example_datasets.yaml"
    filemodel = "$GAMMAPY_DATA/tests/models/gc_example_models.yaml"

    datasets = Datasets.from_yaml(filedata, filemodel)

    assert len(datasets.datasets) == 2
    assert len(datasets.parameters.parameters) == 20

    dataset0 = datasets.datasets[0]
    assert dataset0.counts.data.sum() == 6824
    assert_allclose(dataset0.exposure.data.sum(), 2072125400000.0, atol=0.1)
    assert dataset0.psf is not None
    assert dataset0.edisp is not None

    assert_allclose(dataset0.background_model.evaluate().data.sum(),
                    4094.2,
                    atol=0.1)

    assert dataset0.background_model.name == "background_irf_gc"

    dataset1 = datasets.datasets[1]
    assert dataset1.background_model.name == "background_irf_g09"

    assert dataset0.model["gll_iem_v06_cutout"] == dataset1.model[
        "gll_iem_v06_cutout"]

    assert isinstance(dataset0.model, SkyModels)
    assert len(dataset0.model.skymodels) == 2
    assert dataset0.model.skymodels[0].name == "gc"
    assert dataset0.model.skymodels[1].name == "gll_iem_v06_cutout"

    assert (dataset0.model.skymodels[0].parameters["reference"] is
            dataset1.model.skymodels[1].parameters["reference"])

    assert_allclose(dataset1.model.skymodels[1].parameters["lon_0"].value,
                    0.9,
                    atol=0.1)

    path = str(tmpdir / "/written_")
    datasets.to_yaml(path, overwrite=True)
    datasets_read = Datasets.from_yaml(path + "datasets.yaml",
                                       path + "models.yaml")
    assert len(datasets_read.datasets) == 2
    dataset0 = datasets_read.datasets[0]
    assert dataset0.counts.data.sum() == 6824
    assert_allclose(dataset0.exposure.data.sum(), 2072125400000.0, atol=0.1)
    assert dataset0.psf is not None
    assert dataset0.edisp is not None
    assert_allclose(dataset0.background_model.evaluate().data.sum(),
                    4094.2,
                    atol=0.1)
コード例 #2
0
def test_spectrum_dataset_on_off_to_yaml(tmpdir):
    spectrum_datasets_on_off = make_observation_list()
    datasets = Datasets(spectrum_datasets_on_off)
    datasets.to_yaml(path=tmpdir)
    datasets_read = Datasets.from_yaml(tmpdir / "_datasets.yaml",
                                       tmpdir / "_models.yaml")
    assert len(datasets_read) == len(datasets)
    assert datasets_read[0].name == datasets[0].name
    assert datasets_read[1].name == datasets[1].name
    assert datasets_read[1].counts.data.sum() == datasets[1].counts.data.sum()
コード例 #3
0
    def read_regions(self):
        for kr in self.ROIs_sel:
            filedata = Path(self.resdir + "/3FHL_ROI_num" + str(kr) +
                            "_datasets.yaml")
            filemodel = Path(self.resdir + "/3FHL_ROI_num" + str(kr) +
                             "_models.yaml")
            try:
                dataset = list(Datasets.from_yaml(filedata, filemodel))[0]
            except (FileNotFoundError, IOError):
                continue

            pars = dataset.parameters
            pars.covariance = np.load(self.resdir + "/" + dataset.name +
                                      "_covariance.npy")

            infos = np.load(self.resdir + "/3FHL_ROI_num" + str(kr) +
                            "_fit_infos.npz")
            self.diags["message"].append(infos["message"])
            self.diags["stat"].append(infos["stat"])

            if self.savefig:
                self.plot_maps(dataset)

            for model in list(dataset.model):
                if (self.FHL3[model.name].data["ROI_num"] == kr
                        and self.FHL3[model.name].data["Signif_Avg"] >=
                        self.sig_cut):

                    model.spatial_model.parameters.covariance = pars.get_subcovariance(
                        model.spatial_model.parameters)
                    model.spectral_model.parameters.covariance = pars.get_subcovariance(
                        model.spectral_model.parameters)
                    dataset.background_model.parameters.covariance = pars.get_subcovariance(
                        dataset.background_model.parameters)
                    res_spec = model.spectral_model
                    cat_spec = self.FHL3[model.name].spectral_model()

                    res_fp = FluxPoints.read(self.resdir + "/" + model.name +
                                             "_flux_points.fits")
                    res_fp.table["is_ul"] = res_fp.table["ts"] < 1.0
                    cat_fp = self.FHL3[model.name].flux_points.to_sed_type(
                        "dnde")

                    self.update_spec_diags(dataset, model, cat_spec, res_spec,
                                           cat_fp, res_fp)
                    if self.savefig:
                        self.plot_spec(kr, model, cat_spec, res_spec, cat_fp,
                                       res_fp)
コード例 #4
0
def test_flux_point_dataset_serialization(tmp_path):
    path = "$GAMMAPY_DATA/tests/spectrum/flux_points/diff_flux_points.fits"
    data = FluxPoints.read(path)
    data.table["e_ref"] = data.e_ref.to("TeV")
    # TODO: remove duplicate definition this once model is redefine as skymodel
    spatial_model = ConstantSpatialModel()
    spectral_model = PowerLawSpectralModel(index=2.3,
                                           amplitude="2e-13 cm-2 s-1 TeV-1",
                                           reference="1 TeV")
    model = SkyModel(spatial_model, spectral_model, name="test_model")
    dataset = FluxPointsDataset(SkyModels([model]), data, name="test_dataset")

    Datasets([dataset]).to_yaml(tmp_path, prefix="tmp")
    datasets = Datasets.from_yaml(tmp_path / "tmp_datasets.yaml",
                                  tmp_path / "tmp_models.yaml")
    new_dataset = datasets[0]
    assert_allclose(new_dataset.data.table["dnde"], dataset.data.table["dnde"],
                    1e-4)
    if dataset.mask_fit is None:
        assert np.all(new_dataset.mask_fit == dataset.mask_safe)
    assert np.all(new_dataset.mask_safe == dataset.mask_safe)
    assert new_dataset.name == "test_dataset"