def solve_block_subproblem(block_id, graph, block_prefix, costs, agglomerator,
                           shape, block_shape, cut_outer_edges):
    # load the nodes in this sub-block and map them
    # to our current node-labeling
    block_path = block_prefix + str(block_id)
    assert os.path.exists(block_path), block_path
    nodes = ndist.loadNodes(block_path)

    # # the ignore label (== 0) spans a lot of blocks, hence it would slow down our
    # # subgraph extraction, which looks at all the blocks containing the node,
    # # enormously, so we skip it
    # # we make sure that these are cut later
    # if nodes[0] == 0:
    #     nodes = nodes[1:]

    # # if we have no nodes left after, we return none
    # if len(nodes) == 0:
    #     return None

    # # extract the local subgraph
    # inner_edges, outer_edges, sub_uvs = ndist.extractSubgraphFromNodes(nodes,
    #                                                                    block_prefix,
    #                                                                    shape,
    #                                                                    block_shape,
    #                                                                    block_id)
    inner_edges, outer_edges, sub_uvs = graph.extractSubgraphFromNodes(nodes)

    # if we had only a single node (i.e. no edge, return the outer edges)
    if len(nodes) == 1:
        return outer_edges if cut_outer_edges else None

    assert len(sub_uvs) == len(inner_edges)
    assert len(sub_uvs) > 0, str(block_id)

    n_local_nodes = int(sub_uvs.max() + 1)
    sub_graph = undirectedGraph(n_local_nodes)
    sub_graph.insertEdges(sub_uvs)

    sub_costs = costs[inner_edges]
    assert len(sub_costs) == sub_graph.numberOfEdges
    # print(len(sub_costs))

    sub_result = agglomerator(sub_graph, sub_costs)
    sub_edgeresult = sub_result[sub_uvs[:, 0]] != sub_result[sub_uvs[:, 1]]

    assert len(sub_edgeresult) == len(inner_edges)
    cut_edge_ids = inner_edges[sub_edgeresult]

    # print("block", block_id, "number cut_edges:", len(cut_edge_ids))
    # print("block", block_id, "number outer_edges:", len(outer_edges))

    if cut_outer_edges:
        cut_edge_ids = np.concatenate([cut_edge_ids, outer_edges])

    return cut_edge_ids
Ejemplo n.º 2
0
def debug_costs():
    from cremi_tools.viewer.volumina import view
    from nifty.graph import undirectedGraph
    path = '/g/kreshuk/data/arendt/platyneris_v1/membrane_training_data/validation/segmentation/val_block_01.n5'

    costs = z5py.File(path)['costs'][:]
    edges = z5py.File(path)['graph/edges'][:]
    assert len(costs) == len(edges)
    print(np.mean(costs), "+-", np.std(costs))
    print(costs.min(), costs.max())

    # import matplotlib.pyplot as plt
    # n, bins, patches = plt.hist(costs, 50)
    # plt.grid(True)
    # plt.show()

    n_nodes = int(edges.max()) + 1
    graph = undirectedGraph(n_nodes)
    graph.insertEdges(edges)

    assert graph.numberOfEdges == len(costs)
    node_labels = multicut_gaec(graph, costs)

    ds = z5py.File(path)['volumes/watershed']
    ds.n_threads = 8
    ws = ds[:]
    seg = nt.take(node_labels, ws)

    bb = np.s_[25:75, 500:1624, 100:1624]

    input_path = '/g/kreshuk/data/arendt/platyneris_v1/membrane_training_data/validation/predictions/val_block_0%i_unet_lr_v3_ds122.n5' % block_id
    with z5py.File(input_path) as f:
        ds = f['data']
        ds.n_threads = 8
        affs = ds[(slice(0, 3), ) + bb]
    view([affs.transpose((1, 2, 3, 0)), ws[bb], seg[bb]])
Ejemplo n.º 3
0
def compute_graph_and_weights(path, return_edge_sizes=False):
    from nifty.graph import undirectedGraph
    with h5py.File(path, 'a') as f:
        # if 'features' in f:
        if False:
            edges = f['edges'][:]
            feats = f['features'][:]
            edge_sizes = f['edge_sizes'][:]
            z_edges = f['z_edges'][:]
            n_nodes = int(edges.max()) + 1

        else:
            from elf.segmentation.features import compute_rag, compute_boundary_features, compute_z_edge_mask
            seg = f['watershed'][:]
            boundaries = f['boundaries'][:]
            boundaries[boundaries > .2] *= 3
            boundaries = np.clip(boundaries, 0, 1)
            rag = compute_rag(seg)
            n_nodes = rag.numberOfNodes
            feats = compute_boundary_features(rag, boundaries)
            feats, edge_sizes = feats[:, 0], feats[:, -1]
            edges = rag.uvIds()

            z_edges = compute_z_edge_mask(rag, seg)

            # f.create_dataset('edges', data=edges)
            # f.create_dataset('edge_sizes', data=edge_sizes)
            # f.create_dataset('features', data=feats)
            # f.create_dataset('z_edges', data=z_edges)

    graph = undirectedGraph(n_nodes)
    graph.insertEdges(edges)
    if return_edge_sizes:
        return graph, feats, edge_sizes, z_edges, boundaries
    else:
        return graph, feats