Ejemplo n.º 1
0
def preprocess_data_1():
    for dir in [tgtdir_train, tgtdir_val]:
        fnames = sorted(glob(os.path.join(dir, 'raw_wtsd/*.h5')))
        pix_dir = os.path.join(dir, 'pix_data')
        graph_dir = os.path.join(dir, 'graph_data')
        for i, fname in enumerate(fnames):
            raw = h5py.File(fname, 'r')['raw'][:]
            gt = h5py.File(fname, 'r')['wtsd'][:]
            head, tail = os.path.split(fname)
            hmap = torch.from_numpy(
                h5py.File(
                    os.path.join(dir, 'edt',
                                 tail[:-3] + '_predictions' + '.h5'),
                    'r')['predictions'][:]).squeeze()
            hmap = torch.sigmoid(hmap).numpy()
            # sep = affs.shape[0] // 2
            # affs = torch.sigmoid(affs)

            node_labeling = run_watershed(gaussian_filter(hmap, sigma=.2),
                                          min_size=4)
            edge_feat, edges = get_edge_features_1d(node_labeling, offs, affs)
            gt_edge_weights = calculate_gt_edge_costs(edges,
                                                      node_labeling.squeeze(),
                                                      gt.squeeze())

            edges = edges.astype(np.long)

            affs = affs.astype(np.float32)
            edge_feat = edge_feat.astype(np.float32)
            node_labeling = node_labeling.astype(np.float32)
            gt_edge_weights = gt_edge_weights.astype(np.float32)
            diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()
            edges = np.sort(edges, axis=-1)
            edges = edges.T

            graph_file = h5py.File(
                os.path.join(graph_dir, "graph_" + str(i) + ".h5"), 'w')
            pix_file = h5py.File(
                os.path.join(pix_dir, "pix_" + str(i) + ".h5"), 'w')

            pix_file.create_dataset("raw", data=raw, chunks=True)
            pix_file.create_dataset("gt", data=gt, chunks=True)

            graph_file.create_dataset("edges", data=edges, chunks=True)
            graph_file.create_dataset("edge_feat", data=edge_feat, chunks=True)
            graph_file.create_dataset("diff_to_gt", data=diff_to_gt)
            graph_file.create_dataset("gt_edge_weights",
                                      data=gt_edge_weights,
                                      chunks=True)
            graph_file.create_dataset("node_labeling",
                                      data=node_labeling,
                                      chunks=True)
            graph_file.create_dataset("affinities", data=affs, chunks=True)

            graph_file.close()
            pix_file.close()

    pass
Ejemplo n.º 2
0
def get_graphs(img, gt, sigma, edge_offsets):
    overseg_factor = 1.7
    sep_chnl = 2

    affinities = get_naive_affinities(gaussian(img, sigma=sigma), edge_offsets)
    affinities[:sep_chnl] *= -1
    affinities[:sep_chnl] += +1
    # scale affinities in order to get an oversegmentation
    affinities[:sep_chnl] /= overseg_factor
    affinities[sep_chnl:] *= overseg_factor
    affinities = np.clip(affinities, 0, 1)
    node_labeling = compute_mws_segmentation(affinities, edge_offsets,
                                             sep_chnl)
    node_labeling = node_labeling - 1
    nodes = np.unique(node_labeling)
    try:
        assert all(nodes == np.array(range(len(nodes)), dtype=np.float))
    except:
        Warning("node ids are off")

    # get edges from node labeling and edge features from affinity stats
    edge_feat, neighbors = get_edge_features_1d(node_labeling, edge_offsets,
                                                affinities)
    # get gt edge weights based on edges and gt image
    gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                              node_labeling.squeeze(),
                                              gt.squeeze(), 0.5)
    edges = neighbors.astype(np.long)

    # calc multicut from gt
    gt_seg = get_current_soln(gt_edge_weights, node_labeling, edges)

    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4)
    ax1.imshow(cm.prism(gt / gt.max()))
    ax1.set_title('gt')
    ax2.imshow(cm.prism(node_labeling / node_labeling.max()))
    ax2.set_title('sp')
    ax3.imshow(cm.prism(gt_seg / gt_seg.max()))
    ax3.set_title('mc')
    ax4.imshow(img)
    ax4.set_title('raw')
    plt.show()

    affinities = affinities.astype(np.float32)
    edge_feat = edge_feat.astype(np.float32)
    nodes = nodes.astype(np.float32)
    node_labeling = node_labeling.astype(np.float32)
    gt_edge_weights = gt_edge_weights.astype(np.float32)
    diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()

    edges = np.sort(edges, axis=-1)
    edges = edges.T

    return img, gt, edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, nodes, affinities
Ejemplo n.º 3
0
def get_sp_graph(data, gt, scal=1.01):
    offsets = [[0, -1], [-1, 0], [-3, 0], [0, -3]]
    sep_chnl = 2
    shape = (128, 128)

    affinities = affutils.get_naive_affinities(data, offsets)
    gt_affinities, _ = compute_affinities(gt == 1, offsets)
    gt_affinities[sep_chnl:] *= -1
    gt_affinities[sep_chnl:] += +1
    affinities[sep_chnl:] *= -1
    affinities[sep_chnl:] += +1
    affinities[sep_chnl:] *= scal
    affinities = (affinities - (affinities * gt_affinities)) + gt_affinities

    affinities = affinities.clip(0, 1)

    valid_edges = get_valid_edges((len(offsets), ) + shape, offsets, sep_chnl,
                                  None, False)
    node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
        affinities.ravel(), valid_edges.ravel(), offsets, sep_chnl, shape)
    node_labeling = node_labeling - 1

    nodes = np.unique(node_labeling)
    try:
        assert all(nodes == np.array(range(len(nodes)), dtype=np.float))
    except:
        Warning("node ids are off")

    noisy_affinities = np.random.rand(*affinities.shape)
    noisy_affinities = noisy_affinities.clip(0, 1)
    noisy_affinities = affinities

    edge_feat, neighbors = get_edge_features_1d(node_labeling, offsets,
                                                noisy_affinities)
    gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                              node_labeling.squeeze(),
                                              gt.squeeze())

    edges = neighbors.astype(np.long)
    noisy_affinities = noisy_affinities.astype(np.float32)
    edge_feat = edge_feat.astype(np.float32)
    nodes = nodes.astype(np.float32)
    node_labeling = node_labeling.astype(np.float32)
    gt_edge_weights = gt_edge_weights.astype(np.float32)
    diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()

    edges = np.sort(edges, axis=-1)
    edges = edges.T
    # edges = np.concatenate((edges, np.stack((edges[1], edges[0]))), axis=1)

    # return node_labeling
    # print('imbalance: ', abs(gt_edge_weights.sum() - (len(gt_edge_weights) / 2)))

    return edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, nodes, noisy_affinities
Ejemplo n.º 4
0
    def read_from_h5(self, f_name):
        h5file = h5py.File(f_name, 'r')

        edges = h5file['edges'][:]
        edge_feat = h5file['edge_feat'][:]
        diff_to_gt = h5file['diff_to_gt'][()]
        gt_edge_weights = h5file['gt_edge_weights'][:]
        node_labeling = h5file['node_labeling'][:]
        raw = h5file['raw'][:]
        noisy_affinities = h5file['noisy_affinities'][:]
        gt = h5file['gt'][:]

        h5file.close()

        i = 0
        segs = np.unique(node_labeling)
        new_labeling = np.zeros_like(node_labeling)
        for seg in segs:
            new_labeling += (node_labeling == seg) * i
            i += 1

        node_labeling = new_labeling

        edge_feat, neighbors = get_edge_features_1d(node_labeling,
                                                    self.offsets,
                                                    noisy_affinities)
        gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                                  node_labeling.squeeze(),
                                                  gt.squeeze()).astype(
                                                      np.float32)

        edges = torch.from_numpy(neighbors.astype(np.long))
        edges = edges.t().contiguous()
        edges = torch.cat((edges, torch.stack((edges[1], edges[0]))), dim=1)

        nodes = np.unique(node_labeling)

        if self.no_suppix:
            raw = torch.from_numpy(raw).float()
            return raw.unsqueeze(0), torch.from_numpy(gt.astype(np.long))
        if self.no_rl:
            raw = torch.from_numpy(raw).float()
            gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.long))
            return raw.unsqueeze(0), torch.from_numpy(gt.astype(
                np.long)), gt_edge_weights

        print('imbalance: ',
              abs(gt_edge_weights.sum() - (len(gt_edge_weights) / 2)))

        return edges, torch.from_numpy(edge_feat).float(), diff_to_gt, torch.from_numpy(gt_edge_weights), \
               torch.from_numpy(node_labeling), torch.from_numpy(raw).float(), torch.from_numpy(nodes), \
               torch.from_numpy(noisy_affinities).float(), torch.from_numpy(gt)
Ejemplo n.º 5
0
    def get(self, idx):
        radius = np.random.randint(max(self.shape) // 5, max(self.shape) // 3)
        mp = (np.random.randint(0 + radius, self.shape[0] - radius),
              np.random.randint(0 + radius, self.shape[1] - radius))
        # mp = self.mp
        data = np.zeros(shape=self.shape, dtype=np.float)
        gt = np.zeros(shape=self.shape, dtype=np.float)
        for y in range(self.shape[0]):
            for x in range(self.shape[1]):
                ly, lx = y - mp[0], x - mp[1]
                if (ly**2 + lx**2)**.5 <= radius:
                    data[y, x] += np.sin(x * 10 * np.pi / self.shape[1])
                    data[y, x] += np.sin(
                        np.sqrt(x**2 + y**2) * 20 * np.pi / self.shape[1])
                    # data[y, x] += 4
                    gt[y, x] = 1
                else:
                    data[y, x] += np.sin(y * 10 * np.pi / self.shape[1])
                    data[y, x] += np.sin(
                        np.sqrt(x**2 + (self.shape[1] - y)**2) * 10 * np.pi /
                        self.shape[1])
        data += 1
        # plt.imshow(data);plt.show()
        gt_affinities, _ = compute_affinities(gt == 1, offsets)

        seg_arbitrary = np.zeros_like(data)
        square_dict = {}
        i = 0
        granularity = 30
        for y in range(self.shape[0]):
            for x in range(self.shape[1]):
                if (x // granularity, y // granularity) not in square_dict:
                    square_dict[(x // granularity, y // granularity)] = i
                    i += 1
                seg_arbitrary[y, x] += square_dict[(x // granularity,
                                                    y // granularity)]
        seg_arbitrary += gt * 1000
        i = 0
        segs = np.unique(seg_arbitrary)
        seg_arb = np.zeros_like(seg_arbitrary)
        for seg in segs:
            seg_arb += (seg_arbitrary == seg) * i
            i += 1
        seg_arbitrary = seg_arb
        rag = feats.compute_rag(np.expand_dims(seg_arbitrary, axis=0))
        neighbors = rag.uvIds()

        affinities = get_naive_affinities(data, offsets)
        # edge_feat = get_edge_features_1d(seg_arbitrary, offsets, affinities)
        # self.edge_offsets = [[1, 0], [0, 1], [1, 0], [0, 1]]
        # self.sep_chnl = 2
        # affinities = np.stack((ndimage.sobel(data, axis=0), ndimage.sobel(data, axis=1)))
        # affinities = np.concatenate((affinities, affinities), axis=0)
        affinities[:self.sep_chnl] *= -1
        affinities[:self.sep_chnl] += +1
        affinities[self.sep_chnl:] /= 0.2
        #
        raw = torch.tensor(data).unsqueeze(0).unsqueeze(0).float()
        # if self.aff_pred is not None:
        #     gt_affinities[self.sep_chnl:] *= -1
        #     gt_affinities[self.sep_chnl:] += +1
        #     gt_affinities[:self.sep_chnl] /= 1.5
        # with torch.set_grad_enabled(False):
        #     affinities = self.aff_pred(raw.to(self.aff_pred.device))
        #     affinities = affinities.squeeze().detach().cpu().numpy()
        #     affinities[self.sep_chnl:] *= -1
        #     affinities[self.sep_chnl:] += +1
        #     affinities[:self.sep_chnl] /= 1.2

        valid_edges = get_valid_edges((len(self.edge_offsets), ) + self.shape,
                                      self.edge_offsets, self.sep_chnl, None,
                                      False)
        node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
            affinities.ravel(), valid_edges.ravel(), offsets, self.sep_chnl,
            self.shape)
        node_labeling = node_labeling - 1
        node_labeling = seg_arbitrary
        # plt.imshow(cm.prism(node_labeling/node_labeling.max()));plt.show()
        # plt.imshow(data);plt.show()
        neighbors = (node_labeling.ravel())[neighbors]
        nodes = np.unique(node_labeling)
        edge_feat = get_edge_features_1d(node_labeling, offsets, affinities)

        # for i, node in enumerate(nodes):
        #     seg = node_labeling == node
        #     masked_data = seg * data
        #     idxs = np.where(seg)
        #     dxs1 = np.stack(idxs).transpose()
        #     # y, x = bbox(np.expand_dims(seg, 0))
        #     # y, x = y[0], x[0]
        #     mass = np.sum(seg)
        #     # _, s, _ = np.linalg.svd(StandardScaler().fit_transform(seg))
        #     mean = np.sum(masked_data) / mass
        #     cm = np.sum(dxs1, axis=0) / mass
        #     var = np.var(data[idxs[0], idxs[1]])
        #
        #     mean = 0 if mean < .5 else 1
        #
        #     node_features[node] = torch.tensor([mean])

        offsets_3d = [[0, 0, -1], [0, -1, 0], [0, -3, 0], [0, 0, -3]]

        # rag = feats.compute_rag(np.expand_dims(node_labeling, axis=0))
        # edge_feat = feats.compute_affinity_features(rag, np.expand_dims(affinities, axis=1), offsets_3d)[:, :]
        # gt_edge_weights = feats.compute_affinity_features(rag, np.expand_dims(gt_affinities, axis=1), offsets_3d)[:, 0]
        gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                                  node_labeling.squeeze(),
                                                  gt.squeeze())
        # gt_edge_weights = utils.calculate_naive_gt_edge_costs(neighbors, node_features).unsqueeze(-1)
        # affs = np.expand_dims(affinities, axis=1)
        # boundary_input = np.mean(affs, axis=0)
        # plt.imshow(multicut_from_probas(node_labeling, neighbors, gt_edge_weights, boundary_input));plt.show()

        # neighs = np.empty((10, 2))
        # gt_neighs = np.empty(10)
        # neighs[0] = neighbors[30]
        # gt_neighs[0] = gt_edge_weights[30]
        # i = 0
        # while True:
        #     for idx, n in enumerate(neighbors):
        #         if n[0] in neighs.ravel() or n[1] in neighs.ravel():
        #             neighs[i] = n
        #             gt_neighs[i] = gt_edge_weights[idx]
        #             i += 1
        #             if i == 10:
        #                 break
        #     if i == 10:
        #         break
        #
        # nodes = nodes[np.unique(neighs.ravel())]
        # node_features = nodes
        # neighbors = neighs

        edges = torch.from_numpy(neighbors.astype(np.long))
        raw = raw.squeeze()
        edge_feat = torch.from_numpy(edge_feat.astype(np.float32))
        nodes = torch.from_numpy(nodes.astype(np.float32))
        # gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.float32))
        # affinities = torch.from_numpy(affinities.astype(np.float32))
        affinities = torch.from_numpy(gt_affinities.astype(np.float32))
        gt_affinities = torch.from_numpy(gt_affinities.astype(np.float32))
        node_labeling = torch.from_numpy(node_labeling.astype(np.float32))

        gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.float32))
        # noise = torch.randn_like(edge_feat) / 3
        # edge_feat += noise
        # edge_feat = torch.min(edge_feat, torch.ones_like(edge_feat))
        # edge_feat = torch.max(edge_feat, torch.zeros_like(edge_feat))
        diff_to_gt = (edge_feat[:, 0] - gt_edge_weights).abs().sum()

        node_features, angles = get_stacked_node_data(nodes,
                                                      edges,
                                                      node_labeling,
                                                      raw,
                                                      size=[32, 32])
        # plt.imshow(node_features.view(-1, 32));
        # plt.show()

        edges = edges.t().contiguous()
        edges = torch.cat((edges, torch.stack((edges[1], edges[0]))), dim=1)

        return edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles
Ejemplo n.º 6
0
    def get(self, idx):
        n_disc = np.random.randint(8, 10)
        rads = []
        mps = []
        for disc in range(n_disc):
            radius = np.random.randint(
                max(self.shape) // 18,
                max(self.shape) // 15)
            touching = True
            while touching:
                mp = np.array([
                    np.random.randint(0 + radius, self.shape[0] - radius),
                    np.random.randint(0 + radius, self.shape[1] - radius)
                ])
                touching = False
                for other_rad, other_mp in zip(rads, mps):
                    diff = mp - other_mp
                    if (diff**2).sum()**.5 <= radius + other_rad + 2:
                        touching = True
            rads.append(radius)
            mps.append(mp)

        # take static image
        # rads = self.rads
        # mps = self.mps

        data = np.zeros(shape=self.shape, dtype=np.float)
        gt = np.zeros(shape=self.shape, dtype=np.float)
        for y in range(self.shape[0]):
            for x in range(self.shape[1]):
                bg = True
                for radius, mp in zip(rads, mps):
                    ly, lx = y - mp[0], x - mp[1]
                    if (ly**2 + lx**2)**.5 <= radius:
                        data[y, x] += np.cos(
                            np.sqrt((x - self.shape[1])**2 + y**2) * 50 *
                            np.pi / self.shape[1])
                        data[y, x] += np.cos(
                            np.sqrt(x**2 + y**2) * 50 * np.pi / self.shape[1])
                        # data[y, x] += 6
                        gt[y, x] = 1
                        bg = False
                if bg:
                    data[y, x] += np.cos(y * 40 * np.pi / self.shape[0])
                    data[y, x] += np.cos(
                        np.sqrt(x**2 + (self.shape[0] - y)**2) * 30 * np.pi /
                        self.shape[1])
        data += 1
        # plt.imshow(data);plt.show()
        # if self.no_suppix:
        #     raw = torch.from_numpy(data).float()
        #     return raw.unsqueeze(0), torch.from_numpy(gt.astype(np.long))
        # return torch.stack((torch.rand_like(raw), raw, torch.rand_like(raw))), torch.from_numpy(gt.astype(np.long))

        affinities = affutils.get_naive_affinities(data, self.offsets)
        gt_affinities, _ = compute_affinities(gt == 1, self.offsets)
        gt_affinities[self.sep_chnl:] *= -1
        gt_affinities[self.sep_chnl:] += +1
        affinities[self.sep_chnl:] *= -1
        affinities[self.sep_chnl:] += +1
        # affinities[:self.sep_chnl] /= 1.1
        affinities[self.sep_chnl:] *= 1.01
        affinities = (affinities -
                      (affinities * gt_affinities)) + gt_affinities

        # affinities[self.sep_chnl:] *= -1
        # affinities[self.sep_chnl:] += +1
        # affinities[self.sep_chnl:] *= 4
        affinities = affinities.clip(0, 1)

        valid_edges = get_valid_edges((len(self.offsets), ) + self.shape,
                                      self.offsets, self.sep_chnl, None, False)
        node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
            affinities.ravel(), valid_edges.ravel(), self.offsets,
            self.sep_chnl, self.shape)
        node_labeling = node_labeling - 1
        # rag = elf.segmentation.features.compute_rag(np.expand_dims(node_labeling, axis=0))
        # neighbors = rag.uvIds()
        i = 0

        # node_labeling = gt * 5000 + node_labeling
        # segs = np.unique(node_labeling)
        #
        # new_labeling = np.zeros_like(node_labeling)
        # for seg in segs:
        #     i += 1
        #     new_labeling += (node_labeling == seg) * i
        #
        # node_labeling = new_labeling - 1

        # gt_labeling, _, _, _ = compute_mws_segmentation_cstm(gt_affinities.ravel(),
        #                                                      valid_edges.ravel(),
        #                                                      offsets,
        #                                                      self.shape)
        #                                                      self.sep_chnl,

        nodes = np.unique(node_labeling)
        try:
            assert all(nodes == np.array(range(len(nodes)), dtype=np.float))
        except:
            Warning("node ids are off")

        noisy_affinities = np.random.rand(*affinities.shape)
        noisy_affinities = noisy_affinities.clip(0, 1)
        noisy_affinities = affinities

        edge_feat, neighbors = get_edge_features_1d(node_labeling,
                                                    self.offsets,
                                                    noisy_affinities)
        gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                                  node_labeling.squeeze(),
                                                  gt.squeeze())

        if self.less:
            raw = torch.from_numpy(data).float()
            node_labeling = torch.from_numpy(node_labeling.astype(np.float32))
            gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.long))
            edges = torch.from_numpy(neighbors.astype(np.long))
            edges = edges.t().contiguous()
            edges = torch.cat((edges, torch.stack((edges[1], edges[0]))),
                              dim=1)
            return raw.unsqueeze(0), node_labeling, torch.from_numpy(
                gt.astype(np.long)), gt_edge_weights, edges

        # affs = np.expand_dims(affinities, axis=1)
        # boundary_input = np.mean(affs, axis=0)
        # gt1 = gutils.multicut_from_probas(node_labeling.astype(np.float32), neighbors.astype(np.float32),
        #                                  gt_edge_weights.astype(np.float32), boundary_input.astype(np.float32))

        # plt.imshow(node_labeling)
        # plt.show()
        # plt.imshow(gt1)
        # plt.show()

        gt = torch.from_numpy(gt.astype(np.float32)).squeeze().float()

        edges = torch.from_numpy(neighbors.astype(np.long))
        raw = torch.tensor(data).squeeze().float()
        noisy_affinities = torch.tensor(noisy_affinities).squeeze().float()
        edge_feat = torch.from_numpy(edge_feat.astype(np.float32))
        nodes = torch.from_numpy(nodes.astype(np.float32))
        node_labeling = torch.from_numpy(node_labeling.astype(np.float32))
        gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.float32))
        diff_to_gt = (edge_feat[:, 0] - gt_edge_weights).abs().sum().item()
        # node_features, angles = get_stacked_node_data(nodes, edges, node_labeling, raw, size=[32, 32])

        # file = h5py.File("/g/kreshuk/hilt/projects/rags/" + "rag_" + str(self.fidx) + ".h5", "w")
        # file.create_dataset("edges", data=edges.numpy())
        # self.fidx += 1

        if self.no_suppix:
            raw = torch.from_numpy(data).float()
            return raw.unsqueeze(0), torch.from_numpy(gt.numpy().astype(
                np.long))

        edges = edges.t().contiguous()
        edges = torch.cat((edges, torch.stack((edges[1], edges[0]))), dim=1)

        # print('imbalance: ', abs(gt_edge_weights.sum() - (len(gt_edge_weights) / 2)))

        return edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, noisy_affinities, gt
Ejemplo n.º 7
0
def get_pix_data(length=50000, shape=(128, 128), radius=72):
    dim = (256, 256)
    edge_offsets = [
        [0, -1],
        [-1, 0],
        # direct 3d nhood for attractive edges
        # [0, -1], [-1, 0]]
        [-3, 0],
        [0, -3],
        [-6, 0],
        [0, -6]
    ]
    sep_chnl = 2
    n_ellips = 5
    n_polys = 10
    n_rect = 5
    ellips_color = np.array([1, 0, 0], dtype=np.float)
    rect_color = np.array([0, 0, 1], dtype=np.float)
    col_diff = 0.4
    min_r, max_r = 10, 20
    min_dist = max_r

    img = np.random.randn(*(dim + (3, ))) / 5
    gt = np.zeros(dim)

    ri1, ri2, ri3, ri4, ri5, ri6 = np.sign(np.random.randint(-100, 100)) * (
        (np.random.rand() * 2) + .5), np.sign(np.random.randint(-100, 100)) * (
            (np.random.rand() * 2) + .5), (np.random.rand() * 4) + 3, (
                np.random.rand() * 4) + 3, np.sign(np.random.randint(
                    -100, 100)) * ((np.random.rand() * 2) + .5), np.sign(
                        np.random.randint(-100, 100)) * (
                            (np.random.rand() * 2) + .5)
    x = np.zeros(dim)
    x[:, :] = np.arange(img.shape[0])[np.newaxis, :]
    y = x.transpose()
    img += (np.sin(
        np.sqrt((x * ri1)**2 + ((dim[1] - y) * ri2)**2) * ri3 * np.pi /
        dim[0]))[..., np.newaxis]
    img += (np.sin(
        np.sqrt((x * ri5)**2 + ((dim[1] - y) * ri6)**2) * ri4 * np.pi /
        dim[1]))[..., np.newaxis]
    img = gaussian(np.clip(img, 0.1, 1), sigma=.8)
    circles = []
    cmps = []
    while len(circles) < n_ellips:
        mp = np.random.randint(min_r, dim[0] - min_r, 2)
        too_close = False
        for cmp in cmps:
            if np.linalg.norm(cmp - mp) < min_dist:
                too_close = True
        if too_close:
            continue
        r = np.random.randint(min_r, max_r, 2)
        circles.append(draw.circle(mp[0], mp[1], r[0], shape=dim))
        cmps.append(mp)

    polys = []
    while len(polys) < n_polys:
        mp = np.random.randint(min_r, dim[0] - min_r, 2)
        too_close = False
        for cmp in cmps:
            if np.linalg.norm(cmp - mp) < min_dist // 2:
                too_close = True
        if too_close:
            continue
        circle = draw.circle_perimeter(mp[0], mp[1], max_r)
        poly_vert = np.random.choice(len(circle[0]),
                                     np.random.randint(3, 6),
                                     replace=False)
        polys.append(
            draw.polygon(circle[0][poly_vert], circle[1][poly_vert],
                         shape=dim))
        cmps.append(mp)

    rects = []
    while len(rects) < n_rect:
        mp = np.random.randint(min_r, dim[0] - min_r, 2)
        _len = np.random.randint(min_r // 2, max_r, (2, ))
        too_close = False
        for cmp in cmps:
            if np.linalg.norm(cmp - mp) < min_dist:
                too_close = True
        if too_close:
            continue
        start = (mp[0] - _len[0], mp[1] - _len[1])
        rects.append(
            draw.rectangle(start, extent=(_len[0] * 2, _len[1] * 2),
                           shape=dim))
        cmps.append(mp)

    for poly in polys:
        color = np.random.rand(3)
        while np.linalg.norm(color -
                             ellips_color) < col_diff or np.linalg.norm(
                                 color - rect_color) < col_diff:
            color = np.random.rand(3)
        img[poly[0], poly[1], :] = color
        img[poly[0], poly[1], :] += np.random.randn(len(poly[1]), 3) / 5

    cols = np.random.choice(np.arange(4, 11, 1).astype(np.float) / 10,
                            n_ellips,
                            replace=False)
    for i, ellipse in enumerate(circles):
        gt[ellipse[0], ellipse[1]] = 1 + (i / 10)
        ri1, ri2, ri3, ri4, ri5, ri6 = np.sign(np.random.randint(
            -100, 100)) * ((np.random.rand() * 4) + 7), np.sign(
                np.random.randint(-100, 100)) * (
                    (np.random.rand() * 4) + 7), (np.random.rand() + 1) * 3, (
                        np.random.rand() + 1) * 3, np.sign(
                            np.random.randint(-100, 100)) * (
                                (np.random.rand() * 4) + 7), np.sign(
                                    np.random.randint(-100, 100)) * (
                                        (np.random.rand() * 4) + 7)
        img[ellipse[0], ellipse[1], :] = np.array([cols[i], 0.0, 0.0])
        img[ellipse[0], ellipse[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[ellipse[0], ellipse[1]] * ri5)**2 +
                    ((dim[1] - y[ellipse[0], ellipse[1]]) * ri2)**2) * ri3 *
            np.pi / dim[0]))[..., np.newaxis] * 0.15) + 0.2
        img[ellipse[0], ellipse[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[ellipse[0], ellipse[1]] * ri6)**2 +
                    ((dim[1] - y[ellipse[0], ellipse[1]]) * ri1)**2) * ri4 *
            np.pi / dim[1]))[..., np.newaxis] * 0.15) + 0.2
        # img[ellipse[0], ellipse[1], :] += np.random.randn(len(ellipse[1]), 3) / 10

    cols = np.random.choice(np.arange(4, 11, 1).astype(np.float) / 10,
                            n_rect,
                            replace=False)
    for i, rect in enumerate(rects):
        gt[rect[0], rect[1]] = 2 + (i / 10)
        ri1, ri2, ri3, ri4, ri5, ri6 = np.sign(np.random.randint(
            -100, 100)) * ((np.random.rand() * 4) + 7), np.sign(
                np.random.randint(-100, 100)) * (
                    (np.random.rand() * 4) + 7), (np.random.rand() + 1) * 3, (
                        np.random.rand() + 1) * 3, np.sign(
                            np.random.randint(-100, 100)) * (
                                (np.random.rand() * 4) + 7), np.sign(
                                    np.random.randint(-100, 100)) * (
                                        (np.random.rand() * 4) + 7)
        img[rect[0], rect[1], :] = np.array([0.0, 0.0, cols[i]])
        img[rect[0], rect[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[rect[0], rect[1]] * ri5)**2 +
                    ((dim[1] - y[rect[0], rect[1]]) * ri2)**2) * ri3 * np.pi /
            dim[0]))[..., np.newaxis] * 0.15) + 0.2
        img[rect[0], rect[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[rect[0], rect[1]] * ri1)**2 +
                    ((dim[1] - y[rect[0], rect[1]]) * ri6)**2) * ri4 * np.pi /
            dim[1]))[..., np.newaxis] * 0.15) + 0.2
        # img[rect[0], rect[1], :] += np.random.randn(*(rect[1].shape + (3,)))/10

    img = np.clip(img, 0, 1)

    affinities = get_naive_affinities(gaussian(np.clip(img, 0, 1), sigma=.2),
                                      edge_offsets)
    affinities[:sep_chnl] *= -1
    affinities[:sep_chnl] += +1
    affinities[:sep_chnl] /= 1.3
    affinities[sep_chnl:] *= 1.3
    affinities = np.clip(affinities, 0, 1)
    #
    valid_edges = get_valid_edges((len(edge_offsets), ) + dim, edge_offsets,
                                  sep_chnl, None, False)
    node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
        affinities.ravel(), valid_edges.ravel(), edge_offsets, sep_chnl, dim)
    node_labeling = node_labeling - 1
    nodes = np.unique(node_labeling)
    try:
        assert all(nodes == np.array(range(len(nodes)), dtype=np.float))
    except:
        Warning("node ids are off")

    edge_feat, neighbors = get_edge_features_1d(node_labeling, edge_offsets,
                                                affinities)
    gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                              node_labeling.squeeze(),
                                              gt.squeeze())
    edges = neighbors.astype(np.long)

    gt_seg = get_current_soln(gt_edge_weights, node_labeling, edges)
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
    ax1.imshow(cm.prism(gt / gt.max()))
    ax1.set_title('gt')
    ax2.imshow(cm.prism(node_labeling / node_labeling.max()))
    ax2.set_title('sp')
    ax3.imshow(cm.prism(gt_seg / gt_seg.max()))
    ax3.set_title('mc')
    plt.show()

    affinities = affinities.astype(np.float32)
    edge_feat = edge_feat.astype(np.float32)
    nodes = nodes.astype(np.float32)
    node_labeling = node_labeling.astype(np.float32)
    gt_edge_weights = gt_edge_weights.astype(np.float32)
    diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()

    edges = np.sort(edges, axis=-1)
    edges = edges.T

    return img, gt, edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, nodes, affinities
Ejemplo n.º 8
0
def get_pix_data(shape=(256, 256)):
    """ This generates raw-gt-superpixels and correspondinng rags of rectangles and circles"""

    rsign = lambda: (-1)**np.random.randint(0, 2)
    edge_offsets = [[0, -1], [-1, 0], [-3, 0], [0, -3], [-6, 0],
                    [0, -6]]  # offsets defining the edges for pixel affinities
    overseg_factor = 1.7
    sep_chnl = 2  # channel separating attractive from repulsive edges
    n_circles = 5  # number of ellipses in image
    n_polys = 10  # number of rand polys in image
    n_rect = 5  # number rectangles in image
    circle_color = np.array([1, 0, 0], dtype=np.float)
    rect_color = np.array([0, 0, 1], dtype=np.float)
    col_diff = 0.4  # by this margin object color can vary ranomly
    min_r, max_r = 10, 20  # min and max radii of ellipses/circles
    min_dist = max_r

    img = np.random.randn(*(shape + (3, ))) / 5  # init image with some noise
    gt = np.zeros(shape)

    #  get some random frequencies
    ri1, ri2, ri3, ri4, ri5, ri6 = rsign() * ((np.random.rand() * 2) + .5), \
                                   rsign() * ((np.random.rand() * 2) + .5), \
                                   (np.random.rand() * 4) + 3, \
                                   (np.random.rand() * 4) + 3, \
                                   rsign() * ((np.random.rand() * 2) + .5), \
                                   rsign() * ((np.random.rand() * 2) + .5)
    x = np.zeros(shape)
    x[:, :] = np.arange(img.shape[0])[np.newaxis, :]
    y = x.transpose()
    # add background frequency interferences
    img += (np.sin(
        np.sqrt((x * ri1)**2 + ((shape[1] - y) * ri2)**2) * ri3 * np.pi /
        shape[0]))[..., np.newaxis]
    img += (np.sin(
        np.sqrt((x * ri5)**2 + ((shape[1] - y) * ri6)**2) * ri4 * np.pi /
        shape[1]))[..., np.newaxis]
    # smooth a bit
    img = gaussian(np.clip(img, 0.1, 1), sigma=.8)
    # add some circles
    circles = []
    cmps = []
    while len(circles) < n_circles:
        mp = np.random.randint(min_r, shape[0] - min_r, 2)
        too_close = False
        for cmp in cmps:
            if np.linalg.norm(cmp - mp) < min_dist:
                too_close = True
        if too_close:
            continue
        r = np.random.randint(min_r, max_r, 2)
        circles.append(draw.circle(mp[0], mp[1], r[0], shape=shape))
        cmps.append(mp)

    # add some random polygons
    polys = []
    while len(polys) < n_polys:
        mp = np.random.randint(min_r, shape[0] - min_r, 2)
        too_close = False
        for cmp in cmps:
            if np.linalg.norm(cmp - mp) < min_dist // 2:
                too_close = True
        if too_close:
            continue
        circle = draw.circle_perimeter(mp[0], mp[1], max_r)
        poly_vert = np.random.choice(len(circle[0]),
                                     np.random.randint(3, 6),
                                     replace=False)
        polys.append(
            draw.polygon(circle[0][poly_vert],
                         circle[1][poly_vert],
                         shape=shape))
        cmps.append(mp)

    # add some random rectangles
    rects = []
    while len(rects) < n_rect:
        mp = np.random.randint(min_r, shape[0] - min_r, 2)
        _len = np.random.randint(min_r // 2, max_r, (2, ))
        too_close = False
        for cmp in cmps:
            if np.linalg.norm(cmp - mp) < min_dist:
                too_close = True
        if too_close:
            continue
        start = (mp[0] - _len[0], mp[1] - _len[1])
        rects.append(
            draw.rectangle(start,
                           extent=(_len[0] * 2, _len[1] * 2),
                           shape=shape))
        cmps.append(mp)

    # draw polys and give them some noise
    for poly in polys:
        color = np.random.rand(3)
        while np.linalg.norm(color -
                             circle_color) < col_diff or np.linalg.norm(
                                 color - rect_color) < col_diff:
            color = np.random.rand(3)
        img[poly[0], poly[1], :] = color
        img[poly[0], poly[1], :] += np.random.randn(len(
            poly[1]), 3) / 5  # add noise to the polygons

    # draw circles with some frequency
    cols = np.random.choice(np.arange(4, 11, 1).astype(np.float) / 10,
                            n_circles,
                            replace=False)  # get colors
    for i, circle in enumerate(circles):
        gt[circle[0], circle[1]] = 1 + (i / 10)
        ri1, ri2, ri3, ri4, ri5, ri6 = rsign() * ((np.random.rand() * 4) + 7), \
                                       rsign() * ((np.random.rand() * 4) + 7), \
                                       (np.random.rand() + 1) * 8, \
                                       (np.random.rand() + 1) * 8, \
                                       rsign() * ((np.random.rand() * 4) + 7), \
                                       rsign() * ((np.random.rand() * 4) + 7)

        img[circle[0],
            circle[1], :] = np.array([cols[i], 0.0,
                                      0.0])  # set even color intensity
        # set interference of two freqs in circle color channel
        img[circle[0], circle[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[circle[0], circle[1]] * ri5)**2 +
                    ((shape[1] - y[circle[0], circle[1]]) * ri2)**2) * ri3 *
            np.pi / shape[0]))[..., np.newaxis] * 0.15) + 0.2
        img[circle[0], circle[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[circle[0], circle[1]] * ri6)**2 +
                    ((shape[1] - y[circle[0], circle[1]]) * ri1)**2) * ri4 *
            np.pi / shape[1]))[..., np.newaxis] * 0.15) + 0.2

    # draw rectangles with some frequency
    cols = np.random.choice(np.arange(4, 11, 1).astype(np.float) / 10,
                            n_rect,
                            replace=False)
    for i, rect in enumerate(rects):
        gt[rect[0], rect[1]] = 2 + (i / 10)
        ri1, ri2, ri3, ri4, ri5, ri6 = rsign() * ((np.random.rand() * 4) + 7), \
                                       rsign() * ((np.random.rand() * 4) + 7), \
                                       (np.random.rand() + 1) * 8, \
                                       (np.random.rand() + 1) * 8, \
                                       rsign() * ((np.random.rand() * 4) + 7), \
                                       rsign() * ((np.random.rand() * 4) + 7)
        img[rect[0], rect[1], :] = np.array([0.0, 0.0, cols[i]])
        img[rect[0], rect[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[rect[0], rect[1]] * ri5)**2 +
                    ((shape[1] - y[rect[0], rect[1]]) * ri2)**2) * ri3 *
            np.pi / shape[0]))[..., np.newaxis] * 0.15) + 0.2
        img[rect[0], rect[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[rect[0], rect[1]] * ri1)**2 +
                    ((shape[1] - y[rect[0], rect[1]]) * ri6)**2) * ri4 *
            np.pi / shape[1]))[..., np.newaxis] * 0.15) + 0.2

    img = np.clip(img, 0, 1)  # clip to valid range
    # get affinities and calc superpixels with mutex watershed
    affinities = get_naive_affinities(gaussian(img, sigma=.2), edge_offsets)
    affinities[:sep_chnl] *= -1
    affinities[:sep_chnl] += +1
    # scale affinities in order to get an oversegmentation
    affinities[:sep_chnl] /= overseg_factor
    affinities[sep_chnl:] *= overseg_factor
    affinities = np.clip(affinities, 0, 1)
    node_labeling = compute_mws_segmentation(affinities, edge_offsets,
                                             sep_chnl)
    node_labeling = node_labeling - 1
    nodes = np.unique(node_labeling)
    try:
        assert all(nodes == np.array(range(len(nodes)), dtype=np.float))
    except:
        Warning("node ids are off")

    # get edges from node labeling and edge features from affinity stats
    edge_feat, neighbors = get_edge_features_1d(node_labeling, edge_offsets,
                                                affinities)
    # get gt edge weights based on edges and gt image
    gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                              node_labeling.squeeze(),
                                              gt.squeeze())
    edges = neighbors.astype(np.long)

    # # calc multicut from gt
    # gt_seg = get_current_soln(gt_edge_weights, node_labeling, edges)
    # # show result (uncomment for testing)
    #
    # fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4)
    # ax1.imshow(cm.prism(gt/gt.max()));ax1.set_title('gt')
    # ax2.imshow(cm.prism(node_labeling / node_labeling.max()));ax2.set_title('sp')
    # ax3.imshow(cm.prism(gt_seg / gt_seg.max()));ax3.set_title('mc')
    # ax4.imshow(img);ax4.set_title('raw')
    # plt.show()

    affinities = affinities.astype(np.float32)
    edge_feat = edge_feat.astype(np.float32)
    nodes = nodes.astype(np.float32)
    node_labeling = node_labeling.astype(np.float32)
    gt_edge_weights = gt_edge_weights.astype(np.float32)
    diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()

    edges = np.sort(edges, axis=-1)
    edges = edges.T

    return img, gt, edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, nodes, affinities
Ejemplo n.º 9
0
def preprocess_data():
    for dir in [tgtdir_train, tgtdir_val]:
        pix_dir = os.path.join(dir, 'pix_data')
        graph_dir = os.path.join(dir, 'graph_data')
        fnames_pix = sorted(glob(os.path.join(pix_dir, '*.h5')))
        for i, fname_pix in enumerate(fnames_pix):
            raw = h5py.File(fname_pix, 'r')['raw'][:].squeeze()
            noisy_raw = raw + np.random.normal(0, 0.2, raw.shape)
            noisy_raw = raw
            gt = h5py.File(fname_pix, 'r')['gt'][:]
            head, tail = os.path.split(fname_pix)
            hmap = get_max_hessian_eval(raw, sigma=.5)
            node_labeling = run_watershed(gaussian_filter(hmap, sigma=.1),
                                          min_size=10,
                                          nhood=4)
            _, affs = get_heat_map_by_affs(
                gaussian_filter(noisy_raw, sigma=[3.3, 3.3]))
            # node_labeling = run_watershed(gaussian_filter(hmap, sigma=1), min_size=4, nhood=4)
            edge_feat, edges = get_edge_features_1d(node_labeling, offs, affs)
            gt_edge_weights = calculate_gt_edge_costs(edges,
                                                      node_labeling.squeeze(),
                                                      gt.squeeze())

            edges = edges.astype(np.long)

            affs = affs.astype(np.float32)
            edge_feat = edge_feat.astype(np.float32)
            node_labeling = node_labeling.astype(np.float32)
            gt_edge_weights = gt_edge_weights.astype(np.float32)
            diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()
            edges = np.sort(edges, axis=-1)
            edges = edges.T

            graph_file_name = os.path.join(
                graph_dir, "graph" + os.path.split(fname_pix)[1][3:])
            graph_file = h5py.File(graph_file_name, 'w')
            # pix_file = h5py.File(os.path.join(pix_dir, "pix_" + str(i) + ".h5"), 'w')

            # pix_file.create_dataset("raw", data=raw, chunks=True)
            # pix_file.create_dataset("gt", data=gt, chunks=True)

            graph_file.create_dataset("edges", data=edges, chunks=True)
            graph_file.create_dataset("offsets",
                                      data=np.array(offs),
                                      chunks=True)
            graph_file.create_dataset("separating_channel",
                                      data=np.array([2]),
                                      chunks=True)
            graph_file.create_dataset("edge_feat", data=edge_feat, chunks=True)
            graph_file.create_dataset("diff_to_gt", data=diff_to_gt)
            graph_file.create_dataset("gt_edge_weights",
                                      data=gt_edge_weights,
                                      chunks=True)
            graph_file.create_dataset("node_labeling",
                                      data=node_labeling,
                                      chunks=True)
            graph_file.create_dataset("affinities", data=affs, chunks=True)

            graph_file.close()
            # pix_file.close()

    pass
Ejemplo n.º 10
0
    def create_dsets(self, num):
        for file_index in range(num):
            n_disc = np.random.randint(25, 30)
            rads = []
            mps = []
            for disc in range(n_disc):
                radius = np.random.randint(
                    max(self.shape) // 25,
                    max(self.shape) // 20)
                touching = True
                while touching:
                    mp = np.array([
                        np.random.randint(0 + radius, self.shape[0] - radius),
                        np.random.randint(0 + radius, self.shape[1] - radius)
                    ])
                    touching = False
                    for other_rad, other_mp in zip(rads, mps):
                        diff = mp - other_mp
                        if (diff**2).sum()**.5 <= radius + other_rad + 2:
                            touching = True
                rads.append(radius)
                mps.append(mp)

            data = np.zeros(shape=self.shape, dtype=np.float)
            gt = np.zeros(shape=self.shape, dtype=np.float)
            for y in range(self.shape[0]):
                for x in range(self.shape[1]):
                    bg = True
                    for radius, mp in zip(rads, mps):
                        ly, lx = y - mp[0], x - mp[1]
                        if (ly**2 + lx**2)**.5 <= radius:
                            data[y, x] += np.cos(
                                np.sqrt((x - self.shape[1])**2 + y**2) * 50 *
                                np.pi / self.shape[1])
                            data[y, x] += np.cos(
                                np.sqrt(x**2 + y**2) * 50 * np.pi /
                                self.shape[1])
                            # data[y, x] += 6
                            gt[y, x] = 1
                            bg = False
                    if bg:
                        data[y, x] += np.cos(y * 40 * np.pi / self.shape[0])
                        data[y, x] += np.cos(
                            np.sqrt(x**2 + (self.shape[0] - y)**2) * 30 *
                            np.pi / self.shape[1])
            data += 1
            # plt.imshow(data);plt.show()
            if self.no_suppix:
                raw = torch.from_numpy(data).float()
                return raw.unsqueeze(0), torch.from_numpy(gt.astype(np.long))
                # return torch.stack((torch.rand_like(raw), raw, torch.rand_like(raw))), torch.from_numpy(gt.astype(np.long))

            affinities = affutils.get_naive_affinities(data, self.offsets)
            gt_affinities, _ = compute_affinities(gt == 1, self.offsets)
            gt_affinities[self.sep_chnl:] *= -1
            gt_affinities[self.sep_chnl:] += +1
            affinities[self.sep_chnl:] *= -1
            affinities[self.sep_chnl:] += +1
            # affinities[:self.sep_chnl] /= 1.1
            affinities[self.sep_chnl:] *= 1.01
            affinities = (affinities -
                          (affinities * gt_affinities)) + gt_affinities

            # affinities[self.sep_chnl:] *= -1
            # affinities[self.sep_chnl:] += +1
            # affinities[self.sep_chnl:] *= 4
            affinities = affinities.clip(0, 1)

            valid_edges = get_valid_edges((len(self.offsets), ) + self.shape,
                                          self.offsets, self.sep_chnl, None,
                                          False)
            node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
                affinities.ravel(), valid_edges.ravel(), self.offsets,
                self.sep_chnl, self.shape)
            node_labeling = node_labeling - 1
            nodes = np.unique(node_labeling)
            try:
                assert all(
                    nodes == np.array(range(len(nodes)), dtype=np.float))
            except:
                Warning("node ids are off")

            noisy_affinities = affinities

            edge_feat, neighbors = get_edge_features_1d(
                node_labeling, self.offsets, noisy_affinities)
            gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                                      node_labeling.squeeze(),
                                                      gt.squeeze())

            while abs(gt_edge_weights.sum() - (len(gt_edge_weights) / 2)) > 1:
                edge_idx = np.random.choice(np.arange(len(gt_edge_weights)),
                                            p=torch.softmax(torch.from_numpy(
                                                (gt_edge_weights == 0).astype(
                                                    np.float)),
                                                            dim=0).numpy())
                if gt_edge_weights[edge_idx] != 0.0:
                    continue

                # print(abs(gt_edge_weights.sum() - (len(gt_edge_weights) / 2)))
                edge = neighbors[edge_idx].astype(np.int)
                # merge superpixel
                diff = edge[0] - edge[1]

                mass = (node_labeling == edge[0]).sum()
                node_labeling = node_labeling - (node_labeling
                                                 == edge[0]) * diff
                new_mass = (node_labeling == edge[1]).sum()
                try:
                    assert new_mass >= mass
                except:
                    a = 1

                # if edge_idx == 0:
                #     neighbors = neighbors[1:]
                #     gt_edge_weights = gt_edge_weights[1:]
                # elif edge_idx == len(gt_edge_weights):
                #     neighbors = neighbors[:-1]
                #     gt_edge_weights = gt_edge_weights[:-1]
                # else:
                #     neighbors = np.concatenate((neighbors[:edge_idx], neighbors[edge_idx+1:]), axis=0)
                #     gt_edge_weights = np.concatenate((gt_edge_weights[:edge_idx], gt_edge_weights[edge_idx+1:]), axis=0)
                #
                # neighbors[neighbors == edge[0]] == edge[1]

                edge_feat, neighbors = get_edge_features_1d(
                    node_labeling, self.offsets, noisy_affinities)
                gt_edge_weights = calculate_gt_edge_costs(
                    neighbors, node_labeling.squeeze(), gt.squeeze())

            edge_feat, neighbors = get_edge_features_1d(
                node_labeling, self.offsets, noisy_affinities)
            gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                                      node_labeling.squeeze(),
                                                      gt.squeeze())

            gt = torch.from_numpy(gt.astype(np.float32)).squeeze().float()

            edges = torch.from_numpy(neighbors.astype(np.long))
            raw = torch.tensor(data).squeeze().float()
            noisy_affinities = torch.tensor(noisy_affinities).squeeze().float()
            edge_feat = torch.from_numpy(edge_feat.astype(np.float32))
            nodes = torch.from_numpy(nodes.astype(np.float32))
            node_labeling = torch.from_numpy(node_labeling.astype(np.float32))
            gt_edge_weights = torch.from_numpy(
                gt_edge_weights.astype(np.float32))
            diff_to_gt = (edge_feat[:, 0] - gt_edge_weights).abs().sum()
            edges = edges.t().contiguous()
            edges = torch.cat((edges, torch.stack((edges[1], edges[0]))),
                              dim=1)

            self.write_to_h5(
                '/g/kreshuk/hilt/projects/fewShotLearning/mutexWtsd/data/storage/balanced_graphs/balanced_graph_data'
                + str(file_index) + '.h5', edges, edge_feat, diff_to_gt,
                gt_edge_weights, node_labeling, raw, nodes, noisy_affinities,
                gt)
Ejemplo n.º 11
0
neighbors, nodes, seg, gt_seg, affs, gt_affs = next(iter(dloader_disc))

offsets = [[0, 0, -1], [0, -1, 0], [0, -3, 0], [0, 0, -3]]

affs = np.transpose(affs.cpu().numpy(), (1, 0, 2, 3))
gt_affs = np.transpose(gt_affs.cpu().numpy(), (1, 0, 2, 3))
seg = seg.cpu().numpy()
gt_seg = gt_seg.cpu().numpy()
boundary_input = np.mean(affs, axis=0)
gt_boundary_input = np.mean(gt_affs, axis=0)

rag = feats.compute_rag(seg)
# edges rag.uvIds() [[1, 2], ...]

costs = feats.compute_affinity_features(rag, affs, offsets)[:, 0]
gt_costs = calculate_gt_edge_costs(rag.uvIds(), seg.squeeze(),
                                   gt_seg.squeeze())

edge_sizes = feats.compute_boundary_mean_and_length(rag, boundary_input)[:, 1]
gt_edge_sizes = feats.compute_boundary_mean_and_length(rag,
                                                       gt_boundary_input)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes)
gt_costs = mc.transform_probabilities_to_costs(gt_costs, edge_sizes=edge_sizes)

node_labels = mc.multicut_kernighan_lin(rag, costs)
gt_node_labels = mc.multicut_kernighan_lin(rag, gt_costs)

segmentation = feats.project_node_labels_to_pixels(rag, node_labels)
gt_segmentation = feats.project_node_labels_to_pixels(rag, gt_node_labels)
plt.imshow(
    np.concatenate(
        (gt_segmentation.squeeze(), segmentation.squeeze(), seg.squeeze()),
Ejemplo n.º 12
0
def graphs_for_masked_data():
    for dir in [tgtdir_val]:
        fnames = sorted(glob(os.path.join(dir, 'raw_wtsd/*.h5')))
        pix_dir = os.path.join(dir, 'bg_masked_data/pix_data')
        graph_dir = os.path.join(dir, 'bg_masked_data/graph_data')
        for i in range(len(fnames)):
            fname = fnames[i]
            head, tail = os.path.split(fname)
            num = tail[6:-3]

            raw = h5py.File(fname, 'r')['raw'][:]
            gt = h5py.File(fname, 'r')['label'][:]

            affs = torch.from_numpy(
                h5py.File(
                    os.path.join(dir, 'affinities_01_trainsz',
                                 tail[:-3] + '_predictions' + '.h5'),
                    'r')['predictions'][:]).squeeze(1)
            graph_file = h5py.File(
                os.path.join(graph_dir, "graph_" + num + ".h5"), 'a')
            pix_file = h5py.File(os.path.join(pix_dir, "pix_" + num + ".h5"),
                                 'a')
            affs = torch.sigmoid(affs).numpy()
            #
            node_labeling = h5py.File(
                os.path.join(dir, "bg_masked_data/graph_" + num + ".h5"),
                'r')["node_labeling"][:]
            #
            # # relabel to consecutive ints starting at 0
            node_labeling = torch.from_numpy(node_labeling.astype(np.long))
            gt = torch.from_numpy(gt.astype(np.long))
            mask = node_labeling[None] == torch.unique(node_labeling)[:, None,
                                                                      None]
            node_labeling = (mask * (torch.arange(
                len(torch.unique(node_labeling)),
                device=node_labeling.device)[:, None, None] + 1)).sum(0) - 1

            mask = gt[None] == torch.unique(gt)[:, None, None]
            gt = (mask * (torch.arange(len(torch.unique(gt)), device=gt.device)
                          [:, None, None] + 1)).sum(0) - 1

            edge_feat, edges = get_edge_features_1d(node_labeling.numpy(),
                                                    offs, affs)
            gt_edge_weights = calculate_gt_edge_costs(
                torch.from_numpy(edges.astype(np.long)),
                node_labeling.squeeze(), gt.squeeze(), 0.5)

            gt_edge_weights = gt_edge_weights.numpy()
            gt = gt.numpy()
            node_labeling = node_labeling.numpy()
            edges = edges.astype(np.long)

            affs = affs.astype(np.float32)
            edge_feat = edge_feat.astype(np.float32)
            node_labeling = node_labeling.astype(np.float32)
            gt_edge_weights = gt_edge_weights.astype(np.float32)
            diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()
            # edges = np.sort(edges, axis=-1)
            edges = edges.T

            pix_file.create_dataset("raw", data=raw, chunks=True)
            pix_file.create_dataset("gt", data=gt, chunks=True)
            #
            graph_file.create_dataset("edges", data=edges, chunks=True)
            graph_file.create_dataset("edge_feat", data=edge_feat, chunks=True)
            graph_file.create_dataset("diff_to_gt", data=diff_to_gt)
            graph_file.create_dataset("gt_edge_weights",
                                      data=gt_edge_weights,
                                      chunks=True)
            graph_file.create_dataset("node_labeling",
                                      data=node_labeling,
                                      chunks=True)
            graph_file.create_dataset("affinities", data=affs, chunks=True)
            graph_file.create_dataset("offsets",
                                      data=np.array([[1, 0], [0, 1], [2, 0],
                                                     [0, 2]]),
                                      chunks=True)

            graph_file.close()
            pix_file.close()

    pass
Ejemplo n.º 13
0
        def process_file(i):
            fname = fnames[i]
            head, tail = os.path.split(fname)
            num = tail[4:-3]
            # os.rename(os.path.join(graph_dir, "graph_" + str(i) + ".h5"), os.path.join(graph_dir, "graph_" + num + ".h5"))
            # os.rename(os.path.join(pix_dir, "pix_" + str(i) + ".h5"), os.path.join(pix_dir, "pix_" + num + ".h5"))

            # raw = torch.from_numpy(h5py.File(fname, 'r')['raw'][:].astype(np.float))
            # gt = h5py.File(fname, 'r')['label'][:].astype(np.long)
            # affs = torch.from_numpy(h5py.File(os.path.join(dir, 'affinities', tail[:-3] + '_predictions' + '.h5'), 'r')['predictions'][:]).squeeze(1)

            # raw -= raw.min()
            # raw /= raw.max()
            #
            # node_labeling = run_watershed(gaussian_filter(affs[0] + affs[1] + affs[2] + affs[3], sigma=.2), min_size=4)
            #
            # # relabel to consecutive ints starting at 0
            # node_labeling = torch.from_numpy(node_labeling.astype(np.long))
            #
            #
            # gt = torch.from_numpy(gt.astype(np.long))
            # mask = node_labeling[None] == torch.unique(node_labeling)[:, None, None]
            # node_labeling = (mask * (torch.arange(len(torch.unique(node_labeling)), device=node_labeling.device)[:, None, None] + 1)).sum(
            #     0) - 1
            #
            # node_labeling += 2
            # # bgm
            # node_labeling[gt == 0] = 0
            # node_labeling[gt == 1] = 1
            # plt.imshow(node_labeling);plt.show()
            #
            # mask = gt[None] == torch.unique(gt)[:, None, None]
            # gt = (mask * (torch.arange(len(torch.unique(gt)), device=gt.device)[:, None, None] + 1)).sum(0) - 1
            #
            # edge_img = get_contour_from_2d_binary(node_labeling[None, None].float())
            # edge_img = gauss_kernel(edge_img.float())
            # raw = torch.cat([raw[None, None], edge_img], dim=1).squeeze(0).numpy()
            # affs = torch.sigmoid(affs).numpy()
            #
            # edge_feat, edges = get_edge_features_1d(node_labeling.numpy(), offs, affs[:4])
            #
            # gt_edge_weights = gt_edge_weights.numpy()
            # gt = gt.numpy()
            # node_labeling = node_labeling.numpy()
            # edges = edges.astype(np.long)
            #
            # affs = affs.astype(np.float32)
            # edge_feat = edge_feat.astype(np.float32)
            # node_labeling = node_labeling.astype(np.float32)
            # gt_edge_weights = gt_edge_weights.astype(np.float32)
            # diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()
            # edges = np.sort(edges, axis=-1)
            # edges = edges.T
            # #
            # #
            graph_file = h5py.File(
                os.path.join(graph_dir, "graph_" + num + ".h5"), 'r')
            pix_file = h5py.File(os.path.join(pix_dir, "pix_" + num + ".h5"),
                                 'r')

            raw = pix_file["raw_2chnl"][:]
            gt = pix_file["gt"][:]
            sp_seg = graph_file["node_labeling"][:]
            edges = graph_file["edges"][:]
            affs = graph_file["affinities"][:]

            graph_file.close()
            pix_file.close()
            # plt.imshow(multicut_from_probas(sp_seg, edges.T, calculate_gt_edge_costs(torch.from_numpy(edges.T.astype(np.long)).to(dev), torch.from_numpy(sp_seg).to(dev), torch.from_numpy(gt.squeeze()).to(dev), 1.5).cpu()), cmap=random_label_cmap(), interpolation="none");plt.show()
            with torch.set_grad_enabled(False):
                embeddings = model(
                    torch.from_numpy(raw).to(device).float()[None])
            emb_affs = get_affinities_from_embeddings_2d(
                embeddings, offs, 0.4, distance)[0].cpu().numpy()
            ew_embedaffs = 1 - get_edge_features_1d(sp_seg, offs,
                                                    emb_affs)[0][:, 0]
            mc_soln = torch.from_numpy(
                multicut_from_probas(sp_seg, edges.T,
                                     ew_embedaffs).astype(np.long)).to(device)

            mask = mc_soln[None] == torch.unique(mc_soln)[:, None, None]
            mc_soln = (mask *
                       (torch.arange(len(torch.unique(mc_soln)), device=device)
                        [:, None, None] + 1)).sum(0) - 1

            masses = (
                mc_soln[None] == torch.unique(mc_soln)[:, None,
                                                       None]).sum(-1).sum(-1)
            bg1id = masses.argmax()
            masses[bg1id] = 0
            bg1_mask = mc_soln == bg1id
            bg2_mask = mc_soln == masses.argmax()

            sp_seg = torch.from_numpy(sp_seg.astype(np.long)).to(device)

            mask = sp_seg[None] == torch.unique(sp_seg)[:, None, None]
            sp_seg = (mask *
                      (torch.arange(len(torch.unique(sp_seg)),
                                    device=sp_seg.device)[:, None, None] +
                       1)).sum(0) - 1

            sp_seg += 2
            sp_seg *= (bg1_mask == 0)
            sp_seg *= (bg2_mask == 0)
            sp_seg += bg2_mask

            mask = sp_seg[None] == torch.unique(sp_seg)[:, None, None]
            sp_seg = (mask *
                      (torch.arange(len(torch.unique(sp_seg)),
                                    device=sp_seg.device)[:, None, None] +
                       1)).sum(0) - 1
            sp_seg = sp_seg.cpu()

            raw -= raw.min()
            raw /= raw.max()
            edge_feat, edges = get_edge_features_1d(sp_seg.numpy(), offs[:4],
                                                    affs[:4])
            edges = edges.astype(np.long)

            gt_edge_weights = calculate_gt_edge_costs(
                torch.from_numpy(edges).to(device), sp_seg.to(device),
                torch.from_numpy(gt).to(device), 0.4)
            node_labeling = sp_seg.numpy()

            affs = affs.astype(np.float32)
            edge_feat = edge_feat.astype(np.float32)
            node_labeling = node_labeling.astype(np.float32)
            gt_edge_weights = gt_edge_weights.cpu().numpy().astype(np.float32)
            # edges = np.sort(edges, axis=-1)
            edges = edges.T
            # plt.imshow(sp_seg.cpu(), cmap=random_label_cmap(), interpolation="none");plt.show()
            # plt.imshow(bg1_mask.cpu());plt.show()
            # plt.imshow(bg2_mask.cpu());plt.show()
            new_pix_file = h5py.File(
                os.path.join(new_pix_dir, "pix_" + num + ".h5"), 'w')
            new_graph_file = h5py.File(
                os.path.join(new_graph_dir, "graph_" + num + ".h5"), 'w')

            # plt.imshow(gt_sp_projection, cmap=random_label_cmap(), interpolation="none");plt.show()

            new_pix_file.create_dataset("raw_2chnl", data=raw, chunks=True)
            new_pix_file.create_dataset("gt", data=gt, chunks=True)
            #
            new_graph_file.create_dataset("edges", data=edges, chunks=True)
            new_graph_file.create_dataset("edge_feat",
                                          data=edge_feat,
                                          chunks=True)
            new_graph_file.create_dataset("gt_edge_weights",
                                          data=gt_edge_weights,
                                          chunks=True)
            new_graph_file.create_dataset("node_labeling",
                                          data=node_labeling,
                                          chunks=True)
            new_graph_file.create_dataset("affinities", data=affs, chunks=True)
            new_graph_file.create_dataset("offsets",
                                          data=np.array([[1, 0], [0,
                                                                  1], [2, 0],
                                                         [0, 2], [4,
                                                                  0], [0, 4],
                                                         [8, 0], [0, 8],
                                                         [16, 0], [0, 16]]),
                                          chunks=True)
            #
            new_graph_file.close()
            new_pix_file.close()
Ejemplo n.º 14
0
        mask = pred_seg[None] == torch.unique(pred_seg)[:, None, None]
        pred_seg = (mask *
                    (torch.arange(len(torch.unique(pred_seg)),
                                  device=dev)[:, None, None] + 1)).sum(0) - 1
        # assert the segmentations are consecutive integers
        assert pred_seg.max() == len(torch.unique(pred_seg)) - 1
        assert superpixel_seg.max() == len(torch.unique(superpixel_seg)) - 1

        edges = torch.from_numpy(
            compute_rag(superpixel_seg.cpu().numpy()).uvIds().astype(
                np.long)).T.to(dev)
        dir_edges = torch.stack((torch.cat(
            (edges[0], edges[1])), torch.cat((edges[1], edges[0]))))

        gt_edges = calculate_gt_edge_costs(edges.T,
                                           superpixel_seg.squeeze(),
                                           gt_seg.squeeze(),
                                           thresh=0.5)

        mc_seg_old = None
        for itr in range(10):
            actions = gt_edges + torch.randn_like(gt_edges) * (itr / 10)
            # actions = torch.randn_like(gt_edges)
            # actions = torch.zeros_like(gt_edges)
            # actions[1:4] = 1.0
            actions -= actions.min()
            actions /= actions.max()

            mc_seg = multicut_from_probas(superpixel_seg.cpu().numpy(),
                                          edges.T.cpu().numpy(), actions)
            mc_seg = torch.from_numpy(mc_seg.astype(np.int64)).to(dev)