예제 #1
0
 def particle_list_filter(self):
     from boxtree.tree import ParticleListFilter
     return ParticleListFilter(self.array_context.context)
예제 #2
0
def test_fmm_completeness(ctx_getter, dims, nsources_req, ntargets_req,
                          who_has_extent, source_gen, target_gen, filter_kind,
                          well_sep_is_n_away, extent_norm,
                          from_sep_smaller_crit):
    """Tests whether the built FMM traversal structures and driver completely
    capture all interactions.
    """

    sources_have_extent = "s" in who_has_extent
    targets_have_extent = "t" in who_has_extent

    logging.basicConfig(level=logging.INFO)

    ctx = ctx_getter()
    queue = cl.CommandQueue(ctx)

    dtype = np.float64

    try:
        sources = source_gen(queue, nsources_req, dims, dtype, seed=15)
        nsources = len(sources[0])

        if ntargets_req is None:
            # This says "same as sources" to the tree builder.
            targets = None
            ntargets = ntargets_req
        else:
            targets = target_gen(queue, ntargets_req, dims, dtype, seed=16)
            ntargets = len(targets[0])
    except ImportError:
        pytest.skip("loo.py not available, but needed for particle array "
                    "generation")

    from pyopencl.clrandom import PhiloxGenerator
    rng = PhiloxGenerator(queue.context, seed=12)
    if sources_have_extent:
        source_radii = 2**rng.uniform(queue, nsources, dtype=dtype, a=-10, b=0)
    else:
        source_radii = None

    if targets_have_extent:
        target_radii = 2**rng.uniform(queue, ntargets, dtype=dtype, a=-10, b=0)
    else:
        target_radii = None

    from boxtree import TreeBuilder
    tb = TreeBuilder(ctx)

    tree, _ = tb(queue,
                 sources,
                 targets=targets,
                 max_particles_in_box=30,
                 source_radii=source_radii,
                 target_radii=target_radii,
                 debug=True,
                 stick_out_factor=0.25,
                 extent_norm=extent_norm)
    if 0:
        tree.get().plot()
        import matplotlib.pyplot as pt
        pt.show()

    from boxtree.traversal import FMMTraversalBuilder
    tbuild = FMMTraversalBuilder(ctx,
                                 well_sep_is_n_away=well_sep_is_n_away,
                                 from_sep_smaller_crit=from_sep_smaller_crit)
    trav, _ = tbuild(queue, tree, debug=True)

    if who_has_extent:
        pre_merge_trav = trav
        trav = trav.merge_close_lists(queue)

    #weights = np.random.randn(nsources)
    weights = np.ones(nsources)
    weights_sum = np.sum(weights)

    host_trav = trav.get(queue=queue)
    host_tree = host_trav.tree

    if who_has_extent:
        pre_merge_host_trav = pre_merge_trav.get(queue=queue)

    from boxtree.tree import ParticleListFilter
    plfilt = ParticleListFilter(ctx)

    if filter_kind:
        flags = rng.uniform(queue, ntargets or nsources, np.int32, a=0, b=2) \
                .astype(np.int8)
        if filter_kind == "user":
            filtered_targets = plfilt.filter_target_lists_in_user_order(
                queue, tree, flags)
            wrangler = ConstantOneExpansionWranglerWithFilteredTargetsInUserOrder(
                host_tree, filtered_targets.get(queue=queue))
        elif filter_kind == "tree":
            filtered_targets = plfilt.filter_target_lists_in_tree_order(
                queue, tree, flags)
            wrangler = ConstantOneExpansionWranglerWithFilteredTargetsInTreeOrder(
                host_tree, filtered_targets.get(queue=queue))
        else:
            raise ValueError("unsupported value of 'filter_kind'")
    else:
        wrangler = ConstantOneExpansionWrangler(host_tree)
        flags = cl.array.empty(queue, ntargets or nsources, dtype=np.int8)
        flags.fill(1)

    if ntargets is None and not filter_kind:
        # This check only works for targets == sources.
        assert (wrangler.reorder_potentials(
            wrangler.reorder_sources(weights)) == weights).all()

    from boxtree.fmm import drive_fmm
    pot = drive_fmm(host_trav, wrangler, weights)

    if filter_kind:
        pot = pot[flags.get() > 0]

    rel_err = la.norm((pot - weights_sum) / nsources)
    good = rel_err < 1e-8

    # {{{ build, evaluate matrix (and identify incorrect interactions)

    if 0 and not good:
        mat = np.zeros((ntargets, nsources), dtype)
        from pytools import ProgressBar

        logging.getLogger().setLevel(logging.WARNING)

        pb = ProgressBar("matrix", nsources)
        for i in range(nsources):
            unit_vec = np.zeros(nsources, dtype=dtype)
            unit_vec[i] = 1
            mat[:, i] = drive_fmm(host_trav, wrangler, unit_vec)
            pb.progress()
        pb.finished()

        logging.getLogger().setLevel(logging.INFO)

        import matplotlib.pyplot as pt

        if 0:
            pt.imshow(mat)
            pt.colorbar()
            pt.show()

        incorrect_tgts, incorrect_srcs = np.where(mat != 1)

        if 1 and len(incorrect_tgts):
            from boxtree.visualization import TreePlotter
            plotter = TreePlotter(host_tree)
            plotter.draw_tree(fill=False, edgecolor="black")
            plotter.draw_box_numbers()
            plotter.set_bounding_box()

            tree_order_incorrect_tgts = \
                    host_tree.indices_to_tree_target_order(incorrect_tgts)
            tree_order_incorrect_srcs = \
                    host_tree.indices_to_tree_source_order(incorrect_srcs)

            src_boxes = [
                host_tree.find_box_nr_for_source(i)
                for i in tree_order_incorrect_srcs
            ]
            tgt_boxes = [
                host_tree.find_box_nr_for_target(i)
                for i in tree_order_incorrect_tgts
            ]
            print(src_boxes)
            print(tgt_boxes)

            # plot all sources/targets
            if 0:
                pt.plot(host_tree.targets[0],
                        host_tree.targets[1],
                        "v",
                        alpha=0.9)
                pt.plot(host_tree.sources[0],
                        host_tree.sources[1],
                        "gx",
                        alpha=0.9)

            # plot offending sources/targets
            if 0:
                pt.plot(host_tree.targets[0][tree_order_incorrect_tgts],
                        host_tree.targets[1][tree_order_incorrect_tgts], "rv")
                pt.plot(host_tree.sources[0][tree_order_incorrect_srcs],
                        host_tree.sources[1][tree_order_incorrect_srcs], "go")
            pt.gca().set_aspect("equal")

            from boxtree.visualization import draw_box_lists
            draw_box_lists(
                plotter, pre_merge_host_trav if who_has_extent else host_trav,
                22)
            # from boxtree.visualization import draw_same_level_non_well_sep_boxes
            # draw_same_level_non_well_sep_boxes(plotter, host_trav, 2)

            pt.show()

    # }}}

    if 0 and not good:
        import matplotlib.pyplot as pt
        pt.plot(pot - weights_sum)
        pt.show()

    if 0 and not good:
        import matplotlib.pyplot as pt
        filt_targets = [
            host_tree.targets[0][flags.get() > 0],
            host_tree.targets[1][flags.get() > 0],
        ]
        host_tree.plot()
        bad = np.abs(pot - weights_sum) >= 1e-3
        bad_targets = [
            filt_targets[0][bad],
            filt_targets[1][bad],
        ]
        print(bad_targets[0].shape)
        pt.plot(filt_targets[0], filt_targets[1], "x")
        pt.plot(bad_targets[0], bad_targets[1], "v")
        pt.show()

    assert good