Example #1
0
def mws_agglomerator(affs, offsets, previous_segmentation=None,
                     previous_edges=None, previous_weights=None, return_state=False,
                     strides=None, randomize_strides=True):

    if previous_segmentation is not None:
        assert previous_edges is not None
        assert previous_weights is not None
        assert len(previous_edges) == len(previous_weights), "%i, %i" % (len(previous_edges),
                                                                         len(previous_weights))

        # transform the seed state to what is expected by mutex_watershed_with_seeds
        repulsive = previous_weights < 0
        attractive = np.logical_not(repulsive)
        seed_state = {'attractive': (previous_edges[attractive], previous_weights[attractive]),
                      'repulsive': (previous_edges[repulsive], np.abs(previous_weights[repulsive]))}

        segmentation = mutex_watershed_with_seeds(affs, offsets, seeds=previous_segmentation,
                                                  strides=strides, randomize_strides=randomize_strides,
                                                  seed_state=seed_state)
    else:
        segmentation = mutex_watershed(affs, offsets, strides,
                                       randomize_strides=randomize_strides)

    if return_state:
        state = compute_state(affs, segmentation, offsets, 3)
        return segmentation, state
    return segmentation
def _mws_block_pass1(block_id, blocking, ds_in, ds_out, mask, offsets, strides,
                     randomize_strides, halo, noise_level, max_block_id,
                     tmp_folder):
    fu.log("(Pass1) start processing block %i" % block_id)

    block = blocking.getBlockWithHalo(block_id, halo)
    in_bb = vu.block_to_bb(block.outerBlock)

    if mask is None:
        bb_mask = None
    else:
        bb_mask = mask[in_bb].astype('bool')
        if np.sum(bb_mask) == 0:
            fu.log_block_success(block_id)
            return

    aff_bb = (slice(None), ) + in_bb
    affs = vu.normalize(ds_in[aff_bb])

    seg = su.mutex_watershed(affs,
                             offsets,
                             strides=strides,
                             mask=bb_mask,
                             randomize_strides=randomize_strides,
                             noise_level=noise_level)

    out_bb = vu.block_to_bb(block.innerBlock)
    local_bb = vu.block_to_bb(block.innerBlockLocal)
    seg = seg[local_bb]

    # FIXME once vigra supports uint64 or we implement our own ...
    # seg = vigra.analysis.labelVolumeWithBackground(seg)

    # offset with lowest block coordinate
    offset_id = block_id * np.prod(blocking.blockShape)
    vigra.analysis.relabelConsecutive(seg,
                                      start_label=offset_id,
                                      keep_zeros=True,
                                      out=seg)
    ds_out[out_bb] = seg

    # get the state of the segmentation of this block
    grid_graph = su.compute_grid_graph(seg.shape, mask=bb_mask)
    affs = affs[(slice(None), ) + local_bb]
    state_uvs, state_weights, state_attractive = grid_graph.compute_state_for_segmentation(
        affs, seg, offsets, n_attractive_channels=3, ignore_label=True)
    # serialize the states
    save_path = os.path.join(tmp_folder, 'seg_state_block%i.h5' % block_id)
    with vu.file_reader(save_path) as f:
        f.create_dataset('edges', data=state_uvs)
        f.create_dataset('weights', data=state_weights)
        f.create_dataset('attractive_edge_mask', data=state_attractive)

    # write max-id for the last block
    if block_id == max_block_id:
        _write_nlabels(ds_out, seg)
    # log block success
    fu.log_block_success(block_id)
Example #3
0
    def _check_result(self, with_mask=False):
        with z5py.File(self.input_path) as f:
            shape = f[self.input_key].shape[1:]
            affs = f[self.input_key][:3]

        with z5py.File(self.output_path) as f:
            res = f[self.output_key][:]
        self.assertEqual(res.shape, shape)

        # load affs and compare
        with z5py.File(self.input_path) as f:
            ds = f[self.input_key]
            ds.n_threads = 8
            affs = ds[:]

        if with_mask:
            with z5py.File(self.input_path) as f:
                mask = f[self.mask_key][:]
            self.assertTrue(np.allclose(res[np.logical_not(mask)], 0))
            exp = su.mutex_watershed(affs,
                                     self.isbi_offsets,
                                     self.strides,
                                     mask=mask)
            self.assertTrue(np.allclose(exp[np.logical_not(mask)], 0))
            score = adjusted_rand_score(exp.ravel(), res.ravel())
            # score is much better with mask, so most of the differences seem
            # to be due to boundary artifacts
            self.assertLess(1. - score, .01)
        else:
            exp = su.mutex_watershed(affs, self.isbi_offsets, self.strides)
            score = adjusted_rand_score(exp.ravel(), res.ravel())
            self.assertLess(1. - score, .175)

        from cremi_tools.viewer.volumina import view
        view([affs.transpose((1, 2, 3, 0)), res, exp,
              mask.astype('uint32')], ['affs', 'result', 'expected', 'mask'])
Example #4
0
def mws_agglomerator(affs,
                     offsets,
                     previous_segmentation=None,
                     previous_edges=None,
                     previous_weights=None,
                     return_state=False,
                     strides=None,
                     randomize_strides=True):

    if previous_segmentation is not None:
        assert previous_edges is not None
        assert previous_weights is not None
        assert len(previous_edges) == len(previous_weights), "%i, %i" % (
            len(previous_edges), len(previous_weights))

        # transform the seed state to what is expected by mutex_watershed_with_seeds
        repulsive = previous_weights < 0
        attractive = np.logical_not(repulsive)
        seed_state = {
            'attractive':
            (previous_edges[attractive], previous_weights[attractive]),
            'repulsive':
            (previous_edges[repulsive], np.abs(previous_weights[repulsive]))
        }

        segmentation = mutex_watershed_with_seeds(
            affs,
            offsets,
            seeds=previous_segmentation,
            strides=strides,
            randomize_strides=randomize_strides,
            seed_state=seed_state)
    else:
        segmentation = mutex_watershed(affs,
                                       offsets,
                                       strides,
                                       randomize_strides=randomize_strides)

    if return_state:
        state = compute_state(affs, segmentation, offsets, 3)
        return segmentation, state
    return segmentation
Example #5
0
def _mws_block(block_id, blocking, ds_in, ds_out, mask, offsets, strides,
               randomize_strides, halo, noise_level):
    fu.log("start processing block %i" % block_id)

    in_bb, out_bb, local_bb = _get_bbs(blocking, block_id, halo)
    if mask is None:
        bb_mask = None
    else:
        bb_mask = mask[in_bb].astype('bool')
        if np.sum(bb_mask) == 0:
            fu.log_block_success(block_id)
            return

    aff_bb = (slice(None), ) + in_bb
    affs = ds_in[aff_bb]
    if affs.sum() == 0:
        fu.log_block_success(block_id)
        return
    affs = vu.normalize(affs)

    seg = su.mutex_watershed(affs,
                             offsets,
                             strides=strides,
                             mask=bb_mask,
                             randomize_strides=randomize_strides,
                             noise_level=noise_level)
    seg = seg[local_bb]

    # FIXME once vigra supports uint64 or we implement our own ...
    # seg = vigra.analysis.labelVolumeWithBackground(seg)

    # offset with lowest block coordinate
    offset_id = block_id * np.prod(blocking.blockShape)
    vigra.analysis.relabelConsecutive(seg,
                                      start_label=offset_id,
                                      keep_zeros=True,
                                      out=seg)
    ds_out[out_bb] = seg

    # log block success
    fu.log_block_success(block_id)