Ejemplo n.º 1
0
    def get_current_soln_pic(self, b):
        actions = self.get_batched_actions_from_global_graph(
            self.sg_current_edge_weights.view(-1))
        gt = self.get_batched_actions_from_global_graph(
            self.sg_gt_edge_weights.view(-1))

        edge_ids = self.edge_ids[:, self.e_offs[b]:self.
                                 e_offs[b + 1]] - self.n_offs[b]
        edge_ids = edge_ids.cpu().t().contiguous().numpy()
        boundary_input = self.initial_edge_weights[self.e_offs[b]:self.
                                                   e_offs[b +
                                                          1]].cpu().numpy()
        mc_seg1 = general.multicut_from_probas(
            self.init_sp_seg[b].squeeze().cpu(), edge_ids,
            self.initial_edge_weights[self.e_offs[b]:self.e_offs[b + 1]].cpu(
            ).numpy(), boundary_input)
        mc_seg = general.multicut_from_probas(
            self.init_sp_seg[b].squeeze().cpu(), edge_ids,
            actions[self.e_offs[b]:self.e_offs[b + 1]].cpu().numpy(),
            boundary_input)
        gt_mc_seg = general.multicut_from_probas(
            self.init_sp_seg[b].squeeze().cpu(), edge_ids,
            gt[self.e_offs[b]:self.e_offs[b + 1]].cpu().numpy(),
            boundary_input)
        mc_seg = cm.prism(mc_seg / mc_seg.max())
        mc_seg1 = cm.prism(mc_seg1 / mc_seg1.max())
        seg = cm.prism(self.init_sp_seg[b].squeeze().cpu() /
                       self.init_sp_seg[b].cpu().max())
        gt_mc_seg = cm.prism(gt_mc_seg / gt_mc_seg.max())
        return np.concatenate((np.concatenate(
            (mc_seg1, mc_seg), 0), np.concatenate((gt_mc_seg, seg), 0)), 1)
Ejemplo n.º 2
0
 def show_current_soln(self):
     affs = np.expand_dims(self.affinities, axis=1)
     boundary_input = np.mean(affs, axis=0)
     mc_seg1 = general.multicut_from_probas(
         self.init_sp_seg.cpu(),
         self.edge_ids.cpu().t().contiguous().numpy(),
         self.initial_edge_weights.squeeze().cpu().numpy(), boundary_input)
     mc_seg = general.multicut_from_probas(
         self.init_sp_seg.cpu(),
         self.edge_ids.cpu().t().contiguous().numpy(),
         self.current_edge_weights.squeeze().cpu().numpy(), boundary_input)
     gt_mc_seg = general.multicut_from_probas(
         self.init_sp_seg.cpu(),
         self.edge_ids.cpu().t().contiguous().numpy(),
         self.gt_edge_weights.squeeze().cpu().numpy(), boundary_input)
     mc_seg = cm.prism(mc_seg / mc_seg.max())
     mc_seg1 = cm.prism(mc_seg1 / mc_seg1.max())
     seg = cm.prism(self.init_sp_seg.cpu() / self.init_sp_seg.cpu().max())
     gt_mc_seg = cm.prism(gt_mc_seg / gt_mc_seg.max())
     plt.imshow(
         np.concatenate((np.concatenate(
             (mc_seg1, mc_seg), 0), np.concatenate((gt_mc_seg, seg), 0)),
                        1))
     plt.show()
     a = 1
Ejemplo n.º 3
0
        def process_file(i):
            fname = fnames[i]
            head, tail = os.path.split(fname)
            num = tail[4:-3]
            # os.rename(os.path.join(graph_dir, "graph_" + str(i) + ".h5"), os.path.join(graph_dir, "graph_" + num + ".h5"))
            # os.rename(os.path.join(pix_dir, "pix_" + str(i) + ".h5"), os.path.join(pix_dir, "pix_" + num + ".h5"))

            # raw = torch.from_numpy(h5py.File(fname, 'r')['raw'][:].astype(np.float))
            # gt = h5py.File(fname, 'r')['label'][:].astype(np.long)
            # affs = torch.from_numpy(h5py.File(os.path.join(dir, 'affinities', tail[:-3] + '_predictions' + '.h5'), 'r')['predictions'][:]).squeeze(1)

            # raw -= raw.min()
            # raw /= raw.max()
            #
            # node_labeling = run_watershed(gaussian_filter(affs[0] + affs[1] + affs[2] + affs[3], sigma=.2), min_size=4)
            #
            # # relabel to consecutive ints starting at 0
            # node_labeling = torch.from_numpy(node_labeling.astype(np.long))
            #
            #
            # gt = torch.from_numpy(gt.astype(np.long))
            # mask = node_labeling[None] == torch.unique(node_labeling)[:, None, None]
            # node_labeling = (mask * (torch.arange(len(torch.unique(node_labeling)), device=node_labeling.device)[:, None, None] + 1)).sum(
            #     0) - 1
            #
            # node_labeling += 2
            # # bgm
            # node_labeling[gt == 0] = 0
            # node_labeling[gt == 1] = 1
            # plt.imshow(node_labeling);plt.show()
            #
            # mask = gt[None] == torch.unique(gt)[:, None, None]
            # gt = (mask * (torch.arange(len(torch.unique(gt)), device=gt.device)[:, None, None] + 1)).sum(0) - 1
            #
            # edge_img = get_contour_from_2d_binary(node_labeling[None, None].float())
            # edge_img = gauss_kernel(edge_img.float())
            # raw = torch.cat([raw[None, None], edge_img], dim=1).squeeze(0).numpy()
            # affs = torch.sigmoid(affs).numpy()
            #
            # edge_feat, edges = get_edge_features_1d(node_labeling.numpy(), offs, affs[:4])
            #
            # gt_edge_weights = gt_edge_weights.numpy()
            # gt = gt.numpy()
            # node_labeling = node_labeling.numpy()
            # edges = edges.astype(np.long)
            #
            # affs = affs.astype(np.float32)
            # edge_feat = edge_feat.astype(np.float32)
            # node_labeling = node_labeling.astype(np.float32)
            # gt_edge_weights = gt_edge_weights.astype(np.float32)
            # diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()
            # edges = np.sort(edges, axis=-1)
            # edges = edges.T
            # #
            # #
            graph_file = h5py.File(
                os.path.join(graph_dir, "graph_" + num + ".h5"), 'r')
            pix_file = h5py.File(os.path.join(pix_dir, "pix_" + num + ".h5"),
                                 'r')

            raw = pix_file["raw_2chnl"][:]
            gt = pix_file["gt"][:]
            sp_seg = graph_file["node_labeling"][:]
            edges = graph_file["edges"][:]
            affs = graph_file["affinities"][:]

            graph_file.close()
            pix_file.close()
            # plt.imshow(multicut_from_probas(sp_seg, edges.T, calculate_gt_edge_costs(torch.from_numpy(edges.T.astype(np.long)).to(dev), torch.from_numpy(sp_seg).to(dev), torch.from_numpy(gt.squeeze()).to(dev), 1.5).cpu()), cmap=random_label_cmap(), interpolation="none");plt.show()
            with torch.set_grad_enabled(False):
                embeddings = model(
                    torch.from_numpy(raw).to(device).float()[None])
            emb_affs = get_affinities_from_embeddings_2d(
                embeddings, offs, 0.4, distance)[0].cpu().numpy()
            ew_embedaffs = 1 - get_edge_features_1d(sp_seg, offs,
                                                    emb_affs)[0][:, 0]
            mc_soln = torch.from_numpy(
                multicut_from_probas(sp_seg, edges.T,
                                     ew_embedaffs).astype(np.long)).to(device)

            mask = mc_soln[None] == torch.unique(mc_soln)[:, None, None]
            mc_soln = (mask *
                       (torch.arange(len(torch.unique(mc_soln)), device=device)
                        [:, None, None] + 1)).sum(0) - 1

            masses = (
                mc_soln[None] == torch.unique(mc_soln)[:, None,
                                                       None]).sum(-1).sum(-1)
            bg1id = masses.argmax()
            masses[bg1id] = 0
            bg1_mask = mc_soln == bg1id
            bg2_mask = mc_soln == masses.argmax()

            sp_seg = torch.from_numpy(sp_seg.astype(np.long)).to(device)

            mask = sp_seg[None] == torch.unique(sp_seg)[:, None, None]
            sp_seg = (mask *
                      (torch.arange(len(torch.unique(sp_seg)),
                                    device=sp_seg.device)[:, None, None] +
                       1)).sum(0) - 1

            sp_seg += 2
            sp_seg *= (bg1_mask == 0)
            sp_seg *= (bg2_mask == 0)
            sp_seg += bg2_mask

            mask = sp_seg[None] == torch.unique(sp_seg)[:, None, None]
            sp_seg = (mask *
                      (torch.arange(len(torch.unique(sp_seg)),
                                    device=sp_seg.device)[:, None, None] +
                       1)).sum(0) - 1
            sp_seg = sp_seg.cpu()

            raw -= raw.min()
            raw /= raw.max()
            edge_feat, edges = get_edge_features_1d(sp_seg.numpy(), offs[:4],
                                                    affs[:4])
            edges = edges.astype(np.long)

            gt_edge_weights = calculate_gt_edge_costs(
                torch.from_numpy(edges).to(device), sp_seg.to(device),
                torch.from_numpy(gt).to(device), 0.4)
            node_labeling = sp_seg.numpy()

            affs = affs.astype(np.float32)
            edge_feat = edge_feat.astype(np.float32)
            node_labeling = node_labeling.astype(np.float32)
            gt_edge_weights = gt_edge_weights.cpu().numpy().astype(np.float32)
            # edges = np.sort(edges, axis=-1)
            edges = edges.T
            # plt.imshow(sp_seg.cpu(), cmap=random_label_cmap(), interpolation="none");plt.show()
            # plt.imshow(bg1_mask.cpu());plt.show()
            # plt.imshow(bg2_mask.cpu());plt.show()
            new_pix_file = h5py.File(
                os.path.join(new_pix_dir, "pix_" + num + ".h5"), 'w')
            new_graph_file = h5py.File(
                os.path.join(new_graph_dir, "graph_" + num + ".h5"), 'w')

            # plt.imshow(gt_sp_projection, cmap=random_label_cmap(), interpolation="none");plt.show()

            new_pix_file.create_dataset("raw_2chnl", data=raw, chunks=True)
            new_pix_file.create_dataset("gt", data=gt, chunks=True)
            #
            new_graph_file.create_dataset("edges", data=edges, chunks=True)
            new_graph_file.create_dataset("edge_feat",
                                          data=edge_feat,
                                          chunks=True)
            new_graph_file.create_dataset("gt_edge_weights",
                                          data=gt_edge_weights,
                                          chunks=True)
            new_graph_file.create_dataset("node_labeling",
                                          data=node_labeling,
                                          chunks=True)
            new_graph_file.create_dataset("affinities", data=affs, chunks=True)
            new_graph_file.create_dataset("offsets",
                                          data=np.array([[1, 0], [0,
                                                                  1], [2, 0],
                                                         [0, 2], [4,
                                                                  0], [0, 4],
                                                         [8, 0], [0, 8],
                                                         [16, 0], [0, 16]]),
                                          chunks=True)
            #
            new_graph_file.close()
            new_pix_file.close()
Ejemplo n.º 4
0
 def get_current_soln(self):
     affs = np.expand_dims(self.affinities, axis=1)
     boundary_input = np.mean(affs, axis=0)
     mc_seg = general.multicut_from_probas(self.init_sp_seg.squeeze().cpu(), self.edge_ids.cpu().t().contiguous().numpy(),
                                           self.state[0].squeeze().cpu().numpy(), boundary_input)
     return torch.from_numpy(mc_seg.astype(np.float))
Ejemplo n.º 5
0
    for i in range(1):
        g_file = h5py.File(fnames_graph[i], 'r')
        pix_file = h5py.File(fnames_pix[i], 'r')
        superpixel_seg = g_file['node_labeling'][:]
        gt_seg = pix_file['gt'][:]
        superpixel_seg = torch.from_numpy(superpixel_seg.astype(
            np.int64)).to(dev)
        gt_seg = torch.from_numpy(gt_seg.astype(np.int64)).to(dev)

        probas = g_file[
            'edge_feat'][:, 0]  # take initial edge features as weights

        # make sure probas are probas and get a sample prediction
        probas -= probas.min()
        probas /= (probas.max() + 1e-6)
        pred_seg = multicut_from_probas(superpixel_seg.cpu().numpy(),
                                        g_file['edges'][:].T, probas)
        pred_seg = torch.from_numpy(pred_seg.astype(np.int64)).to(dev)

        # relabel to consecutive integers:
        mask = gt_seg[None] == torch.unique(gt_seg)[:, None, None]
        gt_seg = (mask * (torch.arange(len(
            torch.unique(gt_seg)), device=dev)[:, None, None] + 1)).sum(0) - 1
        mask = superpixel_seg[None] == torch.unique(superpixel_seg)[:, None,
                                                                    None]
        superpixel_seg = (
            mask * (torch.arange(len(torch.unique(superpixel_seg)),
                                 device=dev)[:, None, None] + 1)).sum(0) - 1
        mask = pred_seg[None] == torch.unique(pred_seg)[:, None, None]
        pred_seg = (mask *
                    (torch.arange(len(torch.unique(pred_seg)),
                                  device=dev)[:, None, None] + 1)).sum(0) - 1
Ejemplo n.º 6
0
    fnames_graph = sorted(glob('/g/kreshuk/hilt/projects/data/artificial_cells/graph_data/*.h5'))
    gt = torch.from_numpy(h5py.File(fnames_pix[42], 'r')['gt'][:]).to(dev)

    # set gt to integer labels
    _gt = torch.zeros_like(gt).long()
    for _lbl, lbl in enumerate(torch.unique(gt)):
        _gt += (gt == lbl).long() * _lbl
    gt = _gt
    sample_shapes = torch.zeros((int(gt.max()) + 1,) + gt.size(), device=dev).scatter_(0, gt[None], 1)[1:]  # 0 should be background

    g_file = h5py.File(fnames_graph[42], 'r')
    superpixel_seg = g_file['node_labeling'][:]
    probas = g_file['edge_feat'][:, 0]  # take initial edge features as weights

    # make sure probas are probas and get a sample prediction
    probas -= probas.min()
    probas /= (probas.max() + 1e-6)
    pred_seg = multicut_from_probas(superpixel_seg, g_file['edges'][:].T, probas)
    pred_seg = torch.from_numpy(pred_seg.astype(np.int64)).to(dev)
    superpixel_seg = torch.from_numpy(superpixel_seg.astype(np.int64)).to(dev)

    # assert the segmentations are consecutive integers
    assert pred_seg.max() == len(torch.unique(pred_seg)) - 1
    assert superpixel_seg.max() == len(torch.unique(superpixel_seg)) - 1

    # add batch dimension
    pred_seg = pred_seg[None]
    superpixel_seg = superpixel_seg[None]

    f = ArtificialCellsReward(sample_shapes)
    rewards = f(pred_seg.long(), superpixel_seg.long())