コード例 #1
0
ファイル: benchmark.py プロジェクト: constantinpape/affogato
def measure_runtime(affs, n):
    times = []
    for _ in range(n):
        t = time.time()
        compute_mws_segmentation(affs, OFFSETS, 3)
        times.append(time.time() - t)
    return np.min(times)
コード例 #2
0
 def test_mws_consistency(self):
     from affogato.segmentation import compute_mws_segmentation
     number_of_attractive_channels = 2
     offsets = [[-1, 0], [0, -1], [-3, 0], [0, 3], [5, 5]]
     weights = np.random.rand(len(offsets), 100, 100)
     labels1 = compute_mws_segmentation(weights, offsets,
                                        number_of_attractive_channels,
                                        algorithm='kruskal')
     labels2 = compute_mws_segmentation(weights, offsets,
                                        number_of_attractive_channels,
                                        algorithm='prim')
     self._check_segmentation(labels1, labels2)
コード例 #3
0
    def test_mutex_malis_2d_gradient_descent_with_ignore_label_learning(self):
        from affogato.segmentation import compute_mws_segmentation
        from affogato.learning import mutex_malis
        shape = (100, 100)

        # generate random label image with large ignore region
        labels = np.zeros(shape, dtype=int)
        for i in range(10):
            for j in range(10):
                idx = (10 * i + j + 1)
                labels[10 * i:10 * (i + 1), 10 * j:10 * (j + 1)] = idx
        labels[:, 40:60] = 0

        offsets = [[-1, 0], [0, -1], [-3, 0], [0, -3]]
        affs = 0.5 * np.random.rand(len(offsets), 100, 100)

        for epoch in range(50):
            loss, grads, seg1, seg2 = mutex_malis(affs, labels, offsets, 2,
                                                  learn_in_ignore_label=True)
            affs -= grads
            affs = np.clip(affs, 0., 1.)

        number_of_attractive_channels = 2
        labels1 = compute_mws_segmentation(affs, offsets,
                                           number_of_attractive_channels,
                                           algorithm='kruskal')

        self.assertEqual(grads.shape, affs.shape)
        self.assertEqual(loss, 0)

        labels1[:, 40:60] = 0
        edges1 = self.seg2edges(labels1)
        edges2 = self.seg2edges(labels)

        self.assertTrue(np.allclose(edges1, edges2))
コード例 #4
0
    def test_mutex_malis_2d_gradient_descent(self):
        from affogato.segmentation import compute_mws_segmentation
        from affogato.learning import mutex_malis
        shape = (100, 100)
        labels = np.zeros(shape)
        for i in range(10):
            for j in range(10):
                labels[10 * i:10 * (i + 1), 10 * j:10 * (j + 1)] = 10 * i + j + 1

        offsets = [[-1, 0], [0, -1], [-3, 0], [0, -3]]
        affs = 0.5 * np.ones((len(offsets), 100, 100))

        for epoch in range(30):
            loss, grads, seg1, seg2 = mutex_malis(affs, labels, offsets, 2)
            affs -= grads
            affs = np.clip(affs, 0, 1)

        number_of_attractive_channels = 2
        labels1 = compute_mws_segmentation(affs, offsets,
                                           number_of_attractive_channels,
                                           algorithm='kruskal')

        self.assertEqual(grads.shape, affs.shape)
        self.assertEqual(loss, 0)

        edges1 = self.seg2edges(labels1)
        edges2 = self.seg2edges(labels)
        self.assertTrue(np.allclose(edges1, edges2))
コード例 #5
0
 def _run_mws(self, input_):
     return compute_mws_segmentation(
         input_,
         self.offsets,
         number_of_attractive_channels=self.number_of_attractive_channels,
         strides=self.strides,
         randomize_strides=self.randomize_strides)
コード例 #6
0
def create_test_data_3d(path):
    with h5py.File(path, "r") as f:
        affs = f["affinities"][:, :4, :256, :256]
    assert affs.shape[0] == len(OFFSETS)
    seperating_channel = 3
    affs[:seperating_channel] *= -1
    affs[:seperating_channel] += 1
    offsets = OFFSETS
    seg = compute_mws_segmentation(
        affs,
        offsets,
        number_of_attractive_channels=seperating_channel,
        strides=None)

    assert affs.shape[0] == len(offsets)

    # check the results
    import napari
    v = napari.Viewer()
    v.add_image(affs)
    v.add_labels(seg)
    napari.run()

    with h5py.File("../data/test_data_3d.h5", "w") as f:
        f.create_dataset("affinities", data=affs, compression="gzip")
        f.create_dataset("segmentation", data=seg, compression="gzip")
        f.attrs["offsets"] = offsets
コード例 #7
0
    def mws_block(block_id):
        block = blocking.getBlock(block_id)
        bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))

        bb_affs = (slice(None), ) + bb
        affs_ = affs[bb_affs].copy(
        )  # we need to copy here to leave the original affs unchanged
        mask_ = None if mask is None else mask[bb]

        if noise_level > 0:
            affs_ += noise_level * np.random.rand(*affs_.shape)
        affs_[:ndim] *= -1
        affs_[:ndim] += 1
        seg = compute_mws_segmentation(affs_,
                                       offsets,
                                       number_of_attractive_channels=ndim,
                                       strides=strides,
                                       mask=mask_,
                                       randomize_strides=randomize_strides)
        max_id = relabelConsecutive(seg,
                                    out=seg,
                                    start_label=1,
                                    keep_zeros=mask is not None)[1]
        segmentation[bb] = seg
        return max_id
コード例 #8
0
ファイル: hela.py プロジェクト: constantinpape/affogato
def run_default_mws_2d(affs, fg, offsets):
    shape = fg.shape
    mask = fg > .5
    segmentation = np.zeros(shape, dtype='uint32')

    spatial_channels = [i for i, off in enumerate(offsets) if off[0] == 0]
    causal_channels = [i for i, off in enumerate(offsets) if off[0] != 0]

    spatial_offsets = [
        off for i, off in enumerate(offsets) if i in spatial_channels
    ]
    causal_offsets = [
        off for i, off in enumerate(offsets) if i in causal_channels
    ]

    for t in range(shape[0]):
        affs_t = affs[:, t]
        mask_t = mask[t]

        affs_spatial = affs_t[spatial_channels]

        seg = compute_mws_segmentation(2,
                                       affs_spatial,
                                       spatial_offsets,
                                       strides=strides,
                                       mask=mask_t)
        if t > 0:
            affs_causal = affs_t[causal_channels]
            seg_prev = segmentation[t - 1]
            max_id = int(seg_prev.max()) + 1
            seg[seg != 0] += max_id
            seg = merge_causal(seg, seg_prev, affs_causal, causal_offsets)

        segmentation[t] = seg
    return segmentation
コード例 #9
0
def mutex_watershed(affs, offsets, strides,
                    randomize_strides=False, mask=None,
                    noise_level=0):
    """ Compute mutex watershed segmentation.

    Introduced in "The Mutex Watershed and its Objective: Efficient, Parameter-Free Image Partitioning":
    https://arxiv.org/pdf/1904.12654.pdf

    Arguments:
        affs [np.ndarray] - input affinity map
        offsets [list[list[int]]] - pixel offsets corresponding to affinity channels
        strides [list[int]] - strides used to sub-sample long range edges
        randomize_strides [bool] - randomize the strides? (default: False)
        mask [np.ndarray] - mask to exclude from segmentation (default: None)
        noise_level [float] - sigma of noise added to affinities (default: 0)
    """
    ndim = len(offsets[0])
    if noise_level > 0:
        affs += noise_level * np.random.rand(*affs.shape)
    affs[:ndim] *= -1
    affs[:ndim] += 1
    seg = compute_mws_segmentation(affs, offsets,
                                   number_of_attractive_channels=ndim,
                                   strides=strides, mask=mask,
                                   randomize_strides=randomize_strides)
    relabelConsecutive(seg, out=seg, start_label=1, keep_zeros=mask is not None)
    return seg
コード例 #10
0
ファイル: leptin_data.py プロジェクト: paulhfu/RLForSeg
def get_data(img,
             gt,
             affs,
             sigma,
             strides=[4, 4],
             overseg_factor=1.2,
             random_strides=False,
             fname='ex1.h5'):
    affinities = affs.copy()
    affinities[:sep_chnl] /= overseg_factor
    # affinities[sep_chnl:] *= overseg_factor
    # affinities = np.clip(affinities, 0, 1)

    # scale affinities in order to get an oversegmentation
    affinities[:sep_chnl] /= overseg_factor
    node_labeling = compute_mws_segmentation(affinities,
                                             offs,
                                             sep_chnl,
                                             strides=strides,
                                             randomize_strides=random_strides)
    node_labeling = node_labeling - 1
    nodes = np.unique(node_labeling)
    save_file = h5py.File(
        '/g/kreshuk/hilt/projects/data/leptin_fused_tp1_ch_0/train/raw_wtsd_cpy/exs/'
        + fname, 'w')
    save_file.create_dataset(name='data', data=node_labeling)
    save_file.close()
    plt.imshow(cm.prism(node_labeling / node_labeling.max()))
    plt.show()
コード例 #11
0
 def segment(self, affinities):
     self.att_c = 2
     seg = compute_mws_segmentation(affinities,
                                    self.offsets,
                                    att_c,
                                    strides=self.strides,
                                    mask=None).astype(np.int32)
     return label_cont(seg)
コード例 #12
0
def mws_segmentation(affs, algo, strides=None):
    t0 = time.time()
    seperating_channel = 2
    seg = compute_mws_segmentation(affs,
                                   OFFSETS,
                                   seperating_channel,
                                   algorithm=algo,
                                   strides=strides)
    return seg, time.time() - t0
コード例 #13
0
 def test_mws_reference_3d(self):
     from affogato.segmentation import compute_mws_segmentation
     test_path = os.path.join(os.path.split(__file__)[0], "../../../../data/test_data_3d.h5")
     with h5py.File(test_path, "r") as f:
         affs = f["affinities"][:]
         ref = f["segmentation"][:]
         offsets = f.attrs["offsets"]
     seg = compute_mws_segmentation(affs, offsets, 3, strides=None)
     self._check_segmentation(ref, seg)
コード例 #14
0
ファイル: arand.py プロジェクト: vzinche/neurofire
 def _run_mws(self, input_):
     assert len(input_) == len(self.offsets)
     input_[:self.dim] *= -1
     input_[:self.dim] += 1
     return compute_mws_segmentation(
         input_,
         self.offsets,
         number_of_attractive_channels=self.dim,
         strides=self.strides,
         randomize_strides=self.randomize_strides)
コード例 #15
0
def mws_segmenter(prediction, offset_version='v2'):
    from affogato.segmentation import compute_mws_segmentation
    from train_affs import get_default_offsets, get_mws_offsets
    assert offset_version in ('v1', 'v2')
    offsets = get_default_offsets(
    ) if offset_version == 'v1' else get_mws_offsets()
    # invert the lr channels
    prediction[:2] *= -1
    prediction[:2] += 1
    # TODO change this api
    return compute_mws_segmentation(prediction, offsets, 2, strides=[4, 4])
コード例 #16
0
def get_graphs(img, gt, sigma, edge_offsets):
    overseg_factor = 1.7
    sep_chnl = 2

    affinities = get_naive_affinities(gaussian(img, sigma=sigma), edge_offsets)
    affinities[:sep_chnl] *= -1
    affinities[:sep_chnl] += +1
    # scale affinities in order to get an oversegmentation
    affinities[:sep_chnl] /= overseg_factor
    affinities[sep_chnl:] *= overseg_factor
    affinities = np.clip(affinities, 0, 1)
    node_labeling = compute_mws_segmentation(affinities, edge_offsets,
                                             sep_chnl)
    node_labeling = node_labeling - 1
    nodes = np.unique(node_labeling)
    try:
        assert all(nodes == np.array(range(len(nodes)), dtype=np.float))
    except:
        Warning("node ids are off")

    # get edges from node labeling and edge features from affinity stats
    edge_feat, neighbors = get_edge_features_1d(node_labeling, edge_offsets,
                                                affinities)
    # get gt edge weights based on edges and gt image
    gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                              node_labeling.squeeze(),
                                              gt.squeeze(), 0.5)
    edges = neighbors.astype(np.long)

    # calc multicut from gt
    gt_seg = get_current_soln(gt_edge_weights, node_labeling, edges)

    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4)
    ax1.imshow(cm.prism(gt / gt.max()))
    ax1.set_title('gt')
    ax2.imshow(cm.prism(node_labeling / node_labeling.max()))
    ax2.set_title('sp')
    ax3.imshow(cm.prism(gt_seg / gt_seg.max()))
    ax3.set_title('mc')
    ax4.imshow(img)
    ax4.set_title('raw')
    plt.show()

    affinities = affinities.astype(np.float32)
    edge_feat = edge_feat.astype(np.float32)
    nodes = nodes.astype(np.float32)
    node_labeling = node_labeling.astype(np.float32)
    gt_edge_weights = gt_edge_weights.astype(np.float32)
    diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()

    edges = np.sort(edges, axis=-1)
    edges = edges.T

    return img, gt, edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, nodes, affinities
コード例 #17
0
    def get_node_features(self, embeddings, _, post_input=False):
        separating_channel = 2
        offsets = [[-1, 0], [0, -1],
           # direct 3d nhood for attractive edges
           [-1, -1], [1, 1], [-1, 1],[1, -1]
           # indirect 3d nhood for dam edges
           [-9, 0], [0, -9],
           # long range direct hood
           [-9, -9], [9, -9], [-9, -4], [-4, -9], [4, -9], [9, -4],
           # inplane diagonal dam edges
           [-27, 0], [0, -27]]

        edges = self.get_edges_from_embeddings(embeddings, offsets,
                                               separating_channel)

        node_labeling = compute_mws_segmentation(edges,
                                                 offsets,
                                                 separating_channel,
                                                 strides=[10, 10],
                                                 randomize_strides=True)

        stacked_superpixels = [
            node_labeling == n for n in range(node_labeling.max() + 1)
        ]
        sp_indices = [sp.nonzero().cpu() for sp in stacked_superpixels]

        sp_feat_vecs = torch.empty(
            (len(sp_indices), embeddings.shape[0])).to(self.device).float()
        sp_similarity_reg = 0
        for i, sp in enumerate(sp_indices):
            sp = sp.to(self.device)
            mass = len(sp)
            assert mass > 0
            # ival = torch.index_select(features.squeeze(), 1, sp[:, -2].long())
            # sp_features = torch.gather(ival, 2, torch.stack([sp[:, -1].long() for i in range(ival.shape[0])], dim=0).unsqueeze(-1)).squeeze()
            sp_features = embeddings[:, sp[:, -2], sp[:, -1]].T
            # if sp_features.shape[0] > 1:
            #     shift = torch.randint(1, sp_features.shape[0], (1,)).item()
            #     sp_similarity_reg = sp_similarity_reg + self.pw_dist(sp_features, sp_features.roll(shift, dims=0)).sum()/mass
            sp_feat_vecs[i] = sp_features.sum(0) / mass

        if self.writer is not None and post_input:
            plt.clf()
            fig = plt.figure(frameon=False)
            plt.imshow(
                _pca_project(embeddings.detach().squeeze().cpu().numpy()))
            plt.colorbar()
            self.writer.add_figure("image/embedding_proj", fig,
                                   self.writer_counter)
            self.writer_counter += 1

        return sp_feat_vecs, sp_similarity_reg
def mutex_watershed(affs, offsets, strides,
                    randomize_strides=False, mask=None,
                    noise_level=0):
    assert compute_mws_segmentation is not None, "Need affogato for mutex watershed"
    ndim = len(offsets[0])
    if noise_level > 0:
        affs += noise_level * np.random.rand(*affs.shape)
    affs[:ndim] *= -1
    affs[:ndim] += 1
    seg = compute_mws_segmentation(affs, offsets,
                                   number_of_attractive_channels=ndim,
                                   strides=strides, mask=mask,
                                   randomize_strides=randomize_strides)
    relabelConsecutive(seg, out=seg, start_label=1, keep_zeros=mask is not None)
    return seg
コード例 #19
0
ファイル: mws.py プロジェクト: jamiegrieser/segmfriends
    def __call__(self, affinities):
        if self.invert_affinities:
            affinities = 1. - affinities

        segmentation = compute_mws_segmentation(affinities, self.offsets, self.seperating_channel,
                                     strides=self.stride, randomize_strides=self.randomize_bounds, invert_repulsive_weights=self.invert_dam_channels,
                                     bias_cut=0., mask=None,
                                     algorithm='kruskal')

        # # Apply bias (0.0: merge everything; 1.0: split everything, or what can be split)
        # affinities[:self.seperating_channel] -= 2 * (self.bias - 0.5)
        # if self.stacked_2d:
        #     affinities_ = np.require(affinities[self.keep_channels], requirements='C')
        #     segmentation, _ = superpixel_stacked_from_affinities(affinities_,
        #                                                          self.damws_superpixel,
        #                                                          self.n_threads)
        # else:
        #     segmentation, _ = self.damws_superpixel(affinities)
        return segmentation
コード例 #20
0
    def test_mws_masked(self):
        from affogato.segmentation import compute_mws_segmentation
        number_of_attractive_channels = 2
        offsets = [[-1, 0], [0, -1], [-3, 0], [0, 3], [5, 5]]

        weights = np.random.rand(len(offsets), 100, 100)
        mask = np.ones((100, 100), dtype='bool')
        # exclude 10 % of pixel from foreground mask
        coords = np.where(mask)
        n_out = int(len(coords[0]) * .1)
        indices = np.random.permutation(len(coords[0]))[:n_out]
        coords = (coords[0][indices], coords[1][indices])
        mask[coords] = False

        node_labels = compute_mws_segmentation(weights, offsets,
                                               number_of_attractive_channels,
                                               mask=mask)
        self.assertEqual(weights.shape[1:], node_labels.shape)
        # make sure mask is all non-zero
        self.assertTrue((node_labels[mask] != 0).all())
        # make sure inv mask is all zeros
        self.assertTrue((node_labels[np.logical_not(mask)] == 0).all())
コード例 #21
0
    tick = time.time()
    segm = run_mws(affinities, offsets, [1,1,1], seperating_channel=3, randomize_bounds=False)
    print(segm)
    print("Took ", time.time() - tick)

    file_path_segm = os.path.join('/home/abailoni_local/', "ISBI_results_new_MWS.h5")
    # vigra.writeHDF5(segm.astype('uint32'), file_path_segm, 'segm')

    file_path = os.path.join('/home/abailoni_local/', "noise.h5")
    vigra.writeHDF5(affinities, file_path, 'kwargs')


    # Constantin implementation:
    tick = time.time()
    labels = compute_mws_segmentation(1 - affinities, offsets, 3,
                                  randomize_strides=False,
                                 algorithm='kruskal')
    print("Took ", time.time() - tick)

  #   labels = compute_mws_segmentation(np.random.uniform(size=(3,1,2,2)), np.array([[-1, 0, 0],
  # [0, -1, 0],
  # [0, 0, -1]]),
  #                                     3,
  #                                     randomize_strides=False,
  #                                     algorithm='kruskal')

    print(labels)
    vigra.writeHDF5(labels.astype('uint32'), file_path_segm, 'segm_dMWS')

    # # Steffen:
    # configs = {'models': yaml2dict('./experiments/models_config.yml'),
コード例 #22
0
def get_pix_data(shape=(256, 256)):
    """ This generates raw-gt-superpixels and correspondinng rags of rectangles and circles"""

    rsign = lambda: (-1)**np.random.randint(0, 2)
    edge_offsets = [[0, -1], [-1, 0], [-3, 0], [0, -3], [-6, 0],
                    [0, -6]]  # offsets defining the edges for pixel affinities
    overseg_factor = 1.7
    sep_chnl = 2  # channel separating attractive from repulsive edges
    n_circles = 5  # number of ellipses in image
    n_polys = 10  # number of rand polys in image
    n_rect = 5  # number rectangles in image
    circle_color = np.array([1, 0, 0], dtype=np.float)
    rect_color = np.array([0, 0, 1], dtype=np.float)
    col_diff = 0.4  # by this margin object color can vary ranomly
    min_r, max_r = 10, 20  # min and max radii of ellipses/circles
    min_dist = max_r

    img = np.random.randn(*(shape + (3, ))) / 5  # init image with some noise
    gt = np.zeros(shape)

    #  get some random frequencies
    ri1, ri2, ri3, ri4, ri5, ri6 = rsign() * ((np.random.rand() * 2) + .5), \
                                   rsign() * ((np.random.rand() * 2) + .5), \
                                   (np.random.rand() * 4) + 3, \
                                   (np.random.rand() * 4) + 3, \
                                   rsign() * ((np.random.rand() * 2) + .5), \
                                   rsign() * ((np.random.rand() * 2) + .5)
    x = np.zeros(shape)
    x[:, :] = np.arange(img.shape[0])[np.newaxis, :]
    y = x.transpose()
    # add background frequency interferences
    img += (np.sin(
        np.sqrt((x * ri1)**2 + ((shape[1] - y) * ri2)**2) * ri3 * np.pi /
        shape[0]))[..., np.newaxis]
    img += (np.sin(
        np.sqrt((x * ri5)**2 + ((shape[1] - y) * ri6)**2) * ri4 * np.pi /
        shape[1]))[..., np.newaxis]
    # smooth a bit
    img = gaussian(np.clip(img, 0.1, 1), sigma=.8)
    # add some circles
    circles = []
    cmps = []
    while len(circles) < n_circles:
        mp = np.random.randint(min_r, shape[0] - min_r, 2)
        too_close = False
        for cmp in cmps:
            if np.linalg.norm(cmp - mp) < min_dist:
                too_close = True
        if too_close:
            continue
        r = np.random.randint(min_r, max_r, 2)
        circles.append(draw.circle(mp[0], mp[1], r[0], shape=shape))
        cmps.append(mp)

    # add some random polygons
    polys = []
    while len(polys) < n_polys:
        mp = np.random.randint(min_r, shape[0] - min_r, 2)
        too_close = False
        for cmp in cmps:
            if np.linalg.norm(cmp - mp) < min_dist // 2:
                too_close = True
        if too_close:
            continue
        circle = draw.circle_perimeter(mp[0], mp[1], max_r)
        poly_vert = np.random.choice(len(circle[0]),
                                     np.random.randint(3, 6),
                                     replace=False)
        polys.append(
            draw.polygon(circle[0][poly_vert],
                         circle[1][poly_vert],
                         shape=shape))
        cmps.append(mp)

    # add some random rectangles
    rects = []
    while len(rects) < n_rect:
        mp = np.random.randint(min_r, shape[0] - min_r, 2)
        _len = np.random.randint(min_r // 2, max_r, (2, ))
        too_close = False
        for cmp in cmps:
            if np.linalg.norm(cmp - mp) < min_dist:
                too_close = True
        if too_close:
            continue
        start = (mp[0] - _len[0], mp[1] - _len[1])
        rects.append(
            draw.rectangle(start,
                           extent=(_len[0] * 2, _len[1] * 2),
                           shape=shape))
        cmps.append(mp)

    # draw polys and give them some noise
    for poly in polys:
        color = np.random.rand(3)
        while np.linalg.norm(color -
                             circle_color) < col_diff or np.linalg.norm(
                                 color - rect_color) < col_diff:
            color = np.random.rand(3)
        img[poly[0], poly[1], :] = color
        img[poly[0], poly[1], :] += np.random.randn(len(
            poly[1]), 3) / 5  # add noise to the polygons

    # draw circles with some frequency
    cols = np.random.choice(np.arange(4, 11, 1).astype(np.float) / 10,
                            n_circles,
                            replace=False)  # get colors
    for i, circle in enumerate(circles):
        gt[circle[0], circle[1]] = 1 + (i / 10)
        ri1, ri2, ri3, ri4, ri5, ri6 = rsign() * ((np.random.rand() * 4) + 7), \
                                       rsign() * ((np.random.rand() * 4) + 7), \
                                       (np.random.rand() + 1) * 8, \
                                       (np.random.rand() + 1) * 8, \
                                       rsign() * ((np.random.rand() * 4) + 7), \
                                       rsign() * ((np.random.rand() * 4) + 7)

        img[circle[0],
            circle[1], :] = np.array([cols[i], 0.0,
                                      0.0])  # set even color intensity
        # set interference of two freqs in circle color channel
        img[circle[0], circle[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[circle[0], circle[1]] * ri5)**2 +
                    ((shape[1] - y[circle[0], circle[1]]) * ri2)**2) * ri3 *
            np.pi / shape[0]))[..., np.newaxis] * 0.15) + 0.2
        img[circle[0], circle[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[circle[0], circle[1]] * ri6)**2 +
                    ((shape[1] - y[circle[0], circle[1]]) * ri1)**2) * ri4 *
            np.pi / shape[1]))[..., np.newaxis] * 0.15) + 0.2

    # draw rectangles with some frequency
    cols = np.random.choice(np.arange(4, 11, 1).astype(np.float) / 10,
                            n_rect,
                            replace=False)
    for i, rect in enumerate(rects):
        gt[rect[0], rect[1]] = 2 + (i / 10)
        ri1, ri2, ri3, ri4, ri5, ri6 = rsign() * ((np.random.rand() * 4) + 7), \
                                       rsign() * ((np.random.rand() * 4) + 7), \
                                       (np.random.rand() + 1) * 8, \
                                       (np.random.rand() + 1) * 8, \
                                       rsign() * ((np.random.rand() * 4) + 7), \
                                       rsign() * ((np.random.rand() * 4) + 7)
        img[rect[0], rect[1], :] = np.array([0.0, 0.0, cols[i]])
        img[rect[0], rect[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[rect[0], rect[1]] * ri5)**2 +
                    ((shape[1] - y[rect[0], rect[1]]) * ri2)**2) * ri3 *
            np.pi / shape[0]))[..., np.newaxis] * 0.15) + 0.2
        img[rect[0], rect[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[rect[0], rect[1]] * ri1)**2 +
                    ((shape[1] - y[rect[0], rect[1]]) * ri6)**2) * ri4 *
            np.pi / shape[1]))[..., np.newaxis] * 0.15) + 0.2

    img = np.clip(img, 0, 1)  # clip to valid range
    # get affinities and calc superpixels with mutex watershed
    affinities = get_naive_affinities(gaussian(img, sigma=.2), edge_offsets)
    affinities[:sep_chnl] *= -1
    affinities[:sep_chnl] += +1
    # scale affinities in order to get an oversegmentation
    affinities[:sep_chnl] /= overseg_factor
    affinities[sep_chnl:] *= overseg_factor
    affinities = np.clip(affinities, 0, 1)
    node_labeling = compute_mws_segmentation(affinities, edge_offsets,
                                             sep_chnl)
    node_labeling = node_labeling - 1
    nodes = np.unique(node_labeling)
    try:
        assert all(nodes == np.array(range(len(nodes)), dtype=np.float))
    except:
        Warning("node ids are off")

    # get edges from node labeling and edge features from affinity stats
    edge_feat, neighbors = get_edge_features_1d(node_labeling, edge_offsets,
                                                affinities)
    # get gt edge weights based on edges and gt image
    gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                              node_labeling.squeeze(),
                                              gt.squeeze())
    edges = neighbors.astype(np.long)

    # # calc multicut from gt
    # gt_seg = get_current_soln(gt_edge_weights, node_labeling, edges)
    # # show result (uncomment for testing)
    #
    # fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4)
    # ax1.imshow(cm.prism(gt/gt.max()));ax1.set_title('gt')
    # ax2.imshow(cm.prism(node_labeling / node_labeling.max()));ax2.set_title('sp')
    # ax3.imshow(cm.prism(gt_seg / gt_seg.max()));ax3.set_title('mc')
    # ax4.imshow(img);ax4.set_title('raw')
    # plt.show()

    affinities = affinities.astype(np.float32)
    edge_feat = edge_feat.astype(np.float32)
    nodes = nodes.astype(np.float32)
    node_labeling = node_labeling.astype(np.float32)
    gt_edge_weights = gt_edge_weights.astype(np.float32)
    diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()

    edges = np.sort(edges, axis=-1)
    edges = edges.T

    return img, gt, edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, nodes, affinities
コード例 #23
0
    atr = 1

    aff, offsets = reorder_and_invert(aff,
                                      offsets,
                                      atr * n_dir,
                                      dist_per_dir=n_d)

    aff_new = np.empty((aff.shape[0] - 12, *aff.shape[1:]))
    aff_new[:6] = aff[[0, 1, 2, 6, 10, 14]]
    aff_new[6:] = aff[18:]
    offsets_new = [offsets[i] for i in [0, 1, 2, 6, 10, 14]] + offsets[18:]

    labels = compute_mws_segmentation(aff_new[:],
                                      offsets_new[:],
                                      number_of_attractive_channels=6,
                                      strides=None,
                                      randomize_strides=False)

    X = labels[1]
    numrows, numcols = X.shape

    def format_coord(x, y):
        col = int(x + 0.5)
        row = int(y + 0.5)
        if 0 <= col < numcols and 0 <= row < numrows:
            z = X[row, col]
            return f'x={x:.1f}, y={y:.1f}, z={z}'
        else:
            return 'x=%1.4f, y=%1.4f' % (x, y)
コード例 #24
0
def getFastIOUMWS(dist, S, angles, strides=None, S_attr=1, mask=None, smws=False, verbose=False):
    getOffset = lambda theta, s: list(map(int, list(map(np.around, [s * np.sin(theta), s * np.cos(theta)]))))
    S = float(S)
    S_attr = float(S_attr)

    @jit(nopython=True)
    def validateCoordinates(row, col):
        max_r = dist.shape[1]
        max_c = dist.shape[2]
        in_r = 0 <= row < max_r
        in_c = 0 <= col < max_c
        return in_r & in_c

    @jit(nopython=True)
    def getIntersection(r1: float, _r1: float, r2: float, _r2: float, S: float) -> float:
        inter = 0
        if r1 > S:
            if _r2 > S:
                inter += S
                inter += min(_r1, _r2 - S)
            else:
                inter += _r2
            inter += min(r2, r1 - S)
        else:
            if _r2 > S:
                inter += r1
                inter += min(_r1, _r2 - S)
            else:
                inter += max(r1 + _r2 - S, 0)
        return float(inter)

    @jit(nopython=True, parallel=True)
    def fillAffinities(dist, offset, ray_idx, S):
        rows = dist.shape[1]
        cols = dist.shape[2]
        affs = np.zeros((rows, cols))
        for row in range(rows):
            for col in range(cols):
                row_off = row + offset[0]
                col_off = col + offset[1]
                if validateCoordinates(row_off, col_off):
                    dists_per_pixel0 = dist[:, row, col]
                    dists_per_pixel1 = dist[:, row_off, col_off]
                    if not np.any(dists_per_pixel0) and not np.any(dists_per_pixel1):
                        affs[row, col] = 1.0
                    else:
                        _dist1 = dists_per_pixel0[ray_idx]
                        _dist2 = dists_per_pixel0[ray_idx + n_rays // 2]
                        dist1 = dists_per_pixel1[ray_idx]
                        dist2 = dists_per_pixel1[ray_idx + n_rays // 2]
                        intersection = getIntersection(_dist1.item(), _dist2.item(), dist1.item(), dist2.item(), S)
                        union = _dist1.item() + _dist2.item() + dist1.item() + dist2.item() - intersection
                        if union == 0:
                            iou = 0
                        else:
                            iou = intersection / union
                        affs[row, col] = iou
        return affs

    attractive_angles = [0, 90]
    n_rays = angles.shape[0]
    angl = angles[:n_rays // 2]
    offsets0 = [getOffset(np.deg2rad(angle), S_attr) for angle in attractive_angles]
    offsets1 = [getOffset(a, S) for a in angl]
    offsets = offsets0 + offsets1
    print('Generated offsets:', offsets)
    affs_attr = np.ones((len(offsets0), dist.shape[1], dist.shape[2]))
    angles_d = np.rad2deg(angles)
    off = Dict.empty(key_type=types.int16, value_type=types.int16)
    for idx, offset in enumerate(offsets0):
        ray = np.rad2deg(np.arctan2(*offset))
        ray_idx = np.where(angles_d == ray)[0]
        if verbose: print('Angles', angles_d[ray_idx], angles_d[ray_idx + n_rays // 2])
        off[0] = offset[0]
        off[1] = offset[1]
        affs_attr[idx] = fillAffinities(dist, off, ray_idx[0], S_attr)

    affs_repul = np.zeros((len(offsets1), dist.shape[1], dist.shape[2]))
    for idx, offset in enumerate(offsets1):
        if verbose: print('Angles', np.rad2deg(angles[idx]), np.rad2deg(angles[idx + n_rays // 2]))
        off[0] = offset[0]
        off[1] = offset[1]
        affs_repul[idx] = fillAffinities(dist, off, idx, S)

    merged_aff = np.vstack((affs_attr, affs_repul))
    merged_aff[merged_aff > 1] = 1
    merged_aff[merged_aff < 0] = 0
    merged_aff[len(attractive_angles):] *= -1
    merged_aff[len(attractive_angles):] += 1
    # for i in range(merged_aff.shape[0]):
    #     plt.figure()
    #     plt.imshow(merged_aff[i])
    # plt.show()
    # if smws:
    #     mask = np.expand_dims(mask, 0)
    #     merged_aff = np.vstack((merged_aff, mask))
    #     merged_aff = np.vstack((merged_aff, 1 - mask))
    #     labels = computeSMWS(merged_aff, offsets, len(attractive_angles), stride=strides)
    # else:
    labels = compute_mws_segmentation(merged_aff, offsets, len(attractive_angles), algorithm='kruskal',
                                      strides=strides, mask=mask)

    return labels
コード例 #25
0
plotting_helper(distances[:10])

affinities, offsets = reorder_and_invert(affinities, offsets,
                                         n_directions*attr_layers,
                                         dist_per_dir=len(default_distances))

print(f'affinities.shape: \t {affinities.shape} \n'
      f'offsets.shape: \t {np.array(offsets).shape}')

affinities_new, offsets_new = exclude_some_short_edges(affinities, offsets, sampling_factor=less_attr,
                                                       n_directions=n_directions*attr_layers, z_dir=compute_z)
print(offsets_new)

labels_new = compute_mws_segmentation(affinities_new, offsets_new,
                                      number_of_attractive_channels=number_of_attractive_channels,
                                      algorithm='kruskal')

#plotting_helper(affinities_new)
#plotAffinities(affinities_new[:8, 1], 'Affinities')

numrows, numcols = labels_new.shape[1:]

def format_maker(z_values):
    numcols, numrows = z_values.shape
    def format_coord(x, y):
        col = int(x + 0.5)
        row = int(y + 0.5)
        if 0 <= col < numcols and 0 <= row < numrows:
            z = z_values[row, col]
            return f'x={x:.1f}, y={y:.1f}, z={z}'
コード例 #26
0
def mws_segmentation(affs):
    t0 = time.time()
    seperating_channel = 3
    seg = compute_mws_segmentation(seperating_channel, OFFSETS, affs)
    return seg, time.time() - t0
コード例 #27
0
def trainAffPredCircles(saveToFile,
                        device,
                        separating_channel,
                        offsets,
                        strides,
                        numEpochs=8):
    file = 'mask/masks.h5'
    rootPath = '/g/kreshuk/hilt/projects/fewShotLearning/data/Discs'

    dloader = DataLoader(CustomDiscDset(length=5),
                         batch_size=1,
                         shuffle=True,
                         pin_memory=True)
    print('----START TRAINING----' * 4)

    model = UNet(n_channels=1, n_classes=len(offsets), bilinear=True)
    for param in model.parameters():
        param.requires_grad = True

    criterion = nn.MSELoss()

    optim = torch.optim.Adam(model.parameters())

    model.cuda()
    since = time.time()

    for epoch in range(numEpochs):
        print('Epoch {}/{}'.format(epoch, numEpochs - 1))
        print('-' * 10)
        # Each epoch has a training and validation phase
        # Iterate over data.
        for step, (inputs, _, affinities) in enumerate(dloader):
            inputs = inputs.to(device)
            affinities = affinities.to(device)
            # zero the parameter gradients
            optim.zero_grad()
            # forward
            # track history if only in train
            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                loss = criterion(outputs, affinities)
                loss.backward()
                optim.step()

        weights = outputs.squeeze().detach().cpu().numpy()
        # weights[separating_channel:] /= 2
        affs = affinities.squeeze().detach().cpu().numpy()
        weights[separating_channel:] *= -1
        weights[separating_channel:] += +1
        affs[separating_channel:] *= -1
        affs[separating_channel:] += +1

        weights[:separating_channel] /= 1.5

        ndim = len(offsets[0])
        assert all(len(off) == ndim for off in offsets)
        image_shape = weights.shape[1:]
        valid_edges = get_valid_edges(weights.shape, offsets,
                                      separating_channel, strides, False)
        node_labeling1, cut_edges, used_mtxs, neighbors_features = compute_partial_mws_prim_segmentation(
            weights.ravel(), valid_edges.ravel(), offsets, separating_channel,
            image_shape)
        node_labeling_gt = compute_mws_segmentation(affs,
                                                    offsets,
                                                    separating_channel,
                                                    algorithm='kruskal')
        labels = compute_mws_segmentation(weights,
                                          offsets,
                                          separating_channel,
                                          algorithm='kruskal')
        # labels, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(weights, offsets, separating_channel)
        edges = np.zeros(affs.shape).ravel()
        # lbl = 1
        # for cut_edges, rep_edges in zip(cutting_edges, mutexes):
        #     for edge in cutting_edges:
        #         edges[edge] = lbl
        #     for edge in rep_edges:
        #         edges[edge] = lbl
        #     lbl += 1
        # edges = edges.reshape(affs.shape)
        import matplotlib.pyplot as plt
        from matplotlib import cm
        labels = labels.reshape(image_shape)
        labels1 = node_labeling1.reshape(image_shape)
        node_labeling_gt = node_labeling_gt.reshape(image_shape)

        # show_edge1 = cm.prism(edges[0] / edges[0].max())
        # show_edge2 = cm.prism(edges[1] / edges[1].max())
        # show_edge3 = cm.prism(edges[2] / edges[2].max())
        # show_edge4 = cm.prism(edges[3] / edges[3].max())

        show_seg1 = cm.prism(labels1 / labels1.max())
        show_seg = cm.prism(labels / labels.max())
        show_seg2 = cm.prism(node_labeling_gt / node_labeling_gt.max())
        show_raw = cm.gray(inputs.squeeze().detach().cpu().numpy())
        # img1 = np.concatenate([np.concatenate([show_edge1, show_edge2], axis=1),
        #                       np.concatenate([show_edge3, show_edge4], axis=1)], axis=0)
        img2 = np.concatenate([
            np.concatenate([show_seg, show_seg1], axis=1),
            np.concatenate([show_raw, show_seg2], axis=1)
        ],
                              axis=0)
        # plt.imshow(img1); plt.show()
        # plt.imshow(img2); plt.show()

    torch.save(model.state_dict(), saveToFile)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

    return
コード例 #28
0
from affogato.affinities import compute_multiscale_affinities, compute_affinities
from affogato.segmentation import compute_mws_segmentation

_, mask = compute_affinities(np.zeros_like(raw, dtype='int64'), offsets)

print("Total valid edges: ", mask.sum())

# Invert affs:
affs[3:] = 1. - affs[3:]
import time

tick = time.time()
final_segm = compute_mws_segmentation(affs,
                                      offsets,
                                      number_of_attractive_channels=3,
                                      strides=None,
                                      randomize_strides=False,
                                      algorithm='seeded',
                                      mask=None,
                                      initial_coordinate=(4, 150, 150))
print("Time seeded: ", time.time() - tick)

# tick = time.time()
# _ = compute_mws_segmentation(affs, offsets, number_of_attractive_channels=3,
#                              strides=None, randomize_strides=False,
#                              mask=None, initial_coordinate=(0,0,70))
# print("Time normal: ", time.time() - tick)

import segmfriends.vis as vis

fig, ax = vis.get_figure(1, 2, figsize=(10, 20))
vis.plot_segm(ax[0],