Example #1
0
def segment_volume_mc(pmaps,
                      threshold=0.4,
                      sigma=2.0,
                      beta=0.6,
                      ws=None,
                      sp_min_size=100):
    if ws is None:
        ws = distance_transform_watershed(pmaps,
                                          threshold,
                                          sigma,
                                          min_size=sp_min_size)[0]

    rag = compute_rag(ws, 1)
    features = nrag.accumulateEdgeMeanAndLength(rag, pmaps, numberOfThreads=1)
    probs = features[:, 0]  # mean edge prob
    edge_sizes = features[:, 1]
    costs = transform_probabilities_to_costs(probs,
                                             edge_sizes=edge_sizes,
                                             beta=beta)
    graph = nifty.graph.undirectedGraph(rag.numberOfNodes)
    graph.insertEdges(rag.uvIds())

    node_labels = multicut_kernighan_lin(graph, costs)

    return nifty.tools.take(node_labels, ws)
Example #2
0
    def segment_volume(self, pmaps):
        if self.ws_2D:
            # WS in 2D
            ws = self.ws_dt_2D(pmaps)
        else:
            # WS in 3D
            ws, _ = distance_transform_watershed(pmaps,
                                                 self.ws_threshold,
                                                 self.ws_sigma,
                                                 sigma_weights=self.ws_w_sigma,
                                                 min_size=self.ws_minsize)

        rag = compute_rag(ws, 1)
        # Computing edge features
        features = nrag.accumulateEdgeMeanAndLength(
            rag, pmaps, numberOfThreads=1)  # DO NOT CHANGE numberOfThreads
        probs = features[:, 0]  # mean edge prob
        edge_sizes = features[:, 1]
        # Prob -> edge costs
        costs = transform_probabilities_to_costs(probs,
                                                 edge_sizes=edge_sizes,
                                                 beta=self.beta)
        # Creating graph
        graph = nifty.graph.undirectedGraph(rag.numberOfNodes)
        graph.insertEdges(rag.uvIds())
        # Solving Multicut
        node_labels = multicut_kernighan_lin(graph, costs)
        return nifty.tools.take(node_labels, ws)
Example #3
0
def get_edge_features_1d(sp_seg, offsets, affinities):
    offsets_3d = []
    for off in offsets:
        offsets_3d.append([0] + off)

    rag = feats.compute_rag(np.expand_dims(sp_seg, axis=0))
    edge_feat = feats.compute_affinity_features(
        rag, np.expand_dims(affinities, axis=1), offsets_3d)[:, :]
    return edge_feat
Example #4
0
def segment_mc(pred, seg, delta):
    rag = feats.compute_rag(seg)
    edge_probs = embed.edge_probabilities_from_embeddings(
        pred, seg, rag, delta)
    edge_sizes = feats.compute_boundary_mean_and_length(rag, pred[0])[:, 1]
    costs = mc.transform_probabilities_to_costs(edge_probs,
                                                edge_sizes=edge_sizes)
    mc_seg = mc.multicut_kernighan_lin(rag, costs)
    mc_seg = feats.project_node_labels_to_pixels(rag, mc_seg)
    return mc_seg
Example #5
0
    def test_region_features(self):
        from elf.segmentation.features import compute_rag, compute_region_features

        shape = (32, 128, 128)
        inp = np.random.rand(*shape).astype('float32')
        seg = self.make_seg(shape)
        rag = compute_rag(seg)
        uv_ids = rag.uvIds()

        feats = compute_region_features(uv_ids, inp, seg)
        self.assertEqual(len(uv_ids), len(feats))
        self.assertFalse(np.allclose(feats, 0))
Example #6
0
def multicut_from_probas(segmentation, edges, edge_weights):
    rag = compute_rag(segmentation)
    edge_dict = dict(zip(list(map(tuple, edges)), edge_weights))
    costs = np.empty(len(edge_weights))
    for i, neighbors in enumerate(rag.uvIds()):
        if tuple(neighbors) in edge_dict:
            costs[i] = edge_dict[tuple(neighbors)]
        else:
            costs[i] = edge_dict[(neighbors[1], neighbors[0])]
    costs = transform_probabilities_to_costs(costs)
    node_labels = multicut_kernighan_lin(rag, costs)

    return project_node_labels_to_pixels(rag, node_labels).squeeze()
def supervoxel_merging(mem, sv, beta=0.5, verbose=False):

    rag = feats.compute_rag(sv)
    costs = feats.compute_boundary_features(rag, mem)[:, 0]

    edge_sizes = feats.compute_boundary_mean_and_length(rag, mem)[:, 1]
    costs = mc.transform_probabilities_to_costs(costs,
                                                edge_sizes=edge_sizes,
                                                beta=beta)

    node_labels = mc.multicut_kernighan_lin(rag, costs)
    segmentation = feats.project_node_labels_to_pixels(rag, node_labels)

    return segmentation
Example #8
0
    def test_boundary_features_with_filters(self):
        from elf.segmentation.features import compute_rag, compute_boundary_features_with_filters

        shape = (64, 128, 128)
        inp = np.random.rand(*shape).astype('float32')
        seg = self.make_seg(shape)
        rag = compute_rag(seg)

        feats = compute_boundary_features_with_filters(rag, inp)
        self.assertEqual(rag.numberOfEdges, len(feats))
        self.assertFalse(np.allclose(feats, 0))

        feats = compute_boundary_features_with_filters(rag, inp, apply_2d=True)
        self.assertEqual(rag.numberOfEdges, len(feats))
        self.assertFalse(np.allclose(feats, 0))
Example #9
0
    def test_lifted_problem_from_probabilities(self):
        from elf.segmentation.features import (
            compute_rag, lifted_problem_from_probabilities)
        shape = (32, 128, 128)
        seg = self.make_seg(shape)
        rag = compute_rag(seg)

        n_classes = 3
        input_maps = [
            np.random.rand(*shape).astype('float32') for _ in range(n_classes)
        ]

        assignment_threshold = .5
        graph_depth = 4
        lifted_uvs, lifted_costs = lifted_problem_from_probabilities(
            rag, seg, input_maps, assignment_threshold, graph_depth)
        self.assertEqual(len(lifted_uvs), len(lifted_costs))
Example #10
0
def segment_volume_lmc_from_seg(boundary_pmaps,
                                nuclei_seg,
                                threshold=0.4,
                                sigma=2.0,
                                sp_min_size=100):
    watershed = distance_transform_watershed(boundary_pmaps,
                                             threshold,
                                             sigma,
                                             min_size=sp_min_size)[0]

    # compute the region adjacency graph
    rag = compute_rag(watershed)

    # compute the edge costs
    features = compute_boundary_mean_and_length(rag, boundary_pmaps)
    costs, sizes = features[:, 0], features[:, 1]

    # transform the edge costs from [0, 1] to  [-inf, inf], which is
    # necessary for the multicut. This is done by intepreting the values
    # as probabilities for an edge being 'true' and then taking the negative log-likelihood.
    # in addition, we weight the costs by the size of the corresponding edge

    # we choose a boundary bias smaller than 0.5 in order to
    # decrease the degree of over segmentation
    boundary_bias = .6

    costs = transform_probabilities_to_costs(costs,
                                             edge_sizes=sizes,
                                             beta=boundary_bias)
    max_cost = np.abs(np.max(costs))
    lifted_uvs, lifted_costs = lifted_problem_from_segmentation(
        rag,
        watershed,
        nuclei_seg,
        overlap_threshold=0.2,
        graph_depth=4,
        same_segment_cost=5 * max_cost,
        different_segment_cost=-5 * max_cost)

    # solve the full lifted problem using the kernighan lin approximation introduced in
    # http://openaccess.thecvf.com/content_iccv_2015/html/Keuper_Efficient_Decomposition_of_ICCV_2015_paper.html
    node_labels = lmc.lifted_multicut_kernighan_lin(rag, costs, lifted_uvs,
                                                    lifted_costs)
    lifted_segmentation = project_node_labels_to_pixels(rag, node_labels)
    return lifted_segmentation
Example #11
0
    def test_lifted_problem_from_segmentation(self):
        from elf.segmentation.features import (compute_rag,
                                               lifted_problem_from_segmentation
                                               )
        shape = (32, 128, 128)
        N = 5
        for _ in range(N):
            ws = self.make_seg(shape)
            rag = compute_rag(ws)
            seg = self.make_seg(shape)

            overlap_threshold = .5
            graph_depth = 4
            lifted_uvs, lifted_costs = lifted_problem_from_segmentation(
                rag,
                ws,
                seg,
                overlap_threshold,
                graph_depth,
                same_segment_cost=1.,
                different_segment_cost=-1)
            self.assertEqual(len(lifted_uvs), len(lifted_costs))
Example #12
0
def update_env_data(env,
                    data_iter,
                    data_set,
                    device,
                    with_gt_edges=False,
                    fe_grad=False):
    raw, gt, sp_seg, indices = next(data_iter)
    rags = [compute_rag(sseg.numpy()) for sseg in sp_seg]
    # edges = [torch.from_numpy(rag.uvIds().astype(np.long)).T.to(device) for rag in rags]
    raw, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to(device)
    edges, edge_feat, gt_edges = data_set.get_graphs(indices, sp_seg, device)
    # for e1, e2 in zip(edges, _edges):
    #     assert not (e1 != e2).any()

    # gt_edges = [calculate_gt_edge_costs(s_edges.T, sseg.squeeze(), sgt.squeeze(), cfg.gt_edge_overlap_thresh).to(device).float() for s_edges, sseg, sgt in zip(edges, sp_seg, gt)]
    env.update_data(edge_ids=edges,
                    gt_edges=gt_edges,
                    sp_seg=sp_seg,
                    raw=raw,
                    gt=gt,
                    fe_grad=fe_grad,
                    rags=rags,
                    edge_features=edge_feat)
Example #13
0
def compute_graph_and_weights(path, return_edge_sizes=False):
    from nifty.graph import undirectedGraph
    with h5py.File(path, 'a') as f:
        # if 'features' in f:
        if False:
            edges = f['edges'][:]
            feats = f['features'][:]
            edge_sizes = f['edge_sizes'][:]
            z_edges = f['z_edges'][:]
            n_nodes = int(edges.max()) + 1

        else:
            from elf.segmentation.features import compute_rag, compute_boundary_features, compute_z_edge_mask
            seg = f['watershed'][:]
            boundaries = f['boundaries'][:]
            boundaries[boundaries > .2] *= 3
            boundaries = np.clip(boundaries, 0, 1)
            rag = compute_rag(seg)
            n_nodes = rag.numberOfNodes
            feats = compute_boundary_features(rag, boundaries)
            feats, edge_sizes = feats[:, 0], feats[:, -1]
            edges = rag.uvIds()

            z_edges = compute_z_edge_mask(rag, seg)

            # f.create_dataset('edges', data=edges)
            # f.create_dataset('edge_sizes', data=edge_sizes)
            # f.create_dataset('features', data=feats)
            # f.create_dataset('z_edges', data=z_edges)

    graph = undirectedGraph(n_nodes)
    graph.insertEdges(edges)
    if return_edge_sizes:
        return graph, feats, edge_sizes, z_edges, boundaries
    else:
        return graph, feats
Example #14
0
                                          separating_channel, offsets),
                          batch_size=1,
                          shuffle=True,
                          pin_memory=True)
neighbors, nodes, seg, gt_seg, affs, gt_affs = next(iter(dloader_disc))

offsets = [[0, 0, -1], [0, -1, 0], [0, -3, 0], [0, 0, -3]]

affs = np.transpose(affs.cpu().numpy(), (1, 0, 2, 3))
gt_affs = np.transpose(gt_affs.cpu().numpy(), (1, 0, 2, 3))
seg = seg.cpu().numpy()
gt_seg = gt_seg.cpu().numpy()
boundary_input = np.mean(affs, axis=0)
gt_boundary_input = np.mean(gt_affs, axis=0)

rag = feats.compute_rag(seg)
# edges rag.uvIds() [[1, 2], ...]

costs = feats.compute_affinity_features(rag, affs, offsets)[:, 0]
gt_costs = calculate_gt_edge_costs(rag.uvIds(), seg.squeeze(),
                                   gt_seg.squeeze())

edge_sizes = feats.compute_boundary_mean_and_length(rag, boundary_input)[:, 1]
gt_edge_sizes = feats.compute_boundary_mean_and_length(rag,
                                                       gt_boundary_input)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes)
gt_costs = mc.transform_probabilities_to_costs(gt_costs, edge_sizes=edge_sizes)

node_labels = mc.multicut_kernighan_lin(rag, costs)
gt_node_labels = mc.multicut_kernighan_lin(rag, gt_costs)
Example #15
0
def refine_seg(raw,
               seeds,
               restrict_to_seeds=True,
               restrict_to_bb=False,
               return_intermediates=False):
    pred = get_prediction(raw, cache=False)

    n_threads = 1
    # make watershed
    ws, _ = stacked_watershed(pred,
                              threshold=.5,
                              sigma_seeds=1.,
                              n_threads=n_threads)
    rag = compute_rag(ws, n_threads=n_threads)
    edge_feats = compute_boundary_mean_and_length(rag,
                                                  pred,
                                                  n_threads=n_threads)
    edge_feats, edge_sizes = edge_feats[:, 0], edge_feats[:, 1]
    z_edges = compute_z_edge_mask(rag, ws)
    edge_costs = compute_edge_costs(edge_feats,
                                    beta=.4,
                                    weighting_scheme='xyz',
                                    edge_sizes=edge_sizes,
                                    z_edge_mask=z_edges)

    # make seeds and map them to edges
    bb = tuple(
        slice(sh // 2 - ha // 2, sh // 2 + ha // 2)
        for sh, ha in zip(pred.shape, seeds.shape))

    seeds[seeds < 0] = 0
    seeds = vigra.analysis.labelVolumeWithBackground(seeds.astype('uint32'))
    seed_ids = np.unique(seeds)
    seed_mask = binary_erosion(seeds, iterations=2)

    seeds_new = seeds.copy()
    seeds_new[~seed_mask] = 0
    seed_ids_new = np.unique(seeds_new)
    for seed_id in seed_ids:
        if seed_id in seed_ids_new:
            continue
        seeds_new[seeds == seed_id] = seed_id

    seeds_full = np.zeros(pred.shape, dtype=seeds.dtype)
    seeds_full[bb] = seeds
    seeds = seeds_full

    seed_labels = compute_maximum_label_overlap(ws, seeds, ignore_zeros=True)

    edge_ids = rag.uvIds()
    labels_u = seed_labels[edge_ids[:, 0]]
    labels_v = seed_labels[edge_ids[:, 1]]

    seed_mask = np.logical_and(labels_u != 0, labels_v != 0)
    same_seed = np.logical_and(seed_mask, labels_u == labels_v)
    diff_seed = np.logical_and(seed_mask, labels_u != labels_v)

    max_att = edge_costs.max() + .1
    max_rep = edge_costs.min() - .1
    edge_costs[same_seed] = max_att
    edge_costs[diff_seed] = max_rep

    # run multicut
    node_labels = multicut_kernighan_lin(rag, edge_costs)
    if restrict_to_seeds:
        seed_nodes = np.unique(node_labels[seed_labels > 0])
        node_labels[~np.isin(node_labels, seed_nodes)] = 0
        vigra.analysis.relabelConsecutive(node_labels, out=node_labels)

    seg = project_node_labels_to_pixels(rag, node_labels, n_threads=n_threads)

    if restrict_to_bb:
        bb_mask = np.zeros(seg.shape, dtype='bool')
        bb_mask[bb] = 1
        seg[~bb_mask] = 0

    if return_intermediates:
        return pred, ws, seeds, seg
    else:
        return seg
Example #16
0
    def get(self, idx):
        radius = np.random.randint(max(self.shape) // 5, max(self.shape) // 3)
        mp = (np.random.randint(0 + radius, self.shape[0] - radius),
              np.random.randint(0 + radius, self.shape[1] - radius))
        # mp = self.mp
        data = np.zeros(shape=self.shape, dtype=np.float)
        gt = np.zeros(shape=self.shape, dtype=np.float)
        for y in range(self.shape[0]):
            for x in range(self.shape[1]):
                ly, lx = y - mp[0], x - mp[1]
                if (ly**2 + lx**2)**.5 <= radius:
                    data[y, x] += np.sin(x * 10 * np.pi / self.shape[1])
                    data[y, x] += np.sin(
                        np.sqrt(x**2 + y**2) * 20 * np.pi / self.shape[1])
                    # data[y, x] += 4
                    gt[y, x] = 1
                else:
                    data[y, x] += np.sin(y * 10 * np.pi / self.shape[1])
                    data[y, x] += np.sin(
                        np.sqrt(x**2 + (self.shape[1] - y)**2) * 10 * np.pi /
                        self.shape[1])
        data += 1
        # plt.imshow(data);plt.show()
        gt_affinities, _ = compute_affinities(gt == 1, offsets)

        seg_arbitrary = np.zeros_like(data)
        square_dict = {}
        i = 0
        granularity = 30
        for y in range(self.shape[0]):
            for x in range(self.shape[1]):
                if (x // granularity, y // granularity) not in square_dict:
                    square_dict[(x // granularity, y // granularity)] = i
                    i += 1
                seg_arbitrary[y, x] += square_dict[(x // granularity,
                                                    y // granularity)]
        seg_arbitrary += gt * 1000
        i = 0
        segs = np.unique(seg_arbitrary)
        seg_arb = np.zeros_like(seg_arbitrary)
        for seg in segs:
            seg_arb += (seg_arbitrary == seg) * i
            i += 1
        seg_arbitrary = seg_arb
        rag = feats.compute_rag(np.expand_dims(seg_arbitrary, axis=0))
        neighbors = rag.uvIds()

        affinities = get_naive_affinities(data, offsets)
        # edge_feat = get_edge_features_1d(seg_arbitrary, offsets, affinities)
        # self.edge_offsets = [[1, 0], [0, 1], [1, 0], [0, 1]]
        # self.sep_chnl = 2
        # affinities = np.stack((ndimage.sobel(data, axis=0), ndimage.sobel(data, axis=1)))
        # affinities = np.concatenate((affinities, affinities), axis=0)
        affinities[:self.sep_chnl] *= -1
        affinities[:self.sep_chnl] += +1
        affinities[self.sep_chnl:] /= 0.2
        #
        raw = torch.tensor(data).unsqueeze(0).unsqueeze(0).float()
        # if self.aff_pred is not None:
        #     gt_affinities[self.sep_chnl:] *= -1
        #     gt_affinities[self.sep_chnl:] += +1
        #     gt_affinities[:self.sep_chnl] /= 1.5
        # with torch.set_grad_enabled(False):
        #     affinities = self.aff_pred(raw.to(self.aff_pred.device))
        #     affinities = affinities.squeeze().detach().cpu().numpy()
        #     affinities[self.sep_chnl:] *= -1
        #     affinities[self.sep_chnl:] += +1
        #     affinities[:self.sep_chnl] /= 1.2

        valid_edges = get_valid_edges((len(self.edge_offsets), ) + self.shape,
                                      self.edge_offsets, self.sep_chnl, None,
                                      False)
        node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
            affinities.ravel(), valid_edges.ravel(), offsets, self.sep_chnl,
            self.shape)
        node_labeling = node_labeling - 1
        node_labeling = seg_arbitrary
        # plt.imshow(cm.prism(node_labeling/node_labeling.max()));plt.show()
        # plt.imshow(data);plt.show()
        neighbors = (node_labeling.ravel())[neighbors]
        nodes = np.unique(node_labeling)
        edge_feat = get_edge_features_1d(node_labeling, offsets, affinities)

        # for i, node in enumerate(nodes):
        #     seg = node_labeling == node
        #     masked_data = seg * data
        #     idxs = np.where(seg)
        #     dxs1 = np.stack(idxs).transpose()
        #     # y, x = bbox(np.expand_dims(seg, 0))
        #     # y, x = y[0], x[0]
        #     mass = np.sum(seg)
        #     # _, s, _ = np.linalg.svd(StandardScaler().fit_transform(seg))
        #     mean = np.sum(masked_data) / mass
        #     cm = np.sum(dxs1, axis=0) / mass
        #     var = np.var(data[idxs[0], idxs[1]])
        #
        #     mean = 0 if mean < .5 else 1
        #
        #     node_features[node] = torch.tensor([mean])

        offsets_3d = [[0, 0, -1], [0, -1, 0], [0, -3, 0], [0, 0, -3]]

        # rag = feats.compute_rag(np.expand_dims(node_labeling, axis=0))
        # edge_feat = feats.compute_affinity_features(rag, np.expand_dims(affinities, axis=1), offsets_3d)[:, :]
        # gt_edge_weights = feats.compute_affinity_features(rag, np.expand_dims(gt_affinities, axis=1), offsets_3d)[:, 0]
        gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                                  node_labeling.squeeze(),
                                                  gt.squeeze())
        # gt_edge_weights = utils.calculate_naive_gt_edge_costs(neighbors, node_features).unsqueeze(-1)
        # affs = np.expand_dims(affinities, axis=1)
        # boundary_input = np.mean(affs, axis=0)
        # plt.imshow(multicut_from_probas(node_labeling, neighbors, gt_edge_weights, boundary_input));plt.show()

        # neighs = np.empty((10, 2))
        # gt_neighs = np.empty(10)
        # neighs[0] = neighbors[30]
        # gt_neighs[0] = gt_edge_weights[30]
        # i = 0
        # while True:
        #     for idx, n in enumerate(neighbors):
        #         if n[0] in neighs.ravel() or n[1] in neighs.ravel():
        #             neighs[i] = n
        #             gt_neighs[i] = gt_edge_weights[idx]
        #             i += 1
        #             if i == 10:
        #                 break
        #     if i == 10:
        #         break
        #
        # nodes = nodes[np.unique(neighs.ravel())]
        # node_features = nodes
        # neighbors = neighs

        edges = torch.from_numpy(neighbors.astype(np.long))
        raw = raw.squeeze()
        edge_feat = torch.from_numpy(edge_feat.astype(np.float32))
        nodes = torch.from_numpy(nodes.astype(np.float32))
        # gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.float32))
        # affinities = torch.from_numpy(affinities.astype(np.float32))
        affinities = torch.from_numpy(gt_affinities.astype(np.float32))
        gt_affinities = torch.from_numpy(gt_affinities.astype(np.float32))
        node_labeling = torch.from_numpy(node_labeling.astype(np.float32))

        gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.float32))
        # noise = torch.randn_like(edge_feat) / 3
        # edge_feat += noise
        # edge_feat = torch.min(edge_feat, torch.ones_like(edge_feat))
        # edge_feat = torch.max(edge_feat, torch.zeros_like(edge_feat))
        diff_to_gt = (edge_feat[:, 0] - gt_edge_weights).abs().sum()

        node_features, angles = get_stacked_node_data(nodes,
                                                      edges,
                                                      node_labeling,
                                                      raw,
                                                      size=[32, 32])
        # plt.imshow(node_features.view(-1, 32));
        # plt.show()

        edges = edges.t().contiguous()
        edges = torch.cat((edges, torch.stack((edges[1], edges[0]))), dim=1)

        return edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles
Example #17
0
    def train(self):
        writer = SummaryWriter(logdir=self.log_dir)
        device = "cuda:0"
        wu_cfg = self.cfg.fe.trainer
        model = UNet2D(**self.cfg.fe.backbone)
        model.cuda(device)
        train_set = SpgDset(
            "/g/kreshuk/hilt/projects/data/leptin_fused_tp1_ch_0/true_val",
            reorder_sp=True)
        val_set = SpgDset(
            "/g/kreshuk/hilt/projects/data/leptin_fused_tp1_ch_0/train",
            reorder_sp=True)
        # pm = StridedPatches2D(wu_cfg.patch_stride, wu_cfg.patch_shape, train_set.image_shape)
        train_loader = DataLoader(train_set,
                                  batch_size=wu_cfg.batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=0)
        val_loader = DataLoader(val_set,
                                batch_size=wu_cfg.batch_size,
                                shuffle=True,
                                pin_memory=True,
                                num_workers=0)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr)
        sheduler = ReduceLROnPlateau(optimizer,
                                     patience=40,
                                     threshold=1e-4,
                                     min_lr=1e-5,
                                     factor=0.1)
        criterion = RagContrastiveWeights(delta_var=0.1, delta_dist=0.3)
        acc_loss = 0
        valit = 0
        iteration = 0
        best_loss = np.inf

        while iteration <= wu_cfg.n_iterations:
            for it, (raw, gt, sp_seg, affinities, offs,
                     indices) in enumerate(train_loader):
                raw, gt = raw.to(device), gt.to(device)

                loss_embeds = model(raw[:, :, None]).squeeze(2)
                loss_embeds = loss_embeds / (
                    torch.norm(loss_embeds, dim=1, keepdim=True) + 1e-9)

                edges = [
                    feats.compute_rag(seg.cpu().numpy()).uvIds() for seg in gt
                ]
                edges = [
                    torch.from_numpy(e.astype(np.long)).to(device).T
                    for e in edges
                ]

                loss = criterion(loss_embeds, gt.long(), edges, None, 30)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                print(loss.item())
                # writer.add_scalar("fe_train/lr", optimizer.param_groups[0]['lr'], iteration)
                # writer.add_scalar("fe_train/loss", loss.item(), iteration)
                # if (iteration) % 100 == 0:
                #
                #     fig, (a1, a2, a3) = plt.subplots(3, 1, sharex='col', sharey='row',
                #                                  gridspec_kw={'hspace': 0, 'wspace': 0})
                #     a1.imshow(raw[0, 0].cpu().squeeze())
                #     a1.set_title('train raw')
                #     a2.imshow(pca_project(loss_embeds[0].detach().cpu()))
                #     a2.set_title('train embed')
                #     a3.imshow(gt[0, 0].cpu().squeeze())
                #     a3.set_title('train gt')
                #     plt.show()
                #
                #     with torch.set_grad_enabled(False):
                #         for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(val_loader):
                #             raw = raw.to(device)
                #             embeds = model(raw[:, :, None]).squeeze(2)
                #             embeds = embeds / (torch.norm(embeds, dim=1, keepdim=True) + 1e-9)
                #
                #             print(loss.item())
                #             writer.add_scalar("fe_train/lr", optimizer.param_groups[0]['lr'], iteration)
                #             writer.add_scalar("fe_train/loss", loss.item(), iteration)
                #             fig, (a1, a2) = plt.subplots(2, 1, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0})
                #             a1.imshow(raw[0, 0].cpu().squeeze())
                #             a1.set_title('raw')
                #             a2.imshow(pca_project(embeds[0].detach().cpu()))
                #             a2.set_title('embed')
                #             plt.show()
                #             if it > 2:
                #                 break
                iteration += 1
                print(iteration)
                if iteration > wu_cfg.n_iterations:
                    print(self.save_dir)
                    torch.save(model.state_dict(),
                               os.path.join(self.save_dir, "last_model.pth"))
                    break
        return
Example #18
0
            torch.unique(gt_seg)), device=dev)[:, None, None] + 1)).sum(0) - 1
        mask = superpixel_seg[None] == torch.unique(superpixel_seg)[:, None,
                                                                    None]
        superpixel_seg = (
            mask * (torch.arange(len(torch.unique(superpixel_seg)),
                                 device=dev)[:, None, None] + 1)).sum(0) - 1
        mask = pred_seg[None] == torch.unique(pred_seg)[:, None, None]
        pred_seg = (mask *
                    (torch.arange(len(torch.unique(pred_seg)),
                                  device=dev)[:, None, None] + 1)).sum(0) - 1
        # assert the segmentations are consecutive integers
        assert pred_seg.max() == len(torch.unique(pred_seg)) - 1
        assert superpixel_seg.max() == len(torch.unique(superpixel_seg)) - 1

        edges = torch.from_numpy(
            compute_rag(superpixel_seg.cpu().numpy()).uvIds().astype(
                np.long)).T.to(dev)
        dir_edges = torch.stack((torch.cat(
            (edges[0], edges[1])), torch.cat((edges[1], edges[0]))))

        gt_edges = calculate_gt_edge_costs(edges.T,
                                           superpixel_seg.squeeze(),
                                           gt_seg.squeeze(),
                                           thresh=0.5)

        mc_seg_old = None
        for itr in range(10):
            actions = gt_edges + torch.randn_like(gt_edges) * (itr / 10)
            # actions = torch.randn_like(gt_edges)
            # actions = torch.zeros_like(gt_edges)
            # actions[1:4] = 1.0
            actions -= actions.min()