Пример #1
0
 def test_projection_plot_wf(self):
     test_ds = fake_random_ds(16)
     for wf in WEIGHT_FIELDS:
         proj = ProjectionPlot(test_ds,
                               0, ("gas", "density"),
                               weight_field=wf)
         proj.save()
def test_plot_data():
    tmpdir = tempfile.mkdtemp()
    curdir = os.getcwd()
    os.chdir(tmpdir)
    ds = fake_random_ds(16)

    plot = SlicePlot(ds, 'z', 'density')
    plot.data_source.save_as_dataset('slice.h5')
    ds_slice = load('slice.h5')
    p = SlicePlot(ds_slice, 'z', 'density')
    fn = p.save()
    assert_fname(fn[0])

    plot = ProjectionPlot(ds, 'z', 'density')
    plot.data_source.save_as_dataset('proj.h5')
    ds_proj = load('slice.h5')
    p = ProjectionPlot(ds_proj, 'z', 'density')
    fn = p.save()
    assert_fname(fn[0])

    plot = SlicePlot(ds, [1, 1, 1], 'density')
    plot.data_source.save_as_dataset('oas.h5')
    ds_oas = load('oas.h5')
    p = SlicePlot(ds_oas, [1, 1, 1], 'density')
    fn = p.save()
    assert_fname(fn[0])

    os.chdir(curdir)
    shutil.rmtree(tmpdir)
Пример #3
0
 def test_projection_plot_c(self):
     test_ds = fake_random_ds(16)
     for center in CENTER_SPECS:
         proj = ProjectionPlot(test_ds,
                               0, ("gas", "density"),
                               center=center)
         proj.save()
Пример #4
0
def test_field_access():
    ds = fake_random_ds(16)

    ad = ds.all_data()
    sp = ds.sphere(ds.domain_center, 0.25)
    cg = ds.covering_grid(0, ds.domain_left_edge, ds.domain_dimensions)
    scg = ds.smoothed_covering_grid(0, ds.domain_left_edge,
                                    ds.domain_dimensions)
    sl = ds.slice(0, ds.domain_center[0])
    proj = ds.proj(("gas", "density"), 0)
    prof = create_profile(ad, ("index", "radius"), ("gas", "density"))

    for data_object in [ad, sp, cg, scg, sl, proj, prof]:
        assert_equal(data_object["gas", "density"],
                     data_object[ds.fields.gas.density])

    for field in [("gas", "density"), ds.fields.gas.density]:
        ad = ds.all_data()
        prof = ProfilePlot(ad, ("index", "radius"), field)
        phase = PhasePlot(ad, ("index", "radius"), field, ("gas", "cell_mass"))
        s = SlicePlot(ds, 2, field)
        oas = SlicePlot(ds, [1, 1, 1], field)
        p = ProjectionPlot(ds, 2, field)
        oap = ProjectionPlot(ds, [1, 1, 1], field)

        for plot_object in [s, oas, p, oap, prof, phase]:
            plot_object._setup_plots()
            if hasattr(plot_object, "_frb"):
                plot_object._frb[field]
Пример #5
0
 def test_projection_plot_m(self):
     test_ds = fake_random_ds(16)
     for method in PROJECTION_METHODS:
         proj = ProjectionPlot(test_ds,
                               0, ("gas", "density"),
                               method=method)
         proj.save()
Пример #6
0
def test_plot_data():
    tmpdir = tempfile.mkdtemp()
    curdir = os.getcwd()
    os.chdir(tmpdir)
    ds = fake_random_ds(16)

    plot = SlicePlot(ds, "z", ("gas", "density"))
    fn = plot.data_source.save_as_dataset("slice.h5")
    ds_slice = load(fn)
    p = SlicePlot(ds_slice, "z", ("gas", "density"))
    fn = p.save()
    assert_fname(fn[0])

    plot = ProjectionPlot(ds, "z", ("gas", "density"))
    fn = plot.data_source.save_as_dataset("proj.h5")
    ds_proj = load(fn)
    p = ProjectionPlot(ds_proj, "z", ("gas", "density"))
    fn = p.save()
    assert_fname(fn[0])

    plot = SlicePlot(ds, [1, 1, 1], ("gas", "density"))
    fn = plot.data_source.save_as_dataset("oas.h5")
    ds_oas = load(fn)
    p = SlicePlot(ds_oas, [1, 1, 1], ("gas", "density"))
    fn = p.save()
    assert_fname(fn[0])

    os.chdir(curdir)
    if tmpdir != ".":
        shutil.rmtree(tmpdir)
Пример #7
0
 def test_projection_plot_ds(self):
     test_ds = fake_random_ds(16)
     reg = test_ds.region([0.5] * 3, [0.4] * 3, [0.6] * 3)
     for dim in range(3):
         proj = ProjectionPlot(test_ds,
                               dim, ("gas", "density"),
                               data_source=reg)
         proj.save()
Пример #8
0
def test_dispatch_plot_classes():
    ds = fake_random_ds(16)
    p1 = ProjectionPlot(ds, "z", ("gas", "density"))
    p2 = ProjectionPlot(ds, (1, 2, 3), ("gas", "density"))
    s1 = SlicePlot(ds, "z", ("gas", "density"))
    s2 = SlicePlot(ds, (1, 2, 3), ("gas", "density"))
    assert isinstance(p1, AxisAlignedProjectionPlot)
    assert isinstance(p2, OffAxisProjectionPlot)
    assert isinstance(s1, AxisAlignedSlicePlot)
    assert isinstance(s2, OffAxisSlicePlot)
Пример #9
0
    def setUp(self):
        from yt.config import ytcfg

        newConfig = {
            ("yt", "default_colormap"): "viridis",
            ("plot", "gas", "log"): False,
            ("plot", "gas", "density", "units"): "lb/yard**3",
            ("plot", "gas", "density", "path_length_units"): "mile",
            ("plot", "gas", "density", "cmap"): "plasma",
            ("plot", "gas", "temperature", "log"): True,
            ("plot", "gas", "temperature", "linthresh"): 100,
            ("plot", "gas", "temperature", "cmap"): "hot",
            ("plot", "gas", "pressure", "log"): True,
            ("plot", "index", "radius", "linthresh"): 1e3,
        }
        # Backup the old config
        oldConfig = {}
        for key in newConfig.keys():
            try:
                val = ytcfg[key]
                oldConfig[key] = val
            except KeyError:
                pass
        for key, val in newConfig.items():
            ytcfg[key] = val

        self.oldConfig = oldConfig
        self.newConfig = newConfig

        fields = [("gas", "density"), ("gas", "temperature"), ("gas", "pressure")]
        units = ["g/cm**3", "K", "dyn/cm**2"]
        fields_to_plot = fields + [("index", "radius")]
        if self.ds is None:
            self.ds = fake_random_ds(16, fields=fields, units=units)
            self.slc = ProjectionPlot(self.ds, 0, fields_to_plot)
Пример #10
0
 def __call__(self, args):
     if sys.version_info >= (3, 0, 0):
         print("yt mapserver is disabled for Python 3.")
         return -1
     ds = args.ds
     if args.axis == 4:
         print("Doesn't work with multiple axes!")
         return
     if args.projection:
         p = ProjectionPlot(ds,
                            args.axis,
                            args.field,
                            weight_field=args.weight)
     else:
         p = SlicePlot(ds, args.axis, args.field)
     from yt.visualization.mapserver.pannable_map import PannableMapServer
     PannableMapServer(p.data_source, args.field)
     import yt.extern.bottle as bottle
     bottle.debug(True)
     bottle_dir = os.path.dirname(bottle.__file__)
     sys.path.append(bottle_dir)
     if args.host is not None:
         colonpl = args.host.find(":")
         if colonpl >= 0:
             port = int(args.host.split(":")[-1])
             args.host = args.host[:colonpl]
         else:
             port = 8080
         bottle.run(server='rocket', host=args.host, port=port)
     else:
         bottle.run(server='rocket')
     sys.path.remove(bottle_dir)
Пример #11
0
def test_old_plot_data():
    tmpdir = tempfile.mkdtemp()
    curdir = os.getcwd()
    os.chdir(tmpdir)

    fn = "slice.h5"
    full_fn = os.path.join(ytdata_dir, fn)
    ds_slice = data_dir_load(full_fn)
    p = SlicePlot(ds_slice, "z", ("gas", "density"))
    fn = p.save()
    assert_fname(fn[0])

    fn = "proj.h5"
    full_fn = os.path.join(ytdata_dir, fn)
    ds_proj = data_dir_load(full_fn)
    p = ProjectionPlot(ds_proj, "z", ("gas", "density"))
    fn = p.save()
    assert_fname(fn[0])

    fn = "oas.h5"
    full_fn = os.path.join(ytdata_dir, fn)
    ds_oas = data_dir_load(full_fn)
    p = SlicePlot(ds_oas, [1, 1, 1], ("gas", "density"))
    fn = p.save()
    assert_fname(fn[0])

    os.chdir(curdir)
    shutil.rmtree(tmpdir)
Пример #12
0
    def test_projection_plot_bs(self):
        test_ds = fake_random_ds(16)
        for bf in BUFF_SIZES:
            proj = ProjectionPlot(test_ds, 0, ("gas", "density"), buff_size=bf)
            image = proj.frb["gas", "density"]

            # note that image.shape is inverted relative to the passed in buff_size
            assert_equal(image.shape[::-1], bf)
Пример #13
0
def test_sph_particle_filter_plotting():
    ds = fake_sph_grid_ds()

    @particle_filter("central_gas", requires=["particle_position"], filtered_type="io")
    def _filter(pfilter, data):
        coords = np.abs(data[pfilter.filtered_type, "particle_position"])
        return (
            (coords[:, 0] < 1.6) & (coords[:, 1] < 1.6) & (coords[:, 2] < 1.6))

    ds.add_particle_filter("central_gas")

    plot = ProjectionPlot(ds, 'z', ('central_gas', 'density'))
    tmpdir = tempfile.mkdtemp()
    curdir = os.getcwd()
    os.chdir(tmpdir)

    plot.save()

    os.chdir(curdir)
    shutil.rmtree(tmpdir)
Пример #14
0
def test_ds_arr_invariance_under_projection_plot(tmp_path):
    data_array = np.random.random((10, 10, 10))
    bbox = np.array([[-100, 100], [-100, 100], [-100, 100]])
    data = {("gas", "density"): (data_array, "g*cm**(-3)")}
    ds = load_uniform_grid(data,
                           data_array.shape,
                           length_unit="kpc",
                           bbox=bbox)

    start_source = np.array((0, 0, -0.5))
    end_source = np.array((0, 0, 0.5))
    start = ds.arr(start_source, "unitary")
    end = ds.arr(end_source, "unitary")

    start_i = start.copy()
    end_i = end.copy()

    p = ProjectionPlot(ds, 0, "number_density")
    p.annotate_line(start, end)
    p.save(tmp_path)

    # for lack of a unyt.testing.assert_unit_array_equal function
    np.testing.assert_array_equal(start_i, start)
    assert start_i.units == start.units
    np.testing.assert_array_equal(end_i, end)
    assert end_i.units == end.units
Пример #15
0
    def test_creation_with_width(self):
        test_ds = fake_random_ds(16)
        for width in WIDTH_SPECS:
            xlim, ylim, pwidth, aun = WIDTH_SPECS[width]
            plot = ProjectionPlot(test_ds, 0, ("gas", "density"), width=width)

            xlim = [plot.ds.quan(el[0], el[1]) for el in xlim]
            ylim = [plot.ds.quan(el[0], el[1]) for el in ylim]
            pwidth = [plot.ds.quan(el[0], el[1]) for el in pwidth]

            [assert_array_almost_equal(px, x, 14) for px, x in zip(plot.xlim, xlim)]
            [assert_array_almost_equal(py, y, 14) for py, y in zip(plot.ylim, ylim)]
            [assert_array_almost_equal(pw, w, 14) for pw, w in zip(plot.width, pwidth)]
            assert_true(aun == plot._axes_unit_names)
Пример #16
0
    def __call__(self, args):
        ds = args.ds
        center = args.center
        if args.center == (-1, -1, -1):
            mylog.info("No center fed in; seeking.")
            v, center = ds.find_max("density")
        if args.max:
            v, center = ds.find_max("density")
        elif args.center is None:
            center = 0.5 * (ds.domain_left_edge + ds.domain_right_edge)
        center = np.array(center)
        if ds.dimensionality < 3:
            dummy_dimensions = np.nonzero(
                ds.index.grids[0].ActiveDimensions <= 1)
            axes = ensure_list(dummy_dimensions[0][0])
        elif args.axis == 4:
            axes = range(3)
        else:
            axes = [args.axis]

        unit = args.unit
        if unit is None:
            unit = 'unitary'
        if args.width is None:
            width = None
        else:
            width = (args.width, args.unit)

        for ax in axes:
            mylog.info("Adding plot for axis %i", ax)
            if args.projection:
                plt = ProjectionPlot(ds,
                                     ax,
                                     args.field,
                                     center=center,
                                     width=width,
                                     weight_field=args.weight)
            else:
                plt = SlicePlot(ds, ax, args.field, center=center, width=width)
            if args.grids:
                plt.annotate_grids()
            if args.time:
                time = ds.current_time.in_units("yr")
                plt.annotate_text((0.2, 0.8), 't = %5.2e yr' % time)

            plt.set_cmap(args.field, args.cmap)
            plt.set_log(args.field, args.takelog)
            if args.zlim:
                plt.set_zlim(args.field, *args.zlim)
            ensure_dir_exists(args.output)
            plt.save(os.path.join(args.output, "%s" % (ds)))
Пример #17
0
    def __call__(self, args):
        from yt.frontends.ramses.data_structures import RAMSESDataset
        from yt.visualization.mapserver.pannable_map import PannableMapServer

        # For RAMSES datasets, use the bbox feature to make the dataset load faster
        if RAMSESDataset._is_valid(args.ds) and args.center and args.width:
            kwa = dict(bbox=[
                [c - args.width / 2 for c in args.center],
                [c + args.width / 2 for c in args.center],
            ])
        else:
            kwa = dict()

        ds = _fix_ds(args.ds, **kwa)
        if args.center and args.width:
            center = args.center
            width = args.width
            ad = ds.box(
                left_edge=[c - args.width / 2 for c in args.center],
                right_edge=[c + args.width / 2 for c in args.center],
            )
        else:
            center = [0.5] * 3
            width = 1.0
            ad = ds.all_data()

        if args.axis >= 4:
            print("Doesn't work with multiple axes!")
            return
        if args.projection:
            p = ProjectionPlot(
                ds,
                args.axis,
                args.field,
                weight_field=args.weight,
                data_source=ad,
                center=center,
                width=width,
            )
        else:
            p = SlicePlot(ds,
                          args.axis,
                          args.field,
                          data_source=ad,
                          center=center,
                          width=width)
        p.set_log("all", args.takelog)
        p.set_cmap("all", args.cmap)

        PannableMapServer(p.data_source, args.field, args.takelog, args.cmap)
        try:
            import bottle
        except ImportError as e:
            raise ImportError(
                "The mapserver functionality requires the bottle "
                "package to be installed. Please install using `pip "
                "install bottle`.") from e
        bottle.debug(True)
        if args.host is not None:
            colonpl = args.host.find(":")
            if colonpl >= 0:
                port = int(args.host.split(":")[-1])
                args.host = args.host[:colonpl]
            else:
                port = 8080
            bottle.run(server="auto", host=args.host, port=port)
        else:
            bottle.run(server="auto")
Пример #18
0
    def __call__(self, args):
        ds = args.ds
        center = args.center
        if args.center == (-1, -1, -1):
            mylog.info("No center fed in; seeking.")
            v, center = ds.find_max("density")
        if args.max:
            v, center = ds.find_max("density")
        elif args.center is None:
            center = 0.5 * (ds.domain_left_edge + ds.domain_right_edge)
        center = np.array(center)
        if ds.dimensionality < 3:
            dummy_dimensions = np.nonzero(
                ds.index.grids[0].ActiveDimensions <= 1)
            axes = dummy_dimensions[0][0]
        elif args.axis == 4:
            axes = range(3)
        else:
            axes = args.axis

        unit = args.unit
        if unit is None:
            unit = "unitary"
        if args.width is None:
            width = None
        else:
            width = (args.width, args.unit)

        for ax in always_iterable(axes):
            mylog.info("Adding plot for axis %i", ax)
            if args.projection:
                plt = ProjectionPlot(
                    ds,
                    ax,
                    args.field,
                    center=center,
                    width=width,
                    weight_field=args.weight,
                )
            else:
                plt = SlicePlot(ds, ax, args.field, center=center, width=width)
            if args.grids:
                plt.annotate_grids()
            if args.time:
                plt.annotate_timestamp()
            if args.show_scale_bar:
                plt.annotate_scale()

            if args.field_unit:
                plt.set_unit(args.field, args.field_unit)

            plt.set_cmap(args.field, args.cmap)
            plt.set_log(args.field, args.takelog)
            if args.zlim:
                plt.set_zlim(args.field, *args.zlim)
            ensure_dir_exists(args.output)
            plt.save(os.path.join(args.output, f"{ds}"))
Пример #19
0
 def test_projection_plot(self):
     test_ds = fake_random_ds(16)
     for dim in range(3):
         proj = ProjectionPlot(test_ds, dim, ("gas", "density"))
         for fname in TEST_FLNMS:
             assert_fname(proj.save(fname)[0])