예제 #1
0
    def _check_result(self, with_mask=False):
        with z5py.File(self.input_path) as f:
            shape = f[self.input_key].shape[1:]
            affs = f[self.input_key][:3]

        with z5py.File(self.output_path) as f:
            res = f[self.output_key][:]
        self.assertEqual(res.shape, shape)

        # load affs and compare
        with z5py.File(self.input_path) as f:
            ds = f[self.input_key]
            ds.n_threads = 8
            affs = ds[:]

        if with_mask:
            with z5py.File(self.input_path) as f:
                mask = f[self.mask_key][:]
            self.assertTrue(np.allclose(res[np.logical_not(mask)], 0))
            exp = mutex_watershed(affs, self.offsets, self.strides, mask=mask)
            self.assertTrue(np.allclose(exp[np.logical_not(mask)], 0))
            score = adjusted_rand_score(exp.ravel(), res.ravel())
            # score is much better with mask, so most of the differences seem
            # to be due to boundary artifacts
            self.assertLess(1. - score, .01)
        else:
            exp = mutex_watershed(affs, self.offsets, self.strides)
            score = adjusted_rand_score(exp.ravel(), res.ravel())
            self.assertLess(1. - score, .175)
예제 #2
0
def _mws_block_pass1(block_id, blocking, ds_in, ds_out, mask, offsets, strides,
                     randomize_strides, halo, noise_level, max_block_id,
                     tmp_folder):
    fu.log("(Pass1) start processing block %i" % block_id)

    block = blocking.getBlockWithHalo(block_id, halo)
    in_bb = vu.block_to_bb(block.outerBlock)

    if mask is None:
        bb_mask = None
    else:
        bb_mask = mask[in_bb].astype('bool')
        if np.sum(bb_mask) == 0:
            fu.log_block_success(block_id)
            return

    aff_bb = (slice(None), ) + in_bb
    affs = vu.normalize(ds_in[aff_bb])

    seg = mutex_watershed(affs,
                          offsets,
                          strides=strides,
                          mask=bb_mask,
                          randomize_strides=randomize_strides,
                          noise_level=noise_level)

    out_bb = vu.block_to_bb(block.innerBlock)
    local_bb = vu.block_to_bb(block.innerBlockLocal)
    seg = seg[local_bb]

    # FIXME once vigra supports uint64 or we implement our own ...
    # seg = vigra.analysis.labelVolumeWithBackground(seg)

    # offset with lowest block coordinate
    offset_id = block_id * np.prod(blocking.blockShape)
    vigra.analysis.relabelConsecutive(seg,
                                      start_label=offset_id,
                                      keep_zeros=True,
                                      out=seg)
    ds_out[out_bb] = seg

    # get the state of the segmentation of this block
    grid_graph = compute_grid_graph(seg.shape, mask=bb_mask)
    affs = affs[(slice(None), ) + local_bb]
    # FIXME this function yields incorrect uv-ids !
    state_uvs, state_weights, state_attractive = grid_graph.compute_state_for_segmentation(
        affs, seg, offsets, n_attractive_channels=3, ignore_label=True)
    # serialize the states
    save_path = os.path.join(tmp_folder, 'seg_state_block%i.h5' % block_id)
    with vu.file_reader(save_path) as f:
        f.create_dataset('edges', data=state_uvs)
        f.create_dataset('weights', data=state_weights)
        f.create_dataset('attractive_edge_mask', data=state_attractive)

    # write max-id for the last block
    if block_id == max_block_id:
        _write_nlabels(ds_out, seg)
    # log block success
    fu.log_block_success(block_id)
예제 #3
0
    def test_mutex_watershed(self):
        from elf.segmentation.mutex_watershed import mutex_watershed
        shape = (10, 256, 256)
        aff_shape = (9, ) + shape
        affs = np.random.rand(*aff_shape).astype('float32')

        offsets = [[-1, 0, 0], [0, -1, 0], [0, 0, -1], [-3, 0, 0], [0, -3, 0],
                   [0, 0, -3], [-9, 0, 0], [0, -9, 0], [0, 0, -9]]
        strides = [4, 4, 4]
        seg = mutex_watershed(affs, offsets, strides, True)
        self.assertEqual(seg.shape, shape)
        # make sure the segmentation is not trivial
        self.assertGreater(len(np.unique(seg)), 10)
예제 #4
0
def _mws_block(block_id, blocking, ds_in, ds_out, mask, offsets, strides,
               randomize_strides, halo, noise_level):
    fu.log("start processing block %i" % block_id)

    in_bb, out_bb, local_bb = _get_bbs(blocking, block_id, halo)
    if mask is None:
        bb_mask = None
    else:
        bb_mask = mask[in_bb].astype('bool')
        if np.sum(bb_mask) == 0:
            fu.log_block_success(block_id)
            return

    aff_bb = (slice(None), ) + in_bb
    affs = ds_in[aff_bb]
    if affs.sum() == 0:
        fu.log_block_success(block_id)
        return

    affs = vu.normalize(affs)
    seg = mutex_watershed(affs,
                          offsets,
                          strides=strides,
                          mask=bb_mask,
                          randomize_strides=randomize_strides,
                          noise_level=noise_level)
    seg = seg[local_bb]

    # offset with lowest block coordinate
    offset_id = max(block_id * int(np.prod(blocking.blockShape)), 1)
    assert offset_id < np.iinfo('uint64').max, "Id overflow"
    vigra.analysis.relabelConsecutive(seg,
                                      start_label=offset_id,
                                      keep_zeros=True,
                                      out=seg)
    ds_out[out_bb] = seg

    # log block success
    fu.log_block_success(block_id)