def test_mixed_axes(): label_axis = LabelMapAxis(labels=["label-1", "label-2", "label-3"], name="label") time_axis = TimeMapAxis( edges_min=[1, 10] * u.day, edges_max=[2, 13] * u.day, reference_time=Time("2020-03-19"), ) energy_axis = MapAxis.from_energy_bounds("1 TeV", "10 TeV", nbin=4) axes = MapAxes(axes=[energy_axis, time_axis, label_axis]) coords = axes.get_coord() assert coords["label"].shape == (1, 1, 3) assert coords["energy"].shape == (4, 1, 1) assert coords["time"].shape == (1, 2, 1) idx = axes.coord_to_idx(coords) assert_allclose(idx[0], np.arange(4).reshape((4, 1, 1))) assert_allclose(idx[1], np.arange(2).reshape((1, 2, 1))) assert_allclose(idx[2], np.arange(3).reshape((1, 1, 3))) hdu = axes.to_table_hdu(format="gadf") table = Table.read(hdu) assert table["LABEL"].dtype == np.dtype("<U7") assert len(table) == 24
def test_region_nd_map_plot_label_axis(): energy_axis = MapAxis.from_energy_bounds("1 TeV", "10 TeV", nbin=5) label_axis = LabelMapAxis(labels=["dataset-1", "dataset-2"], name="dataset") m = RegionNDMap.create(region=None, axes=[energy_axis, label_axis]) with mpl_plot_check(): m.plot(axis_name="energy") with mpl_plot_check(): m.plot(axis_name="dataset")
def test_label_map_axis_basics(): axis = LabelMapAxis(labels=["label-1", "label-2"], name="label-axis") axis_str = str(axis) assert "node type" in axis_str assert "labels" in axis_str assert "label-2" in axis_str with pytest.raises(ValueError): axis.assert_name("time") assert axis.nbin == 2 assert axis.node_type == "label" assert_allclose(axis.bin_width, 1) assert axis.name == "label-axis" with pytest.raises(ValueError): axis.edges axis_copy = axis.copy() assert axis_copy.name == "label-axis"
def test_label_axis_io(tmpdir): energy_axis = MapAxis.from_energy_bounds("1 TeV", "10 TeV", nbin=5) label_axis = LabelMapAxis(labels=["dataset-1", "dataset-2"], name="dataset") m = RegionNDMap.create(region=None, axes=[energy_axis, label_axis]) m.data = np.arange(m.data.size) filename = tmpdir / "test.fits" m.write(filename, format="gadf") m_new = RegionNDMap.read(filename, format="gadf") assert m.geom.axes["dataset"] == m_new.geom.axes["dataset"] assert m.geom.axes["energy"] == m_new.geom.axes["energy"]
def expand_map(m, dataset_names): """Expand map in dataset axis Parameters ---------- map : `Map` Map to expand. dataset_names : list of str Dataset names Returns ------- map : `Map` Expanded map. """ label_axis = LabelMapAxis(labels=dataset_names, name="dataset") geom = m.geom.replace_axis(axis=label_axis) result = Map.from_geom(geom, data=np.nan) coords = m.geom.get_coord(sparse=True) result.set_by_coord(coords, vals=m.data) return result
def test_label_map_axis_coord_to_idx(): axis = LabelMapAxis(labels=["label-1", "label-2", "label-3"], name="label-axis") labels = "label-1" idx = axis.coord_to_idx(coord=labels) assert_allclose(idx, 0) labels = ["label-1", "label-3"] idx = axis.coord_to_idx(coord=labels) assert_allclose(idx, [0, 2]) labels = [["label-1"], ["label-2"]] idx = axis.coord_to_idx(coord=labels) assert_allclose(idx, [[0], [1]]) with pytest.raises(ValueError): labels = [["bad-label"], ["label-2"]] _ = axis.coord_to_idx(coord=labels)