def vis(arr, height_func, crit_pts, regions, step=True, new_fig=False):
    filled_arr = arr.copy()
    def hf((a,b)): return height_func(a), height_func(b)
    for region in sorted(regions, key=hf, reverse=True):
        if step:
            visualize(filled_arr, crit_pts=crit_pts, ncontours=None, cmap='gray', new_fig=new_fig)
            raw_input("enter to continue")
        X = [_[0] for _ in regions[region]]
        Y = [_[1] for _ in regions[region]]
        filled_arr[X,Y] = 2*arr.max()
    if not step:
        visualize(filled_arr, crit_pts=crit_pts, ncontours=None, cmap='gray', new_fig=True)
        raw_input("enter to continue")
def test_contour_tree():
    ctr = 0#{{{
    flux_tube_areas_record = []
    for bx_arr, by_arr, psi_arr in h5_gen('data.h5', ('bx', 'by', 'psi')):
        b_mag = np.sqrt(bx_arr**2 + by_arr**2)
        logger("array memory size: %d" % psi_arr.nbytes)
        def height_func(n):
            return (psi_arr[n], n)
        logger("meshing array...")
        mesh = ct.make_mesh(psi_arr)
        logger("done")
        logger("mesh memory size: %d" % total_graph_memory(mesh))

        logger("computing contour tree...")
        c_tree = ct.contour_tree(mesh, height_func)
        logger("done")

        def region_func(r):
            return len(r)

        logger("pruning small regions...")
        ct.prune_regions(
                c_tree,
                region_func=region_func,
                threshold=3,
                )
        logger("done")

        logger("computing critical points...")
        cpts = ct.critical_points(c_tree)
        logger("done")

        logger("condensing small regions...")
        ct.prune_regions(
                c_tree,
                region_func=region_func,
                threshold=psi_arr.size/200,
                )
        logger("done")

        logger("computing regions...")
        regions = ct.get_regions(c_tree)
        logger("done")

        flux_tubes = ct.get_flux_tubes(c_tree)
        flux_tube_areas = np.array([len(ft) for (e, ft) in flux_tubes], dtype=np.int64)
        flux_tube_areas_record.append(flux_tube_areas)

        pickle.dump(flux_tube_areas_record, open('flux_tube_areas.dat', 'wb'))

        flux_tube_mask = np.zeros(psi_arr.shape, dtype=np.bool_)
        for (e, ft) in flux_tubes:
            for p in ft:
                flux_tube_mask[p] = True

        peaks = cpts.peaks
        passes = cpts.passes
        pits = cpts.pits
        all_cpts = peaks.union(passes).union(pits)

        logger("peaks + pits - passes = %d" % (len(peaks) + len(pits) - len(passes)))
        logger("len(crit_pts) = %d" % (len(peaks) + len(pits) + len(passes)))

        cpt_grads = [b_mag[cpt] / b_mag.min() for cpt in all_cpts]
        logger("b_mag.mean()=%f" % b_mag.mean())
        logger("b_mag.std()=%f" % b_mag.std())
        logger("b_mag.max()=%f" % b_mag.max())
        logger("b_mag.min()=%f" % b_mag.min())
        if 1:
            pl.figure()
            pl.hist(cpt_grads, bins=pl.sqrt(len(cpt_grads)))
            pl.title("cpoint gradient values")
            pl.figure()
            pl.hist(flux_tube_areas, bins=pl.sqrt(len(flux_tube_areas)))
            pl.title("flux tube areas")
            pl.figure()
            filled_arr = psi_arr.copy()
            filled_arr[flux_tube_mask] = 2*psi_arr.max()
            visualize(filled_arr, crit_pts=cpts, ncontours=None, cmap='hot', new_fig=False, save_fig='psi_flux_tubes%03d' % ctr)
            pl.figure()
            filled_arr = b_mag.copy()
            filled_arr[flux_tube_mask] = 2*b_mag.max()
            visualize(filled_arr, crit_pts=cpts, ncontours=None, cmap='hot', new_fig=False, save_fig='b_mag_flux_tubes%03d' % ctr)
            pl.close('all')
            if 0:
                filled_arr = psi_arr.copy()
                def hf((a,b)): return height_func(a), height_func(b)
                ctr = 0
                for region in sorted(regions, key=hf, reverse=True):
                    X = [_[0] for _ in regions[region]]
                    Y = [_[1] for _ in regions[region]]
                    filled_arr[X,Y] = 2*psi_arr.max()
                    if not ctr:
                        visualize(filled_arr, crit_pts=cpts, ncontours=None, cmap='gray', new_fig=False)
                        raw_input("enter to continue")
                    ctr += 1
                    ctr %= len(regions) / 20
                visualize(filled_arr, crit_pts=cpts, ncontours=None, cmap='gray', new_fig=False)
                raw_input("enter to continue")
                pl.close('all')
        ctr += 1

        del c_tree
        del regions
        del cpts#}}}