Пример #1
0
 def test_projection_plot_wf(self):
     test_ds = fake_random_ds(16)
     for wf in WEIGHT_FIELDS:
         proj = ProjectionPlot(test_ds,
                               0, ("gas", "density"),
                               weight_field=wf)
         proj.save()
Пример #2
0
 def test_projection_plot_m(self):
     test_ds = fake_random_ds(16)
     for method in PROJECTION_METHODS:
         proj = ProjectionPlot(test_ds,
                               0, ("gas", "density"),
                               method=method)
         proj.save()
Пример #3
0
def test_ds_arr_invariance_under_projection_plot(tmp_path):
    data_array = np.random.random((10, 10, 10))
    bbox = np.array([[-100, 100], [-100, 100], [-100, 100]])
    data = {("gas", "density"): (data_array, "g*cm**(-3)")}
    ds = load_uniform_grid(data,
                           data_array.shape,
                           length_unit="kpc",
                           bbox=bbox)

    start_source = np.array((0, 0, -0.5))
    end_source = np.array((0, 0, 0.5))
    start = ds.arr(start_source, "unitary")
    end = ds.arr(end_source, "unitary")

    start_i = start.copy()
    end_i = end.copy()

    p = ProjectionPlot(ds, 0, "number_density")
    p.annotate_line(start, end)
    p.save(tmp_path)

    # for lack of a unyt.testing.assert_unit_array_equal function
    np.testing.assert_array_equal(start_i, start)
    assert start_i.units == start.units
    np.testing.assert_array_equal(end_i, end)
    assert end_i.units == end.units
Пример #4
0
 def test_projection_plot_c(self):
     test_ds = fake_random_ds(16)
     for center in CENTER_SPECS:
         proj = ProjectionPlot(test_ds,
                               0, ("gas", "density"),
                               center=center)
         proj.save()
Пример #5
0
 def test_projection_plot_ds(self):
     test_ds = fake_random_ds(16)
     reg = test_ds.region([0.5] * 3, [0.4] * 3, [0.6] * 3)
     for dim in range(3):
         proj = ProjectionPlot(test_ds,
                               dim, ("gas", "density"),
                               data_source=reg)
         proj.save()
Пример #6
0
    def __call__(self, args):
        ds = args.ds
        center = args.center
        if args.center == (-1, -1, -1):
            mylog.info("No center fed in; seeking.")
            v, center = ds.find_max("density")
        if args.max:
            v, center = ds.find_max("density")
        elif args.center is None:
            center = 0.5 * (ds.domain_left_edge + ds.domain_right_edge)
        center = np.array(center)
        if ds.dimensionality < 3:
            dummy_dimensions = np.nonzero(
                ds.index.grids[0].ActiveDimensions <= 1)
            axes = dummy_dimensions[0][0]
        elif args.axis == 4:
            axes = range(3)
        else:
            axes = args.axis

        unit = args.unit
        if unit is None:
            unit = "unitary"
        if args.width is None:
            width = None
        else:
            width = (args.width, args.unit)

        for ax in always_iterable(axes):
            mylog.info("Adding plot for axis %i", ax)
            if args.projection:
                plt = ProjectionPlot(
                    ds,
                    ax,
                    args.field,
                    center=center,
                    width=width,
                    weight_field=args.weight,
                )
            else:
                plt = SlicePlot(ds, ax, args.field, center=center, width=width)
            if args.grids:
                plt.annotate_grids()
            if args.time:
                plt.annotate_timestamp()
            if args.show_scale_bar:
                plt.annotate_scale()

            if args.field_unit:
                plt.set_unit(args.field, args.field_unit)

            plt.set_cmap(args.field, args.cmap)
            plt.set_log(args.field, args.takelog)
            if args.zlim:
                plt.set_zlim(args.field, *args.zlim)
            ensure_dir_exists(args.output)
            plt.save(os.path.join(args.output, f"{ds}"))
Пример #7
0
    def __call__(self, args):
        ds = args.ds
        center = args.center
        if args.center == (-1, -1, -1):
            mylog.info("No center fed in; seeking.")
            v, center = ds.find_max("density")
        if args.max:
            v, center = ds.find_max("density")
        elif args.center is None:
            center = 0.5 * (ds.domain_left_edge + ds.domain_right_edge)
        center = np.array(center)
        if ds.dimensionality < 3:
            dummy_dimensions = np.nonzero(
                ds.index.grids[0].ActiveDimensions <= 1)
            axes = ensure_list(dummy_dimensions[0][0])
        elif args.axis == 4:
            axes = range(3)
        else:
            axes = [args.axis]

        unit = args.unit
        if unit is None:
            unit = 'unitary'
        if args.width is None:
            width = None
        else:
            width = (args.width, args.unit)

        for ax in axes:
            mylog.info("Adding plot for axis %i", ax)
            if args.projection:
                plt = ProjectionPlot(ds,
                                     ax,
                                     args.field,
                                     center=center,
                                     width=width,
                                     weight_field=args.weight)
            else:
                plt = SlicePlot(ds, ax, args.field, center=center, width=width)
            if args.grids:
                plt.annotate_grids()
            if args.time:
                time = ds.current_time.in_units("yr")
                plt.annotate_text((0.2, 0.8), 't = %5.2e yr' % time)

            plt.set_cmap(args.field, args.cmap)
            plt.set_log(args.field, args.takelog)
            if args.zlim:
                plt.set_zlim(args.field, *args.zlim)
            ensure_dir_exists(args.output)
            plt.save(os.path.join(args.output, "%s" % (ds)))
Пример #8
0
def test_sph_particle_filter_plotting():
    ds = fake_sph_grid_ds()

    @particle_filter("central_gas", requires=["particle_position"], filtered_type="io")
    def _filter(pfilter, data):
        coords = np.abs(data[pfilter.filtered_type, "particle_position"])
        return (
            (coords[:, 0] < 1.6) & (coords[:, 1] < 1.6) & (coords[:, 2] < 1.6))

    ds.add_particle_filter("central_gas")

    plot = ProjectionPlot(ds, 'z', ('central_gas', 'density'))
    tmpdir = tempfile.mkdtemp()
    curdir = os.getcwd()
    os.chdir(tmpdir)

    plot.save()

    os.chdir(curdir)
    shutil.rmtree(tmpdir)
Пример #9
0
 def test_projection_plot(self):
     test_ds = fake_random_ds(16)
     for dim in range(3):
         proj = ProjectionPlot(test_ds, dim, ("gas", "density"))
         for fname in TEST_FLNMS:
             assert_fname(proj.save(fname)[0])