def diffuse(model: ChunkGrid[bool], repeat=1):
    """
    Diffuse the voxels in model to their neighboring voxels
    :param model: the model to diffuse
    :param repeat: number of diffusion steps
    :return: diffused model
    """
    kernel = np.zeros((3, 3, 3), dtype=float)
    kernel[1] = 1
    kernel[:, 1] = 1
    kernel[:, :, 1] = 1
    kernel /= np.sum(kernel)

    result = ChunkGrid(model.chunk_size, dtype=float, fill_value=1.0)
    result[model] = 0.0
    result.pad_chunks(repeat // result.chunk_size + 1)

    for r in range(repeat):
        tmp = result.copy(empty=True)
        for chunk in result.chunks:
            padded = chunk.padding(result, 1)
            ndimage.convolve(padded,
                             kernel,
                             output=padded,
                             mode='constant',
                             cval=1.0)
            conv = padded[1:-1, 1:-1, 1:-1]
            m = model.ensure_chunk_at_index(chunk.index, insert=False)
            if m.is_filled():
                if m.value:
                    tmp.ensure_chunk_at_index(chunk.index).set_fill(0.0)
                    continue
            else:
                conv[m.to_array()] = 0.0
            tmp.ensure_chunk_at_index(chunk.index).set_array(conv)
            # Expand chunks
            for f, i in ChunkGrid.iter_neighbors_indices(chunk.index):
                tmp.ensure_chunk_at_index(i)

        result = tmp

    result.cleanup(remove=True)
    return result
def crust_fix(
        crust: ChunkGrid[np.bool8],
        outer_fill: ChunkGrid[np.bool8],
        crust_outer: ChunkGrid[np.bool8],
        crust_inner: ChunkGrid[np.bool8],
        min_distance: int = 1,
        data_pts: Optional[np.ndarray] = None,  # for plotting,
        export_path: Optional[str] = None):
    CHUNKSIZE = crust.chunk_size
    normal_kernel = make_normal_kernel()

    inv_outer_fill = ~outer_fill

    # Method cache (prevent lookup in loop)
    __grid_set_value = ChunkGrid.set_value
    __np_sum = np.sum

    print("\tCreate Normals: ")
    with timed("\t\tTime: "):
        # normal_zero = np.zeros(3, dtype=np.float32)
        normal_pos = np.array(list(crust_outer.where()))
        normal_val = np.full((len(normal_pos), 3), 0.0, dtype=np.float32)
        for n, p in enumerate(normal_pos):
            x, y, z = p
            mask: np.ndarray = outer_fill[x - 1:x + 2, y - 1:y + 2,
                                          z - 1:z + 2]
            normal_val[n] = __np_sum(normal_kernel[mask], axis=0)
        normal_val = (normal_val.T / np.linalg.norm(normal_val, axis=1)).T

    print("\tGrid Normals: ")
    with timed("\t\tTime: "):
        normals: ChunkGrid[np.float32] = ChunkGrid(
            CHUNKSIZE, np.dtype((np.float32, (3, ))), 0.0)
        normals[normal_pos] = normal_val

    print("\tRender Normal Propagation: ")
    with timed("\t\tTime: "):
        markers_outer = np.array([
            v for p, n in normals.items(mask=crust_outer)
            for v in (p, p + n, (np.nan, np.nan, np.nan))
        ],
                                 dtype=np.float32) + 0.5
        markers_outer_tips = np.array(
            [p + n for p, n in normals.items(mask=crust_outer)],
            dtype=np.float32) + 0.5

        ren = CloudRender()
        fig = ren.make_figure(title="Crust-Fix: Start Normal Propagation")
        fig.add_trace(
            ren.make_scatter(markers_outer,
                             marker=dict(opacity=0.5, ),
                             mode="lines",
                             name="Start normal"))
        fig.add_trace(
            ren.make_scatter(markers_outer_tips,
                             marker=dict(size=1, symbol='x'),
                             name="Start nromal end"))
        if data_pts is not None:
            fig.add_trace(
                ren.make_scatter(data_pts, opacity=0.1, size=1, name='Model'))
        if export_path:
            fig.write_html(os.path.join(export_path, "normal_start.html"))
        fig.show()

    print("\tNormal Propagation")
    with timed("\t\tTime: "):
        iterations = CHUNKSIZE * 2
        nfield = propagate_normals(iterations, normals, crust_outer,
                                   inv_outer_fill)
        field_reset_mask = outer_fill ^ crust_outer
        nfield[field_reset_mask] = 0
        nfield.cleanup(remove=True)

    # print("\tRender Normal Field: ")
    # with timed("\t\tTime: "):
    #
    #     markers_crust = np.array(
    #         [v for p, n in nfield.items(mask=crust) for v in (p, p + n, (np.nan, np.nan, np.nan))],
    #         dtype=np.float32) + 0.5
    #     markers_outer = np.array(
    #         [v for p, n in nfield.items(mask=crust_outer) for v in (p, p + n, (np.nan, np.nan, np.nan))],
    #         dtype=np.float32) + 0.5
    #     markers_outer_tips = np.array(
    #         [p + n for p, n in nfield.items(mask=crust_outer)],
    #         dtype=np.float32) + 0.5
    #
    #     ren = CloudRender()
    #     fig = ren.make_figure(title="Crust-Fix: Normal Field")
    #     fig.add_trace(ren.make_scatter(markers_outer, marker=dict(opacity=0.5, ), mode="lines", name="Start normal"))
    #     fig.add_trace(ren.make_scatter(markers_outer_tips, marker=dict(size=1, symbol='x'), name="Start normal end"))
    #     fig.add_trace(ren.make_scatter(markers_crust, marker=dict(opacity=0.5, ), mode="lines", name="Normal field"))
    #     if data_pts is not None:
    #         fig.add_trace(ren.make_scatter(data_pts, opacity=0.1, size=1, name='Model'))
    #     if export_path:
    #         fig.write_html(os.path.join(export_path, "normal_field.html"))
    #     fig.show()

    print("\tNormal cone: ")
    with timed("\t\tTime: "):
        medial = ChunkGrid(crust.chunk_size, np.bool8, False)
        cone_threshold: float = 0.5 * np.pi
        min_norm: float = 1e-15
        for chunk in nfield.chunks:
            padded = nfield.padding_at(chunk.index,
                                       1,
                                       corners=True,
                                       edges=True)
            cones = normal_cone_angles(padded, cone_threshold, min_norm)
            medial.ensure_chunk_at_index(chunk.index).set_array(cones.copy())
        medial.cleanup(remove=True)

    print("\tResult: ")
    with timed("\t\tTime: "):
        # Remove artifacts where the inner and outer crusts are touching
        artifacts_fix = outer_fill.copy().pad_chunks(1)
        artifacts_fix.fill_value = False
        artifacts_fix = ~dilate(artifacts_fix,
                                steps=max(0, min_distance) + 2) & ~outer_fill
        medial_cleaned = medial & artifacts_fix
        medial_cleaned.cleanup(remove=True)

    print("\tRender 2: ")
    with timed("\t\tTime: "):
        time.sleep(0.01)
        ren = VoxelRender()
        fig = ren.make_figure(title="Crust-Normal-Fix: Result before cleanup")
        print("Ren2-medial")
        fig.add_trace(ren.grid_voxel(medial, opacity=0.3, name='Medial'))
        # fig.add_trace(ren.grid_voxel(medial_cleaned, opacity=0.05, name='Fixed'))
        print("Ren2-crust_outer")
        fig.add_trace(ren.grid_voxel(crust_outer, opacity=0.05, name='Outer'))
        if data_pts is not None:
            print("Ren2-data_pts")
            fig.add_trace(CloudRender().make_scatter(data_pts,
                                                     opacity=0.2,
                                                     size=1,
                                                     name='Model'))
        print("Ren2-show")
        if export_path:
            fig.write_html(os.path.join(export_path, "medial.html"))
        fig.show()

    print("\tRender 3: ")
    with timed("\t\tTime: "):
        time.sleep(0.01)
        ren = VoxelRender()
        fig = ren.make_figure(title="Crust-Fix: Result after cleanup")
        # fig.add_trace(ren.grid_voxel(medial, opacity=0.3, name='Fixed'))
        print("Ren2-medial_cleaned")
        fig.add_trace(
            ren.grid_voxel(medial_cleaned, opacity=0.3, name='Medial-Cleaned'))
        print("Ren3-crust_outer")
        fig.add_trace(ren.grid_voxel(crust_outer, opacity=0.05, name='Outer'))
        if data_pts is not None:
            print("Ren3-data_pts")
            fig.add_trace(CloudRender().make_scatter(data_pts,
                                                     opacity=0.2,
                                                     size=1,
                                                     name='Model'))
        print("Ren3-show")
        if export_path:
            fig.write_html(os.path.join(export_path, "medial_cleaned.html"))
        fig.show()

    return medial_cleaned