def segment_volume_mc(pmaps, threshold=0.4, sigma=2.0, beta=0.6, ws=None, sp_min_size=100): if ws is None: ws = distance_transform_watershed(pmaps, threshold, sigma, min_size=sp_min_size)[0] rag = compute_rag(ws, 1) features = nrag.accumulateEdgeMeanAndLength(rag, pmaps, numberOfThreads=1) probs = features[:, 0] # mean edge prob edge_sizes = features[:, 1] costs = transform_probabilities_to_costs(probs, edge_sizes=edge_sizes, beta=beta) graph = nifty.graph.undirectedGraph(rag.numberOfNodes) graph.insertEdges(rag.uvIds()) node_labels = multicut_kernighan_lin(graph, costs) return nifty.tools.take(node_labels, ws)
def segment_volume(self, pmaps): if self.ws_2D: # WS in 2D ws = self.ws_dt_2D(pmaps) else: # WS in 3D ws, _ = distance_transform_watershed(pmaps, self.ws_threshold, self.ws_sigma, sigma_weights=self.ws_w_sigma, min_size=self.ws_minsize) rag = compute_rag(ws, 1) # Computing edge features features = nrag.accumulateEdgeMeanAndLength( rag, pmaps, numberOfThreads=1) # DO NOT CHANGE numberOfThreads probs = features[:, 0] # mean edge prob edge_sizes = features[:, 1] # Prob -> edge costs costs = transform_probabilities_to_costs(probs, edge_sizes=edge_sizes, beta=self.beta) # Creating graph graph = nifty.graph.undirectedGraph(rag.numberOfNodes) graph.insertEdges(rag.uvIds()) # Solving Multicut node_labels = multicut_kernighan_lin(graph, costs) return nifty.tools.take(node_labels, ws)
def get_edge_features_1d(sp_seg, offsets, affinities): offsets_3d = [] for off in offsets: offsets_3d.append([0] + off) rag = feats.compute_rag(np.expand_dims(sp_seg, axis=0)) edge_feat = feats.compute_affinity_features( rag, np.expand_dims(affinities, axis=1), offsets_3d)[:, :] return edge_feat
def segment_mc(pred, seg, delta): rag = feats.compute_rag(seg) edge_probs = embed.edge_probabilities_from_embeddings( pred, seg, rag, delta) edge_sizes = feats.compute_boundary_mean_and_length(rag, pred[0])[:, 1] costs = mc.transform_probabilities_to_costs(edge_probs, edge_sizes=edge_sizes) mc_seg = mc.multicut_kernighan_lin(rag, costs) mc_seg = feats.project_node_labels_to_pixels(rag, mc_seg) return mc_seg
def test_region_features(self): from elf.segmentation.features import compute_rag, compute_region_features shape = (32, 128, 128) inp = np.random.rand(*shape).astype('float32') seg = self.make_seg(shape) rag = compute_rag(seg) uv_ids = rag.uvIds() feats = compute_region_features(uv_ids, inp, seg) self.assertEqual(len(uv_ids), len(feats)) self.assertFalse(np.allclose(feats, 0))
def multicut_from_probas(segmentation, edges, edge_weights): rag = compute_rag(segmentation) edge_dict = dict(zip(list(map(tuple, edges)), edge_weights)) costs = np.empty(len(edge_weights)) for i, neighbors in enumerate(rag.uvIds()): if tuple(neighbors) in edge_dict: costs[i] = edge_dict[tuple(neighbors)] else: costs[i] = edge_dict[(neighbors[1], neighbors[0])] costs = transform_probabilities_to_costs(costs) node_labels = multicut_kernighan_lin(rag, costs) return project_node_labels_to_pixels(rag, node_labels).squeeze()
def supervoxel_merging(mem, sv, beta=0.5, verbose=False): rag = feats.compute_rag(sv) costs = feats.compute_boundary_features(rag, mem)[:, 0] edge_sizes = feats.compute_boundary_mean_and_length(rag, mem)[:, 1] costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes, beta=beta) node_labels = mc.multicut_kernighan_lin(rag, costs) segmentation = feats.project_node_labels_to_pixels(rag, node_labels) return segmentation
def test_boundary_features_with_filters(self): from elf.segmentation.features import compute_rag, compute_boundary_features_with_filters shape = (64, 128, 128) inp = np.random.rand(*shape).astype('float32') seg = self.make_seg(shape) rag = compute_rag(seg) feats = compute_boundary_features_with_filters(rag, inp) self.assertEqual(rag.numberOfEdges, len(feats)) self.assertFalse(np.allclose(feats, 0)) feats = compute_boundary_features_with_filters(rag, inp, apply_2d=True) self.assertEqual(rag.numberOfEdges, len(feats)) self.assertFalse(np.allclose(feats, 0))
def test_lifted_problem_from_probabilities(self): from elf.segmentation.features import ( compute_rag, lifted_problem_from_probabilities) shape = (32, 128, 128) seg = self.make_seg(shape) rag = compute_rag(seg) n_classes = 3 input_maps = [ np.random.rand(*shape).astype('float32') for _ in range(n_classes) ] assignment_threshold = .5 graph_depth = 4 lifted_uvs, lifted_costs = lifted_problem_from_probabilities( rag, seg, input_maps, assignment_threshold, graph_depth) self.assertEqual(len(lifted_uvs), len(lifted_costs))
def segment_volume_lmc_from_seg(boundary_pmaps, nuclei_seg, threshold=0.4, sigma=2.0, sp_min_size=100): watershed = distance_transform_watershed(boundary_pmaps, threshold, sigma, min_size=sp_min_size)[0] # compute the region adjacency graph rag = compute_rag(watershed) # compute the edge costs features = compute_boundary_mean_and_length(rag, boundary_pmaps) costs, sizes = features[:, 0], features[:, 1] # transform the edge costs from [0, 1] to [-inf, inf], which is # necessary for the multicut. This is done by intepreting the values # as probabilities for an edge being 'true' and then taking the negative log-likelihood. # in addition, we weight the costs by the size of the corresponding edge # we choose a boundary bias smaller than 0.5 in order to # decrease the degree of over segmentation boundary_bias = .6 costs = transform_probabilities_to_costs(costs, edge_sizes=sizes, beta=boundary_bias) max_cost = np.abs(np.max(costs)) lifted_uvs, lifted_costs = lifted_problem_from_segmentation( rag, watershed, nuclei_seg, overlap_threshold=0.2, graph_depth=4, same_segment_cost=5 * max_cost, different_segment_cost=-5 * max_cost) # solve the full lifted problem using the kernighan lin approximation introduced in # http://openaccess.thecvf.com/content_iccv_2015/html/Keuper_Efficient_Decomposition_of_ICCV_2015_paper.html node_labels = lmc.lifted_multicut_kernighan_lin(rag, costs, lifted_uvs, lifted_costs) lifted_segmentation = project_node_labels_to_pixels(rag, node_labels) return lifted_segmentation
def test_lifted_problem_from_segmentation(self): from elf.segmentation.features import (compute_rag, lifted_problem_from_segmentation ) shape = (32, 128, 128) N = 5 for _ in range(N): ws = self.make_seg(shape) rag = compute_rag(ws) seg = self.make_seg(shape) overlap_threshold = .5 graph_depth = 4 lifted_uvs, lifted_costs = lifted_problem_from_segmentation( rag, ws, seg, overlap_threshold, graph_depth, same_segment_cost=1., different_segment_cost=-1) self.assertEqual(len(lifted_uvs), len(lifted_costs))
def update_env_data(env, data_iter, data_set, device, with_gt_edges=False, fe_grad=False): raw, gt, sp_seg, indices = next(data_iter) rags = [compute_rag(sseg.numpy()) for sseg in sp_seg] # edges = [torch.from_numpy(rag.uvIds().astype(np.long)).T.to(device) for rag in rags] raw, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to(device) edges, edge_feat, gt_edges = data_set.get_graphs(indices, sp_seg, device) # for e1, e2 in zip(edges, _edges): # assert not (e1 != e2).any() # gt_edges = [calculate_gt_edge_costs(s_edges.T, sseg.squeeze(), sgt.squeeze(), cfg.gt_edge_overlap_thresh).to(device).float() for s_edges, sseg, sgt in zip(edges, sp_seg, gt)] env.update_data(edge_ids=edges, gt_edges=gt_edges, sp_seg=sp_seg, raw=raw, gt=gt, fe_grad=fe_grad, rags=rags, edge_features=edge_feat)
def compute_graph_and_weights(path, return_edge_sizes=False): from nifty.graph import undirectedGraph with h5py.File(path, 'a') as f: # if 'features' in f: if False: edges = f['edges'][:] feats = f['features'][:] edge_sizes = f['edge_sizes'][:] z_edges = f['z_edges'][:] n_nodes = int(edges.max()) + 1 else: from elf.segmentation.features import compute_rag, compute_boundary_features, compute_z_edge_mask seg = f['watershed'][:] boundaries = f['boundaries'][:] boundaries[boundaries > .2] *= 3 boundaries = np.clip(boundaries, 0, 1) rag = compute_rag(seg) n_nodes = rag.numberOfNodes feats = compute_boundary_features(rag, boundaries) feats, edge_sizes = feats[:, 0], feats[:, -1] edges = rag.uvIds() z_edges = compute_z_edge_mask(rag, seg) # f.create_dataset('edges', data=edges) # f.create_dataset('edge_sizes', data=edge_sizes) # f.create_dataset('features', data=feats) # f.create_dataset('z_edges', data=z_edges) graph = undirectedGraph(n_nodes) graph.insertEdges(edges) if return_edge_sizes: return graph, feats, edge_sizes, z_edges, boundaries else: return graph, feats
separating_channel, offsets), batch_size=1, shuffle=True, pin_memory=True) neighbors, nodes, seg, gt_seg, affs, gt_affs = next(iter(dloader_disc)) offsets = [[0, 0, -1], [0, -1, 0], [0, -3, 0], [0, 0, -3]] affs = np.transpose(affs.cpu().numpy(), (1, 0, 2, 3)) gt_affs = np.transpose(gt_affs.cpu().numpy(), (1, 0, 2, 3)) seg = seg.cpu().numpy() gt_seg = gt_seg.cpu().numpy() boundary_input = np.mean(affs, axis=0) gt_boundary_input = np.mean(gt_affs, axis=0) rag = feats.compute_rag(seg) # edges rag.uvIds() [[1, 2], ...] costs = feats.compute_affinity_features(rag, affs, offsets)[:, 0] gt_costs = calculate_gt_edge_costs(rag.uvIds(), seg.squeeze(), gt_seg.squeeze()) edge_sizes = feats.compute_boundary_mean_and_length(rag, boundary_input)[:, 1] gt_edge_sizes = feats.compute_boundary_mean_and_length(rag, gt_boundary_input)[:, 1] costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes) gt_costs = mc.transform_probabilities_to_costs(gt_costs, edge_sizes=edge_sizes) node_labels = mc.multicut_kernighan_lin(rag, costs) gt_node_labels = mc.multicut_kernighan_lin(rag, gt_costs)
def refine_seg(raw, seeds, restrict_to_seeds=True, restrict_to_bb=False, return_intermediates=False): pred = get_prediction(raw, cache=False) n_threads = 1 # make watershed ws, _ = stacked_watershed(pred, threshold=.5, sigma_seeds=1., n_threads=n_threads) rag = compute_rag(ws, n_threads=n_threads) edge_feats = compute_boundary_mean_and_length(rag, pred, n_threads=n_threads) edge_feats, edge_sizes = edge_feats[:, 0], edge_feats[:, 1] z_edges = compute_z_edge_mask(rag, ws) edge_costs = compute_edge_costs(edge_feats, beta=.4, weighting_scheme='xyz', edge_sizes=edge_sizes, z_edge_mask=z_edges) # make seeds and map them to edges bb = tuple( slice(sh // 2 - ha // 2, sh // 2 + ha // 2) for sh, ha in zip(pred.shape, seeds.shape)) seeds[seeds < 0] = 0 seeds = vigra.analysis.labelVolumeWithBackground(seeds.astype('uint32')) seed_ids = np.unique(seeds) seed_mask = binary_erosion(seeds, iterations=2) seeds_new = seeds.copy() seeds_new[~seed_mask] = 0 seed_ids_new = np.unique(seeds_new) for seed_id in seed_ids: if seed_id in seed_ids_new: continue seeds_new[seeds == seed_id] = seed_id seeds_full = np.zeros(pred.shape, dtype=seeds.dtype) seeds_full[bb] = seeds seeds = seeds_full seed_labels = compute_maximum_label_overlap(ws, seeds, ignore_zeros=True) edge_ids = rag.uvIds() labels_u = seed_labels[edge_ids[:, 0]] labels_v = seed_labels[edge_ids[:, 1]] seed_mask = np.logical_and(labels_u != 0, labels_v != 0) same_seed = np.logical_and(seed_mask, labels_u == labels_v) diff_seed = np.logical_and(seed_mask, labels_u != labels_v) max_att = edge_costs.max() + .1 max_rep = edge_costs.min() - .1 edge_costs[same_seed] = max_att edge_costs[diff_seed] = max_rep # run multicut node_labels = multicut_kernighan_lin(rag, edge_costs) if restrict_to_seeds: seed_nodes = np.unique(node_labels[seed_labels > 0]) node_labels[~np.isin(node_labels, seed_nodes)] = 0 vigra.analysis.relabelConsecutive(node_labels, out=node_labels) seg = project_node_labels_to_pixels(rag, node_labels, n_threads=n_threads) if restrict_to_bb: bb_mask = np.zeros(seg.shape, dtype='bool') bb_mask[bb] = 1 seg[~bb_mask] = 0 if return_intermediates: return pred, ws, seeds, seg else: return seg
def get(self, idx): radius = np.random.randint(max(self.shape) // 5, max(self.shape) // 3) mp = (np.random.randint(0 + radius, self.shape[0] - radius), np.random.randint(0 + radius, self.shape[1] - radius)) # mp = self.mp data = np.zeros(shape=self.shape, dtype=np.float) gt = np.zeros(shape=self.shape, dtype=np.float) for y in range(self.shape[0]): for x in range(self.shape[1]): ly, lx = y - mp[0], x - mp[1] if (ly**2 + lx**2)**.5 <= radius: data[y, x] += np.sin(x * 10 * np.pi / self.shape[1]) data[y, x] += np.sin( np.sqrt(x**2 + y**2) * 20 * np.pi / self.shape[1]) # data[y, x] += 4 gt[y, x] = 1 else: data[y, x] += np.sin(y * 10 * np.pi / self.shape[1]) data[y, x] += np.sin( np.sqrt(x**2 + (self.shape[1] - y)**2) * 10 * np.pi / self.shape[1]) data += 1 # plt.imshow(data);plt.show() gt_affinities, _ = compute_affinities(gt == 1, offsets) seg_arbitrary = np.zeros_like(data) square_dict = {} i = 0 granularity = 30 for y in range(self.shape[0]): for x in range(self.shape[1]): if (x // granularity, y // granularity) not in square_dict: square_dict[(x // granularity, y // granularity)] = i i += 1 seg_arbitrary[y, x] += square_dict[(x // granularity, y // granularity)] seg_arbitrary += gt * 1000 i = 0 segs = np.unique(seg_arbitrary) seg_arb = np.zeros_like(seg_arbitrary) for seg in segs: seg_arb += (seg_arbitrary == seg) * i i += 1 seg_arbitrary = seg_arb rag = feats.compute_rag(np.expand_dims(seg_arbitrary, axis=0)) neighbors = rag.uvIds() affinities = get_naive_affinities(data, offsets) # edge_feat = get_edge_features_1d(seg_arbitrary, offsets, affinities) # self.edge_offsets = [[1, 0], [0, 1], [1, 0], [0, 1]] # self.sep_chnl = 2 # affinities = np.stack((ndimage.sobel(data, axis=0), ndimage.sobel(data, axis=1))) # affinities = np.concatenate((affinities, affinities), axis=0) affinities[:self.sep_chnl] *= -1 affinities[:self.sep_chnl] += +1 affinities[self.sep_chnl:] /= 0.2 # raw = torch.tensor(data).unsqueeze(0).unsqueeze(0).float() # if self.aff_pred is not None: # gt_affinities[self.sep_chnl:] *= -1 # gt_affinities[self.sep_chnl:] += +1 # gt_affinities[:self.sep_chnl] /= 1.5 # with torch.set_grad_enabled(False): # affinities = self.aff_pred(raw.to(self.aff_pred.device)) # affinities = affinities.squeeze().detach().cpu().numpy() # affinities[self.sep_chnl:] *= -1 # affinities[self.sep_chnl:] += +1 # affinities[:self.sep_chnl] /= 1.2 valid_edges = get_valid_edges((len(self.edge_offsets), ) + self.shape, self.edge_offsets, self.sep_chnl, None, False) node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm( affinities.ravel(), valid_edges.ravel(), offsets, self.sep_chnl, self.shape) node_labeling = node_labeling - 1 node_labeling = seg_arbitrary # plt.imshow(cm.prism(node_labeling/node_labeling.max()));plt.show() # plt.imshow(data);plt.show() neighbors = (node_labeling.ravel())[neighbors] nodes = np.unique(node_labeling) edge_feat = get_edge_features_1d(node_labeling, offsets, affinities) # for i, node in enumerate(nodes): # seg = node_labeling == node # masked_data = seg * data # idxs = np.where(seg) # dxs1 = np.stack(idxs).transpose() # # y, x = bbox(np.expand_dims(seg, 0)) # # y, x = y[0], x[0] # mass = np.sum(seg) # # _, s, _ = np.linalg.svd(StandardScaler().fit_transform(seg)) # mean = np.sum(masked_data) / mass # cm = np.sum(dxs1, axis=0) / mass # var = np.var(data[idxs[0], idxs[1]]) # # mean = 0 if mean < .5 else 1 # # node_features[node] = torch.tensor([mean]) offsets_3d = [[0, 0, -1], [0, -1, 0], [0, -3, 0], [0, 0, -3]] # rag = feats.compute_rag(np.expand_dims(node_labeling, axis=0)) # edge_feat = feats.compute_affinity_features(rag, np.expand_dims(affinities, axis=1), offsets_3d)[:, :] # gt_edge_weights = feats.compute_affinity_features(rag, np.expand_dims(gt_affinities, axis=1), offsets_3d)[:, 0] gt_edge_weights = calculate_gt_edge_costs(neighbors, node_labeling.squeeze(), gt.squeeze()) # gt_edge_weights = utils.calculate_naive_gt_edge_costs(neighbors, node_features).unsqueeze(-1) # affs = np.expand_dims(affinities, axis=1) # boundary_input = np.mean(affs, axis=0) # plt.imshow(multicut_from_probas(node_labeling, neighbors, gt_edge_weights, boundary_input));plt.show() # neighs = np.empty((10, 2)) # gt_neighs = np.empty(10) # neighs[0] = neighbors[30] # gt_neighs[0] = gt_edge_weights[30] # i = 0 # while True: # for idx, n in enumerate(neighbors): # if n[0] in neighs.ravel() or n[1] in neighs.ravel(): # neighs[i] = n # gt_neighs[i] = gt_edge_weights[idx] # i += 1 # if i == 10: # break # if i == 10: # break # # nodes = nodes[np.unique(neighs.ravel())] # node_features = nodes # neighbors = neighs edges = torch.from_numpy(neighbors.astype(np.long)) raw = raw.squeeze() edge_feat = torch.from_numpy(edge_feat.astype(np.float32)) nodes = torch.from_numpy(nodes.astype(np.float32)) # gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.float32)) # affinities = torch.from_numpy(affinities.astype(np.float32)) affinities = torch.from_numpy(gt_affinities.astype(np.float32)) gt_affinities = torch.from_numpy(gt_affinities.astype(np.float32)) node_labeling = torch.from_numpy(node_labeling.astype(np.float32)) gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.float32)) # noise = torch.randn_like(edge_feat) / 3 # edge_feat += noise # edge_feat = torch.min(edge_feat, torch.ones_like(edge_feat)) # edge_feat = torch.max(edge_feat, torch.zeros_like(edge_feat)) diff_to_gt = (edge_feat[:, 0] - gt_edge_weights).abs().sum() node_features, angles = get_stacked_node_data(nodes, edges, node_labeling, raw, size=[32, 32]) # plt.imshow(node_features.view(-1, 32)); # plt.show() edges = edges.t().contiguous() edges = torch.cat((edges, torch.stack((edges[1], edges[0]))), dim=1) return edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles
def train(self): writer = SummaryWriter(logdir=self.log_dir) device = "cuda:0" wu_cfg = self.cfg.fe.trainer model = UNet2D(**self.cfg.fe.backbone) model.cuda(device) train_set = SpgDset( "/g/kreshuk/hilt/projects/data/leptin_fused_tp1_ch_0/true_val", reorder_sp=True) val_set = SpgDset( "/g/kreshuk/hilt/projects/data/leptin_fused_tp1_ch_0/train", reorder_sp=True) # pm = StridedPatches2D(wu_cfg.patch_stride, wu_cfg.patch_shape, train_set.image_shape) train_loader = DataLoader(train_set, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) val_loader = DataLoader(val_set, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr) sheduler = ReduceLROnPlateau(optimizer, patience=40, threshold=1e-4, min_lr=1e-5, factor=0.1) criterion = RagContrastiveWeights(delta_var=0.1, delta_dist=0.3) acc_loss = 0 valit = 0 iteration = 0 best_loss = np.inf while iteration <= wu_cfg.n_iterations: for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(train_loader): raw, gt = raw.to(device), gt.to(device) loss_embeds = model(raw[:, :, None]).squeeze(2) loss_embeds = loss_embeds / ( torch.norm(loss_embeds, dim=1, keepdim=True) + 1e-9) edges = [ feats.compute_rag(seg.cpu().numpy()).uvIds() for seg in gt ] edges = [ torch.from_numpy(e.astype(np.long)).to(device).T for e in edges ] loss = criterion(loss_embeds, gt.long(), edges, None, 30) optimizer.zero_grad() loss.backward() optimizer.step() print(loss.item()) # writer.add_scalar("fe_train/lr", optimizer.param_groups[0]['lr'], iteration) # writer.add_scalar("fe_train/loss", loss.item(), iteration) # if (iteration) % 100 == 0: # # fig, (a1, a2, a3) = plt.subplots(3, 1, sharex='col', sharey='row', # gridspec_kw={'hspace': 0, 'wspace': 0}) # a1.imshow(raw[0, 0].cpu().squeeze()) # a1.set_title('train raw') # a2.imshow(pca_project(loss_embeds[0].detach().cpu())) # a2.set_title('train embed') # a3.imshow(gt[0, 0].cpu().squeeze()) # a3.set_title('train gt') # plt.show() # # with torch.set_grad_enabled(False): # for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(val_loader): # raw = raw.to(device) # embeds = model(raw[:, :, None]).squeeze(2) # embeds = embeds / (torch.norm(embeds, dim=1, keepdim=True) + 1e-9) # # print(loss.item()) # writer.add_scalar("fe_train/lr", optimizer.param_groups[0]['lr'], iteration) # writer.add_scalar("fe_train/loss", loss.item(), iteration) # fig, (a1, a2) = plt.subplots(2, 1, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0}) # a1.imshow(raw[0, 0].cpu().squeeze()) # a1.set_title('raw') # a2.imshow(pca_project(embeds[0].detach().cpu())) # a2.set_title('embed') # plt.show() # if it > 2: # break iteration += 1 print(iteration) if iteration > wu_cfg.n_iterations: print(self.save_dir) torch.save(model.state_dict(), os.path.join(self.save_dir, "last_model.pth")) break return
torch.unique(gt_seg)), device=dev)[:, None, None] + 1)).sum(0) - 1 mask = superpixel_seg[None] == torch.unique(superpixel_seg)[:, None, None] superpixel_seg = ( mask * (torch.arange(len(torch.unique(superpixel_seg)), device=dev)[:, None, None] + 1)).sum(0) - 1 mask = pred_seg[None] == torch.unique(pred_seg)[:, None, None] pred_seg = (mask * (torch.arange(len(torch.unique(pred_seg)), device=dev)[:, None, None] + 1)).sum(0) - 1 # assert the segmentations are consecutive integers assert pred_seg.max() == len(torch.unique(pred_seg)) - 1 assert superpixel_seg.max() == len(torch.unique(superpixel_seg)) - 1 edges = torch.from_numpy( compute_rag(superpixel_seg.cpu().numpy()).uvIds().astype( np.long)).T.to(dev) dir_edges = torch.stack((torch.cat( (edges[0], edges[1])), torch.cat((edges[1], edges[0])))) gt_edges = calculate_gt_edge_costs(edges.T, superpixel_seg.squeeze(), gt_seg.squeeze(), thresh=0.5) mc_seg_old = None for itr in range(10): actions = gt_edges + torch.randn_like(gt_edges) * (itr / 10) # actions = torch.randn_like(gt_edges) # actions = torch.zeros_like(gt_edges) # actions[1:4] = 1.0 actions -= actions.min()