def measure_runtime(affs, n): times = [] for _ in range(n): t = time.time() compute_mws_segmentation(affs, OFFSETS, 3) times.append(time.time() - t) return np.min(times)
def test_mws_consistency(self): from affogato.segmentation import compute_mws_segmentation number_of_attractive_channels = 2 offsets = [[-1, 0], [0, -1], [-3, 0], [0, 3], [5, 5]] weights = np.random.rand(len(offsets), 100, 100) labels1 = compute_mws_segmentation(weights, offsets, number_of_attractive_channels, algorithm='kruskal') labels2 = compute_mws_segmentation(weights, offsets, number_of_attractive_channels, algorithm='prim') self._check_segmentation(labels1, labels2)
def test_mutex_malis_2d_gradient_descent_with_ignore_label_learning(self): from affogato.segmentation import compute_mws_segmentation from affogato.learning import mutex_malis shape = (100, 100) # generate random label image with large ignore region labels = np.zeros(shape, dtype=int) for i in range(10): for j in range(10): idx = (10 * i + j + 1) labels[10 * i:10 * (i + 1), 10 * j:10 * (j + 1)] = idx labels[:, 40:60] = 0 offsets = [[-1, 0], [0, -1], [-3, 0], [0, -3]] affs = 0.5 * np.random.rand(len(offsets), 100, 100) for epoch in range(50): loss, grads, seg1, seg2 = mutex_malis(affs, labels, offsets, 2, learn_in_ignore_label=True) affs -= grads affs = np.clip(affs, 0., 1.) number_of_attractive_channels = 2 labels1 = compute_mws_segmentation(affs, offsets, number_of_attractive_channels, algorithm='kruskal') self.assertEqual(grads.shape, affs.shape) self.assertEqual(loss, 0) labels1[:, 40:60] = 0 edges1 = self.seg2edges(labels1) edges2 = self.seg2edges(labels) self.assertTrue(np.allclose(edges1, edges2))
def test_mutex_malis_2d_gradient_descent(self): from affogato.segmentation import compute_mws_segmentation from affogato.learning import mutex_malis shape = (100, 100) labels = np.zeros(shape) for i in range(10): for j in range(10): labels[10 * i:10 * (i + 1), 10 * j:10 * (j + 1)] = 10 * i + j + 1 offsets = [[-1, 0], [0, -1], [-3, 0], [0, -3]] affs = 0.5 * np.ones((len(offsets), 100, 100)) for epoch in range(30): loss, grads, seg1, seg2 = mutex_malis(affs, labels, offsets, 2) affs -= grads affs = np.clip(affs, 0, 1) number_of_attractive_channels = 2 labels1 = compute_mws_segmentation(affs, offsets, number_of_attractive_channels, algorithm='kruskal') self.assertEqual(grads.shape, affs.shape) self.assertEqual(loss, 0) edges1 = self.seg2edges(labels1) edges2 = self.seg2edges(labels) self.assertTrue(np.allclose(edges1, edges2))
def _run_mws(self, input_): return compute_mws_segmentation( input_, self.offsets, number_of_attractive_channels=self.number_of_attractive_channels, strides=self.strides, randomize_strides=self.randomize_strides)
def create_test_data_3d(path): with h5py.File(path, "r") as f: affs = f["affinities"][:, :4, :256, :256] assert affs.shape[0] == len(OFFSETS) seperating_channel = 3 affs[:seperating_channel] *= -1 affs[:seperating_channel] += 1 offsets = OFFSETS seg = compute_mws_segmentation( affs, offsets, number_of_attractive_channels=seperating_channel, strides=None) assert affs.shape[0] == len(offsets) # check the results import napari v = napari.Viewer() v.add_image(affs) v.add_labels(seg) napari.run() with h5py.File("../data/test_data_3d.h5", "w") as f: f.create_dataset("affinities", data=affs, compression="gzip") f.create_dataset("segmentation", data=seg, compression="gzip") f.attrs["offsets"] = offsets
def mws_block(block_id): block = blocking.getBlock(block_id) bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) bb_affs = (slice(None), ) + bb affs_ = affs[bb_affs].copy( ) # we need to copy here to leave the original affs unchanged mask_ = None if mask is None else mask[bb] if noise_level > 0: affs_ += noise_level * np.random.rand(*affs_.shape) affs_[:ndim] *= -1 affs_[:ndim] += 1 seg = compute_mws_segmentation(affs_, offsets, number_of_attractive_channels=ndim, strides=strides, mask=mask_, randomize_strides=randomize_strides) max_id = relabelConsecutive(seg, out=seg, start_label=1, keep_zeros=mask is not None)[1] segmentation[bb] = seg return max_id
def run_default_mws_2d(affs, fg, offsets): shape = fg.shape mask = fg > .5 segmentation = np.zeros(shape, dtype='uint32') spatial_channels = [i for i, off in enumerate(offsets) if off[0] == 0] causal_channels = [i for i, off in enumerate(offsets) if off[0] != 0] spatial_offsets = [ off for i, off in enumerate(offsets) if i in spatial_channels ] causal_offsets = [ off for i, off in enumerate(offsets) if i in causal_channels ] for t in range(shape[0]): affs_t = affs[:, t] mask_t = mask[t] affs_spatial = affs_t[spatial_channels] seg = compute_mws_segmentation(2, affs_spatial, spatial_offsets, strides=strides, mask=mask_t) if t > 0: affs_causal = affs_t[causal_channels] seg_prev = segmentation[t - 1] max_id = int(seg_prev.max()) + 1 seg[seg != 0] += max_id seg = merge_causal(seg, seg_prev, affs_causal, causal_offsets) segmentation[t] = seg return segmentation
def mutex_watershed(affs, offsets, strides, randomize_strides=False, mask=None, noise_level=0): """ Compute mutex watershed segmentation. Introduced in "The Mutex Watershed and its Objective: Efficient, Parameter-Free Image Partitioning": https://arxiv.org/pdf/1904.12654.pdf Arguments: affs [np.ndarray] - input affinity map offsets [list[list[int]]] - pixel offsets corresponding to affinity channels strides [list[int]] - strides used to sub-sample long range edges randomize_strides [bool] - randomize the strides? (default: False) mask [np.ndarray] - mask to exclude from segmentation (default: None) noise_level [float] - sigma of noise added to affinities (default: 0) """ ndim = len(offsets[0]) if noise_level > 0: affs += noise_level * np.random.rand(*affs.shape) affs[:ndim] *= -1 affs[:ndim] += 1 seg = compute_mws_segmentation(affs, offsets, number_of_attractive_channels=ndim, strides=strides, mask=mask, randomize_strides=randomize_strides) relabelConsecutive(seg, out=seg, start_label=1, keep_zeros=mask is not None) return seg
def get_data(img, gt, affs, sigma, strides=[4, 4], overseg_factor=1.2, random_strides=False, fname='ex1.h5'): affinities = affs.copy() affinities[:sep_chnl] /= overseg_factor # affinities[sep_chnl:] *= overseg_factor # affinities = np.clip(affinities, 0, 1) # scale affinities in order to get an oversegmentation affinities[:sep_chnl] /= overseg_factor node_labeling = compute_mws_segmentation(affinities, offs, sep_chnl, strides=strides, randomize_strides=random_strides) node_labeling = node_labeling - 1 nodes = np.unique(node_labeling) save_file = h5py.File( '/g/kreshuk/hilt/projects/data/leptin_fused_tp1_ch_0/train/raw_wtsd_cpy/exs/' + fname, 'w') save_file.create_dataset(name='data', data=node_labeling) save_file.close() plt.imshow(cm.prism(node_labeling / node_labeling.max())) plt.show()
def segment(self, affinities): self.att_c = 2 seg = compute_mws_segmentation(affinities, self.offsets, att_c, strides=self.strides, mask=None).astype(np.int32) return label_cont(seg)
def mws_segmentation(affs, algo, strides=None): t0 = time.time() seperating_channel = 2 seg = compute_mws_segmentation(affs, OFFSETS, seperating_channel, algorithm=algo, strides=strides) return seg, time.time() - t0
def test_mws_reference_3d(self): from affogato.segmentation import compute_mws_segmentation test_path = os.path.join(os.path.split(__file__)[0], "../../../../data/test_data_3d.h5") with h5py.File(test_path, "r") as f: affs = f["affinities"][:] ref = f["segmentation"][:] offsets = f.attrs["offsets"] seg = compute_mws_segmentation(affs, offsets, 3, strides=None) self._check_segmentation(ref, seg)
def _run_mws(self, input_): assert len(input_) == len(self.offsets) input_[:self.dim] *= -1 input_[:self.dim] += 1 return compute_mws_segmentation( input_, self.offsets, number_of_attractive_channels=self.dim, strides=self.strides, randomize_strides=self.randomize_strides)
def mws_segmenter(prediction, offset_version='v2'): from affogato.segmentation import compute_mws_segmentation from train_affs import get_default_offsets, get_mws_offsets assert offset_version in ('v1', 'v2') offsets = get_default_offsets( ) if offset_version == 'v1' else get_mws_offsets() # invert the lr channels prediction[:2] *= -1 prediction[:2] += 1 # TODO change this api return compute_mws_segmentation(prediction, offsets, 2, strides=[4, 4])
def get_graphs(img, gt, sigma, edge_offsets): overseg_factor = 1.7 sep_chnl = 2 affinities = get_naive_affinities(gaussian(img, sigma=sigma), edge_offsets) affinities[:sep_chnl] *= -1 affinities[:sep_chnl] += +1 # scale affinities in order to get an oversegmentation affinities[:sep_chnl] /= overseg_factor affinities[sep_chnl:] *= overseg_factor affinities = np.clip(affinities, 0, 1) node_labeling = compute_mws_segmentation(affinities, edge_offsets, sep_chnl) node_labeling = node_labeling - 1 nodes = np.unique(node_labeling) try: assert all(nodes == np.array(range(len(nodes)), dtype=np.float)) except: Warning("node ids are off") # get edges from node labeling and edge features from affinity stats edge_feat, neighbors = get_edge_features_1d(node_labeling, edge_offsets, affinities) # get gt edge weights based on edges and gt image gt_edge_weights = calculate_gt_edge_costs(neighbors, node_labeling.squeeze(), gt.squeeze(), 0.5) edges = neighbors.astype(np.long) # calc multicut from gt gt_seg = get_current_soln(gt_edge_weights, node_labeling, edges) fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4) ax1.imshow(cm.prism(gt / gt.max())) ax1.set_title('gt') ax2.imshow(cm.prism(node_labeling / node_labeling.max())) ax2.set_title('sp') ax3.imshow(cm.prism(gt_seg / gt_seg.max())) ax3.set_title('mc') ax4.imshow(img) ax4.set_title('raw') plt.show() affinities = affinities.astype(np.float32) edge_feat = edge_feat.astype(np.float32) nodes = nodes.astype(np.float32) node_labeling = node_labeling.astype(np.float32) gt_edge_weights = gt_edge_weights.astype(np.float32) diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum() edges = np.sort(edges, axis=-1) edges = edges.T return img, gt, edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, nodes, affinities
def get_node_features(self, embeddings, _, post_input=False): separating_channel = 2 offsets = [[-1, 0], [0, -1], # direct 3d nhood for attractive edges [-1, -1], [1, 1], [-1, 1],[1, -1] # indirect 3d nhood for dam edges [-9, 0], [0, -9], # long range direct hood [-9, -9], [9, -9], [-9, -4], [-4, -9], [4, -9], [9, -4], # inplane diagonal dam edges [-27, 0], [0, -27]] edges = self.get_edges_from_embeddings(embeddings, offsets, separating_channel) node_labeling = compute_mws_segmentation(edges, offsets, separating_channel, strides=[10, 10], randomize_strides=True) stacked_superpixels = [ node_labeling == n for n in range(node_labeling.max() + 1) ] sp_indices = [sp.nonzero().cpu() for sp in stacked_superpixels] sp_feat_vecs = torch.empty( (len(sp_indices), embeddings.shape[0])).to(self.device).float() sp_similarity_reg = 0 for i, sp in enumerate(sp_indices): sp = sp.to(self.device) mass = len(sp) assert mass > 0 # ival = torch.index_select(features.squeeze(), 1, sp[:, -2].long()) # sp_features = torch.gather(ival, 2, torch.stack([sp[:, -1].long() for i in range(ival.shape[0])], dim=0).unsqueeze(-1)).squeeze() sp_features = embeddings[:, sp[:, -2], sp[:, -1]].T # if sp_features.shape[0] > 1: # shift = torch.randint(1, sp_features.shape[0], (1,)).item() # sp_similarity_reg = sp_similarity_reg + self.pw_dist(sp_features, sp_features.roll(shift, dims=0)).sum()/mass sp_feat_vecs[i] = sp_features.sum(0) / mass if self.writer is not None and post_input: plt.clf() fig = plt.figure(frameon=False) plt.imshow( _pca_project(embeddings.detach().squeeze().cpu().numpy())) plt.colorbar() self.writer.add_figure("image/embedding_proj", fig, self.writer_counter) self.writer_counter += 1 return sp_feat_vecs, sp_similarity_reg
def mutex_watershed(affs, offsets, strides, randomize_strides=False, mask=None, noise_level=0): assert compute_mws_segmentation is not None, "Need affogato for mutex watershed" ndim = len(offsets[0]) if noise_level > 0: affs += noise_level * np.random.rand(*affs.shape) affs[:ndim] *= -1 affs[:ndim] += 1 seg = compute_mws_segmentation(affs, offsets, number_of_attractive_channels=ndim, strides=strides, mask=mask, randomize_strides=randomize_strides) relabelConsecutive(seg, out=seg, start_label=1, keep_zeros=mask is not None) return seg
def __call__(self, affinities): if self.invert_affinities: affinities = 1. - affinities segmentation = compute_mws_segmentation(affinities, self.offsets, self.seperating_channel, strides=self.stride, randomize_strides=self.randomize_bounds, invert_repulsive_weights=self.invert_dam_channels, bias_cut=0., mask=None, algorithm='kruskal') # # Apply bias (0.0: merge everything; 1.0: split everything, or what can be split) # affinities[:self.seperating_channel] -= 2 * (self.bias - 0.5) # if self.stacked_2d: # affinities_ = np.require(affinities[self.keep_channels], requirements='C') # segmentation, _ = superpixel_stacked_from_affinities(affinities_, # self.damws_superpixel, # self.n_threads) # else: # segmentation, _ = self.damws_superpixel(affinities) return segmentation
def test_mws_masked(self): from affogato.segmentation import compute_mws_segmentation number_of_attractive_channels = 2 offsets = [[-1, 0], [0, -1], [-3, 0], [0, 3], [5, 5]] weights = np.random.rand(len(offsets), 100, 100) mask = np.ones((100, 100), dtype='bool') # exclude 10 % of pixel from foreground mask coords = np.where(mask) n_out = int(len(coords[0]) * .1) indices = np.random.permutation(len(coords[0]))[:n_out] coords = (coords[0][indices], coords[1][indices]) mask[coords] = False node_labels = compute_mws_segmentation(weights, offsets, number_of_attractive_channels, mask=mask) self.assertEqual(weights.shape[1:], node_labels.shape) # make sure mask is all non-zero self.assertTrue((node_labels[mask] != 0).all()) # make sure inv mask is all zeros self.assertTrue((node_labels[np.logical_not(mask)] == 0).all())
tick = time.time() segm = run_mws(affinities, offsets, [1,1,1], seperating_channel=3, randomize_bounds=False) print(segm) print("Took ", time.time() - tick) file_path_segm = os.path.join('/home/abailoni_local/', "ISBI_results_new_MWS.h5") # vigra.writeHDF5(segm.astype('uint32'), file_path_segm, 'segm') file_path = os.path.join('/home/abailoni_local/', "noise.h5") vigra.writeHDF5(affinities, file_path, 'kwargs') # Constantin implementation: tick = time.time() labels = compute_mws_segmentation(1 - affinities, offsets, 3, randomize_strides=False, algorithm='kruskal') print("Took ", time.time() - tick) # labels = compute_mws_segmentation(np.random.uniform(size=(3,1,2,2)), np.array([[-1, 0, 0], # [0, -1, 0], # [0, 0, -1]]), # 3, # randomize_strides=False, # algorithm='kruskal') print(labels) vigra.writeHDF5(labels.astype('uint32'), file_path_segm, 'segm_dMWS') # # Steffen: # configs = {'models': yaml2dict('./experiments/models_config.yml'),
def get_pix_data(shape=(256, 256)): """ This generates raw-gt-superpixels and correspondinng rags of rectangles and circles""" rsign = lambda: (-1)**np.random.randint(0, 2) edge_offsets = [[0, -1], [-1, 0], [-3, 0], [0, -3], [-6, 0], [0, -6]] # offsets defining the edges for pixel affinities overseg_factor = 1.7 sep_chnl = 2 # channel separating attractive from repulsive edges n_circles = 5 # number of ellipses in image n_polys = 10 # number of rand polys in image n_rect = 5 # number rectangles in image circle_color = np.array([1, 0, 0], dtype=np.float) rect_color = np.array([0, 0, 1], dtype=np.float) col_diff = 0.4 # by this margin object color can vary ranomly min_r, max_r = 10, 20 # min and max radii of ellipses/circles min_dist = max_r img = np.random.randn(*(shape + (3, ))) / 5 # init image with some noise gt = np.zeros(shape) # get some random frequencies ri1, ri2, ri3, ri4, ri5, ri6 = rsign() * ((np.random.rand() * 2) + .5), \ rsign() * ((np.random.rand() * 2) + .5), \ (np.random.rand() * 4) + 3, \ (np.random.rand() * 4) + 3, \ rsign() * ((np.random.rand() * 2) + .5), \ rsign() * ((np.random.rand() * 2) + .5) x = np.zeros(shape) x[:, :] = np.arange(img.shape[0])[np.newaxis, :] y = x.transpose() # add background frequency interferences img += (np.sin( np.sqrt((x * ri1)**2 + ((shape[1] - y) * ri2)**2) * ri3 * np.pi / shape[0]))[..., np.newaxis] img += (np.sin( np.sqrt((x * ri5)**2 + ((shape[1] - y) * ri6)**2) * ri4 * np.pi / shape[1]))[..., np.newaxis] # smooth a bit img = gaussian(np.clip(img, 0.1, 1), sigma=.8) # add some circles circles = [] cmps = [] while len(circles) < n_circles: mp = np.random.randint(min_r, shape[0] - min_r, 2) too_close = False for cmp in cmps: if np.linalg.norm(cmp - mp) < min_dist: too_close = True if too_close: continue r = np.random.randint(min_r, max_r, 2) circles.append(draw.circle(mp[0], mp[1], r[0], shape=shape)) cmps.append(mp) # add some random polygons polys = [] while len(polys) < n_polys: mp = np.random.randint(min_r, shape[0] - min_r, 2) too_close = False for cmp in cmps: if np.linalg.norm(cmp - mp) < min_dist // 2: too_close = True if too_close: continue circle = draw.circle_perimeter(mp[0], mp[1], max_r) poly_vert = np.random.choice(len(circle[0]), np.random.randint(3, 6), replace=False) polys.append( draw.polygon(circle[0][poly_vert], circle[1][poly_vert], shape=shape)) cmps.append(mp) # add some random rectangles rects = [] while len(rects) < n_rect: mp = np.random.randint(min_r, shape[0] - min_r, 2) _len = np.random.randint(min_r // 2, max_r, (2, )) too_close = False for cmp in cmps: if np.linalg.norm(cmp - mp) < min_dist: too_close = True if too_close: continue start = (mp[0] - _len[0], mp[1] - _len[1]) rects.append( draw.rectangle(start, extent=(_len[0] * 2, _len[1] * 2), shape=shape)) cmps.append(mp) # draw polys and give them some noise for poly in polys: color = np.random.rand(3) while np.linalg.norm(color - circle_color) < col_diff or np.linalg.norm( color - rect_color) < col_diff: color = np.random.rand(3) img[poly[0], poly[1], :] = color img[poly[0], poly[1], :] += np.random.randn(len( poly[1]), 3) / 5 # add noise to the polygons # draw circles with some frequency cols = np.random.choice(np.arange(4, 11, 1).astype(np.float) / 10, n_circles, replace=False) # get colors for i, circle in enumerate(circles): gt[circle[0], circle[1]] = 1 + (i / 10) ri1, ri2, ri3, ri4, ri5, ri6 = rsign() * ((np.random.rand() * 4) + 7), \ rsign() * ((np.random.rand() * 4) + 7), \ (np.random.rand() + 1) * 8, \ (np.random.rand() + 1) * 8, \ rsign() * ((np.random.rand() * 4) + 7), \ rsign() * ((np.random.rand() * 4) + 7) img[circle[0], circle[1], :] = np.array([cols[i], 0.0, 0.0]) # set even color intensity # set interference of two freqs in circle color channel img[circle[0], circle[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin( np.sqrt((x[circle[0], circle[1]] * ri5)**2 + ((shape[1] - y[circle[0], circle[1]]) * ri2)**2) * ri3 * np.pi / shape[0]))[..., np.newaxis] * 0.15) + 0.2 img[circle[0], circle[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin( np.sqrt((x[circle[0], circle[1]] * ri6)**2 + ((shape[1] - y[circle[0], circle[1]]) * ri1)**2) * ri4 * np.pi / shape[1]))[..., np.newaxis] * 0.15) + 0.2 # draw rectangles with some frequency cols = np.random.choice(np.arange(4, 11, 1).astype(np.float) / 10, n_rect, replace=False) for i, rect in enumerate(rects): gt[rect[0], rect[1]] = 2 + (i / 10) ri1, ri2, ri3, ri4, ri5, ri6 = rsign() * ((np.random.rand() * 4) + 7), \ rsign() * ((np.random.rand() * 4) + 7), \ (np.random.rand() + 1) * 8, \ (np.random.rand() + 1) * 8, \ rsign() * ((np.random.rand() * 4) + 7), \ rsign() * ((np.random.rand() * 4) + 7) img[rect[0], rect[1], :] = np.array([0.0, 0.0, cols[i]]) img[rect[0], rect[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin( np.sqrt((x[rect[0], rect[1]] * ri5)**2 + ((shape[1] - y[rect[0], rect[1]]) * ri2)**2) * ri3 * np.pi / shape[0]))[..., np.newaxis] * 0.15) + 0.2 img[rect[0], rect[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin( np.sqrt((x[rect[0], rect[1]] * ri1)**2 + ((shape[1] - y[rect[0], rect[1]]) * ri6)**2) * ri4 * np.pi / shape[1]))[..., np.newaxis] * 0.15) + 0.2 img = np.clip(img, 0, 1) # clip to valid range # get affinities and calc superpixels with mutex watershed affinities = get_naive_affinities(gaussian(img, sigma=.2), edge_offsets) affinities[:sep_chnl] *= -1 affinities[:sep_chnl] += +1 # scale affinities in order to get an oversegmentation affinities[:sep_chnl] /= overseg_factor affinities[sep_chnl:] *= overseg_factor affinities = np.clip(affinities, 0, 1) node_labeling = compute_mws_segmentation(affinities, edge_offsets, sep_chnl) node_labeling = node_labeling - 1 nodes = np.unique(node_labeling) try: assert all(nodes == np.array(range(len(nodes)), dtype=np.float)) except: Warning("node ids are off") # get edges from node labeling and edge features from affinity stats edge_feat, neighbors = get_edge_features_1d(node_labeling, edge_offsets, affinities) # get gt edge weights based on edges and gt image gt_edge_weights = calculate_gt_edge_costs(neighbors, node_labeling.squeeze(), gt.squeeze()) edges = neighbors.astype(np.long) # # calc multicut from gt # gt_seg = get_current_soln(gt_edge_weights, node_labeling, edges) # # show result (uncomment for testing) # # fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4) # ax1.imshow(cm.prism(gt/gt.max()));ax1.set_title('gt') # ax2.imshow(cm.prism(node_labeling / node_labeling.max()));ax2.set_title('sp') # ax3.imshow(cm.prism(gt_seg / gt_seg.max()));ax3.set_title('mc') # ax4.imshow(img);ax4.set_title('raw') # plt.show() affinities = affinities.astype(np.float32) edge_feat = edge_feat.astype(np.float32) nodes = nodes.astype(np.float32) node_labeling = node_labeling.astype(np.float32) gt_edge_weights = gt_edge_weights.astype(np.float32) diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum() edges = np.sort(edges, axis=-1) edges = edges.T return img, gt, edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, nodes, affinities
atr = 1 aff, offsets = reorder_and_invert(aff, offsets, atr * n_dir, dist_per_dir=n_d) aff_new = np.empty((aff.shape[0] - 12, *aff.shape[1:])) aff_new[:6] = aff[[0, 1, 2, 6, 10, 14]] aff_new[6:] = aff[18:] offsets_new = [offsets[i] for i in [0, 1, 2, 6, 10, 14]] + offsets[18:] labels = compute_mws_segmentation(aff_new[:], offsets_new[:], number_of_attractive_channels=6, strides=None, randomize_strides=False) X = labels[1] numrows, numcols = X.shape def format_coord(x, y): col = int(x + 0.5) row = int(y + 0.5) if 0 <= col < numcols and 0 <= row < numrows: z = X[row, col] return f'x={x:.1f}, y={y:.1f}, z={z}' else: return 'x=%1.4f, y=%1.4f' % (x, y)
def getFastIOUMWS(dist, S, angles, strides=None, S_attr=1, mask=None, smws=False, verbose=False): getOffset = lambda theta, s: list(map(int, list(map(np.around, [s * np.sin(theta), s * np.cos(theta)])))) S = float(S) S_attr = float(S_attr) @jit(nopython=True) def validateCoordinates(row, col): max_r = dist.shape[1] max_c = dist.shape[2] in_r = 0 <= row < max_r in_c = 0 <= col < max_c return in_r & in_c @jit(nopython=True) def getIntersection(r1: float, _r1: float, r2: float, _r2: float, S: float) -> float: inter = 0 if r1 > S: if _r2 > S: inter += S inter += min(_r1, _r2 - S) else: inter += _r2 inter += min(r2, r1 - S) else: if _r2 > S: inter += r1 inter += min(_r1, _r2 - S) else: inter += max(r1 + _r2 - S, 0) return float(inter) @jit(nopython=True, parallel=True) def fillAffinities(dist, offset, ray_idx, S): rows = dist.shape[1] cols = dist.shape[2] affs = np.zeros((rows, cols)) for row in range(rows): for col in range(cols): row_off = row + offset[0] col_off = col + offset[1] if validateCoordinates(row_off, col_off): dists_per_pixel0 = dist[:, row, col] dists_per_pixel1 = dist[:, row_off, col_off] if not np.any(dists_per_pixel0) and not np.any(dists_per_pixel1): affs[row, col] = 1.0 else: _dist1 = dists_per_pixel0[ray_idx] _dist2 = dists_per_pixel0[ray_idx + n_rays // 2] dist1 = dists_per_pixel1[ray_idx] dist2 = dists_per_pixel1[ray_idx + n_rays // 2] intersection = getIntersection(_dist1.item(), _dist2.item(), dist1.item(), dist2.item(), S) union = _dist1.item() + _dist2.item() + dist1.item() + dist2.item() - intersection if union == 0: iou = 0 else: iou = intersection / union affs[row, col] = iou return affs attractive_angles = [0, 90] n_rays = angles.shape[0] angl = angles[:n_rays // 2] offsets0 = [getOffset(np.deg2rad(angle), S_attr) for angle in attractive_angles] offsets1 = [getOffset(a, S) for a in angl] offsets = offsets0 + offsets1 print('Generated offsets:', offsets) affs_attr = np.ones((len(offsets0), dist.shape[1], dist.shape[2])) angles_d = np.rad2deg(angles) off = Dict.empty(key_type=types.int16, value_type=types.int16) for idx, offset in enumerate(offsets0): ray = np.rad2deg(np.arctan2(*offset)) ray_idx = np.where(angles_d == ray)[0] if verbose: print('Angles', angles_d[ray_idx], angles_d[ray_idx + n_rays // 2]) off[0] = offset[0] off[1] = offset[1] affs_attr[idx] = fillAffinities(dist, off, ray_idx[0], S_attr) affs_repul = np.zeros((len(offsets1), dist.shape[1], dist.shape[2])) for idx, offset in enumerate(offsets1): if verbose: print('Angles', np.rad2deg(angles[idx]), np.rad2deg(angles[idx + n_rays // 2])) off[0] = offset[0] off[1] = offset[1] affs_repul[idx] = fillAffinities(dist, off, idx, S) merged_aff = np.vstack((affs_attr, affs_repul)) merged_aff[merged_aff > 1] = 1 merged_aff[merged_aff < 0] = 0 merged_aff[len(attractive_angles):] *= -1 merged_aff[len(attractive_angles):] += 1 # for i in range(merged_aff.shape[0]): # plt.figure() # plt.imshow(merged_aff[i]) # plt.show() # if smws: # mask = np.expand_dims(mask, 0) # merged_aff = np.vstack((merged_aff, mask)) # merged_aff = np.vstack((merged_aff, 1 - mask)) # labels = computeSMWS(merged_aff, offsets, len(attractive_angles), stride=strides) # else: labels = compute_mws_segmentation(merged_aff, offsets, len(attractive_angles), algorithm='kruskal', strides=strides, mask=mask) return labels
plotting_helper(distances[:10]) affinities, offsets = reorder_and_invert(affinities, offsets, n_directions*attr_layers, dist_per_dir=len(default_distances)) print(f'affinities.shape: \t {affinities.shape} \n' f'offsets.shape: \t {np.array(offsets).shape}') affinities_new, offsets_new = exclude_some_short_edges(affinities, offsets, sampling_factor=less_attr, n_directions=n_directions*attr_layers, z_dir=compute_z) print(offsets_new) labels_new = compute_mws_segmentation(affinities_new, offsets_new, number_of_attractive_channels=number_of_attractive_channels, algorithm='kruskal') #plotting_helper(affinities_new) #plotAffinities(affinities_new[:8, 1], 'Affinities') numrows, numcols = labels_new.shape[1:] def format_maker(z_values): numcols, numrows = z_values.shape def format_coord(x, y): col = int(x + 0.5) row = int(y + 0.5) if 0 <= col < numcols and 0 <= row < numrows: z = z_values[row, col] return f'x={x:.1f}, y={y:.1f}, z={z}'
def mws_segmentation(affs): t0 = time.time() seperating_channel = 3 seg = compute_mws_segmentation(seperating_channel, OFFSETS, affs) return seg, time.time() - t0
def trainAffPredCircles(saveToFile, device, separating_channel, offsets, strides, numEpochs=8): file = 'mask/masks.h5' rootPath = '/g/kreshuk/hilt/projects/fewShotLearning/data/Discs' dloader = DataLoader(CustomDiscDset(length=5), batch_size=1, shuffle=True, pin_memory=True) print('----START TRAINING----' * 4) model = UNet(n_channels=1, n_classes=len(offsets), bilinear=True) for param in model.parameters(): param.requires_grad = True criterion = nn.MSELoss() optim = torch.optim.Adam(model.parameters()) model.cuda() since = time.time() for epoch in range(numEpochs): print('Epoch {}/{}'.format(epoch, numEpochs - 1)) print('-' * 10) # Each epoch has a training and validation phase # Iterate over data. for step, (inputs, _, affinities) in enumerate(dloader): inputs = inputs.to(device) affinities = affinities.to(device) # zero the parameter gradients optim.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(True): outputs = model(inputs) loss = criterion(outputs, affinities) loss.backward() optim.step() weights = outputs.squeeze().detach().cpu().numpy() # weights[separating_channel:] /= 2 affs = affinities.squeeze().detach().cpu().numpy() weights[separating_channel:] *= -1 weights[separating_channel:] += +1 affs[separating_channel:] *= -1 affs[separating_channel:] += +1 weights[:separating_channel] /= 1.5 ndim = len(offsets[0]) assert all(len(off) == ndim for off in offsets) image_shape = weights.shape[1:] valid_edges = get_valid_edges(weights.shape, offsets, separating_channel, strides, False) node_labeling1, cut_edges, used_mtxs, neighbors_features = compute_partial_mws_prim_segmentation( weights.ravel(), valid_edges.ravel(), offsets, separating_channel, image_shape) node_labeling_gt = compute_mws_segmentation(affs, offsets, separating_channel, algorithm='kruskal') labels = compute_mws_segmentation(weights, offsets, separating_channel, algorithm='kruskal') # labels, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(weights, offsets, separating_channel) edges = np.zeros(affs.shape).ravel() # lbl = 1 # for cut_edges, rep_edges in zip(cutting_edges, mutexes): # for edge in cutting_edges: # edges[edge] = lbl # for edge in rep_edges: # edges[edge] = lbl # lbl += 1 # edges = edges.reshape(affs.shape) import matplotlib.pyplot as plt from matplotlib import cm labels = labels.reshape(image_shape) labels1 = node_labeling1.reshape(image_shape) node_labeling_gt = node_labeling_gt.reshape(image_shape) # show_edge1 = cm.prism(edges[0] / edges[0].max()) # show_edge2 = cm.prism(edges[1] / edges[1].max()) # show_edge3 = cm.prism(edges[2] / edges[2].max()) # show_edge4 = cm.prism(edges[3] / edges[3].max()) show_seg1 = cm.prism(labels1 / labels1.max()) show_seg = cm.prism(labels / labels.max()) show_seg2 = cm.prism(node_labeling_gt / node_labeling_gt.max()) show_raw = cm.gray(inputs.squeeze().detach().cpu().numpy()) # img1 = np.concatenate([np.concatenate([show_edge1, show_edge2], axis=1), # np.concatenate([show_edge3, show_edge4], axis=1)], axis=0) img2 = np.concatenate([ np.concatenate([show_seg, show_seg1], axis=1), np.concatenate([show_raw, show_seg2], axis=1) ], axis=0) # plt.imshow(img1); plt.show() # plt.imshow(img2); plt.show() torch.save(model.state_dict(), saveToFile) time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) return
from affogato.affinities import compute_multiscale_affinities, compute_affinities from affogato.segmentation import compute_mws_segmentation _, mask = compute_affinities(np.zeros_like(raw, dtype='int64'), offsets) print("Total valid edges: ", mask.sum()) # Invert affs: affs[3:] = 1. - affs[3:] import time tick = time.time() final_segm = compute_mws_segmentation(affs, offsets, number_of_attractive_channels=3, strides=None, randomize_strides=False, algorithm='seeded', mask=None, initial_coordinate=(4, 150, 150)) print("Time seeded: ", time.time() - tick) # tick = time.time() # _ = compute_mws_segmentation(affs, offsets, number_of_attractive_channels=3, # strides=None, randomize_strides=False, # mask=None, initial_coordinate=(0,0,70)) # print("Time normal: ", time.time() - tick) import segmfriends.vis as vis fig, ax = vis.get_figure(1, 2, figsize=(10, 20)) vis.plot_segm(ax[0],