Beispiel #1
0
    def __init__(self,
                 affinities,
                 separating_channel,
                 offsets,
                 strides,
                 gt_affinities,
                 stop_cnt=500,
                 win_reward=-100,
                 use_bbox=False,
                 writer=None,
                 action_aggression=2,
                 penalize_diff_thresh=2000,
                 keep_first_state=True):
        super(MtxWtsdEnvUnet, self).__init__(affinities,
                                             separating_channel,
                                             offsets,
                                             strides,
                                             gt_affinities=gt_affinities,
                                             stop_cnt=stop_cnt,
                                             win_reward=win_reward)
        self.writer = writer
        self.gt_seg, _, _, _ = compute_mws_segmentation_cstm(
            gt_affinities, self.valid_edges, self.mtx_offsets,
            self.mtx_separating_channel, self.img_shape)
        self.use_bbox = use_bbox
        self.action_aggression = action_aggression
        self.penalize_diff_thresh = penalize_diff_thresh
        self.loosing_diff_thresh = penalize_diff_thresh * 3

        self.bbox = np.array(self.img_shape)
        self.keep_first_state = keep_first_state
        self._update_state_1()
Beispiel #2
0
def get_sp_graph(data, gt, scal=1.01):
    offsets = [[0, -1], [-1, 0], [-3, 0], [0, -3]]
    sep_chnl = 2
    shape = (128, 128)

    affinities = affutils.get_naive_affinities(data, offsets)
    gt_affinities, _ = compute_affinities(gt == 1, offsets)
    gt_affinities[sep_chnl:] *= -1
    gt_affinities[sep_chnl:] += +1
    affinities[sep_chnl:] *= -1
    affinities[sep_chnl:] += +1
    affinities[sep_chnl:] *= scal
    affinities = (affinities - (affinities * gt_affinities)) + gt_affinities

    affinities = affinities.clip(0, 1)

    valid_edges = get_valid_edges((len(offsets), ) + shape, offsets, sep_chnl,
                                  None, False)
    node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
        affinities.ravel(), valid_edges.ravel(), offsets, sep_chnl, shape)
    node_labeling = node_labeling - 1

    nodes = np.unique(node_labeling)
    try:
        assert all(nodes == np.array(range(len(nodes)), dtype=np.float))
    except:
        Warning("node ids are off")

    noisy_affinities = np.random.rand(*affinities.shape)
    noisy_affinities = noisy_affinities.clip(0, 1)
    noisy_affinities = affinities

    edge_feat, neighbors = get_edge_features_1d(node_labeling, offsets,
                                                noisy_affinities)
    gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                              node_labeling.squeeze(),
                                              gt.squeeze())

    edges = neighbors.astype(np.long)
    noisy_affinities = noisy_affinities.astype(np.float32)
    edge_feat = edge_feat.astype(np.float32)
    nodes = nodes.astype(np.float32)
    node_labeling = node_labeling.astype(np.float32)
    gt_edge_weights = gt_edge_weights.astype(np.float32)
    diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()

    edges = np.sort(edges, axis=-1)
    edges = edges.T
    # edges = np.concatenate((edges, np.stack((edges[1], edges[0]))), axis=1)

    # return node_labeling
    # print('imbalance: ', abs(gt_edge_weights.sum() - (len(gt_edge_weights) / 2)))

    return edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, nodes, noisy_affinities
Beispiel #3
0
    def get(self, idx):
        img = np.random.randn(*(self.dim + (3, ))) / 5

        ri1, ri2, ri3, ri4, ri5, ri6 = np.sign(np.random.randint(
            -100, 100)) * ((np.random.rand() * 2) + .5), np.sign(
                np.random.randint(-100, 100)) * (
                    (np.random.rand() * 2) + .5), (np.random.rand() * 4) + 3, (
                        np.random.rand() * 4) + 3, np.sign(
                            np.random.randint(-100, 100)) * (
                                (np.random.rand() * 2) + .5), np.sign(
                                    np.random.randint(-100, 100)) * (
                                        (np.random.rand() * 2) + .5)
        x = np.zeros(self.dim)
        x[:, :] = np.arange(img.shape[0])[np.newaxis, :]
        y = x.transpose()
        img += (np.sin(
            np.sqrt((x * ri1)**2 + ((self.dim[1] - y) * ri2)**2) * ri3 *
            np.pi / self.dim[0]))[..., np.newaxis]
        img += (np.sin(
            np.sqrt((x * ri5)**2 + ((self.dim[1] - y) * ri6)**2) * ri4 *
            np.pi / self.dim[1]))[..., np.newaxis]
        img = gaussian(np.clip(img, 0.1, 1), sigma=.8)
        circles = []
        cmps = []
        while len(circles) < self.n_ellips:
            mp = np.random.randint(self.min_r, self.dim[0] - self.min_r, 2)
            too_close = False
            for cmp in cmps:
                if np.linalg.norm(cmp - mp) < self.min_dist:
                    too_close = True
            if too_close:
                continue
            r = np.random.randint(self.min_r, self.max_r, 2)
            circles.append(draw.circle(mp[0], mp[1], r[0], shape=self.dim))
            cmps.append(mp)

        polys = []
        while len(polys) < self.n_polys:
            mp = np.random.randint(self.min_r, self.dim[0] - self.min_r, 2)
            too_close = False
            for cmp in cmps:
                if np.linalg.norm(cmp - mp) < self.min_dist // 2:
                    too_close = True
            if too_close:
                continue
            circle = draw.circle_perimeter(mp[0], mp[1], self.max_r)
            poly_vert = np.random.choice(len(circle[0]),
                                         np.random.randint(3, 6),
                                         replace=False)
            polys.append(
                draw.polygon(circle[0][poly_vert],
                             circle[1][poly_vert],
                             shape=self.dim))
            cmps.append(mp)

        rects = []
        while len(rects) < self.n_rect:
            mp = np.random.randint(self.min_r, self.dim[0] - self.min_r, 2)
            _len = np.random.randint(self.min_r // 2, self.max_r, (2, ))
            too_close = False
            for cmp in cmps:
                if np.linalg.norm(cmp - mp) < self.min_dist:
                    too_close = True
            if too_close:
                continue
            start = (mp[0] - _len[0], mp[1] - _len[1])
            rects.append(
                draw.rectangle(start,
                               extent=(_len[0] * 2, _len[1] * 2),
                               shape=self.dim))
            cmps.append(mp)

        for poly in polys:
            color = np.random.rand(3)
            while np.linalg.norm(color - self.ellips_color
                                 ) < self.col_diff or np.linalg.norm(
                                     color - self.rect_color) < self.col_diff:
                color = np.random.rand(3)
            img[poly[0], poly[1], :] = color
            img[poly[0], poly[1], :] += np.random.randn(len(poly[1]), 3) / 5

        cols = np.random.choice(np.arange(4, 11, 1).astype(np.float) / 10,
                                self.n_ellips,
                                replace=False)
        for i, ellipse in enumerate(circles):
            ri1, ri2, ri3, ri4, ri5, ri6 = np.sign(np.random.randint(
                -100, 100)) * ((np.random.rand() * 4) + 7), np.sign(
                    np.random.randint(-100, 100)) * (
                        (np.random.rand() * 4) +
                        7), (np.random.rand() +
                             1) * 3, (np.random.rand() + 1) * 3, np.sign(
                                 np.random.randint(-100, 100)) * (
                                     (np.random.rand() * 4) + 7), np.sign(
                                         np.random.randint(-100, 100)) * (
                                             (np.random.rand() * 4) + 7)
            img[ellipse[0], ellipse[1], :] = np.array([cols[i], 0.0, 0.0])
            img[ellipse[0],
                ellipse[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
                    np.sqrt((x[ellipse[0], ellipse[1]] * ri5)**2 + (
                        (self.dim[1] - y[ellipse[0], ellipse[1]]) * ri2)**2) *
                    ri3 * np.pi / self.dim[0]))[..., np.newaxis] * 0.15) + 0.2
            img[ellipse[0],
                ellipse[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
                    np.sqrt((x[ellipse[0], ellipse[1]] * ri6)**2 + (
                        (self.dim[1] - y[ellipse[0], ellipse[1]]) * ri1)**2) *
                    ri4 * np.pi / self.dim[1]))[..., np.newaxis] * 0.15) + 0.2
            # img[ellipse[0], ellipse[1], :] += np.random.randn(len(ellipse[1]), 3) / 10

        cols = np.random.choice(np.arange(4, 11, 1).astype(np.float) / 10,
                                self.n_rect,
                                replace=False)
        for rect in rects:
            ri1, ri2, ri3, ri4, ri5, ri6 = np.sign(np.random.randint(
                -100, 100)) * ((np.random.rand() * 4) + 7), np.sign(
                    np.random.randint(-100, 100)) * (
                        (np.random.rand() * 4) +
                        7), (np.random.rand() +
                             1) * 3, (np.random.rand() + 1) * 3, np.sign(
                                 np.random.randint(-100, 100)) * (
                                     (np.random.rand() * 4) + 7), np.sign(
                                         np.random.randint(-100, 100)) * (
                                             (np.random.rand() * 4) + 7)
            img[rect[0], rect[1], :] = np.array([0.0, 0.0, cols[i]])
            img[rect[0], rect[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
                np.sqrt((x[rect[0], rect[1]] * ri5)**2 +
                        ((self.dim[1] - y[rect[0], rect[1]]) * ri2)**2) * ri3 *
                np.pi / self.dim[0]))[..., np.newaxis] * 0.15) + 0.2
            img[rect[0], rect[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
                np.sqrt((x[rect[0], rect[1]] * ri1)**2 +
                        ((self.dim[1] - y[rect[0], rect[1]]) * ri6)**2) * ri4 *
                np.pi / self.dim[1]))[..., np.newaxis] * 0.15) + 0.2
            # img[rect[0], rect[1], :] += np.random.randn(*(rect[1].shape + (3,)))/10

        img = np.clip(img, 0, 1).astype(np.float32)

        smooth_image = gaussian(img, sigma=.2)

        # shape = np.array(smooth_image.shape[0:2]).astype(np.uint32).tolist()
        # taggedImg = vigra.taggedView(smooth_image, 'xyc')
        # edgeStrength = vigra.filters.structureTensorEigenvalues(taggedImg, 1.5, 1.9)[:, :, 0]
        # edgeStrength = edgeStrength.squeeze()
        # edgeStrength = np.array(edgeStrength).astype(np.float32)
        # seeds = vigra.analysis.localMinima(edgeStrength)
        # seeds = vigra.analysis.labelImageWithBackground(seeds)
        # gridGraph = nifty.graph.undirectedGridGraph(shape)
        # # oversegNodeWeighted = nifty.graph.nodeWeightedWatershedsSegmentation(graph=gridGraph, seeds=seeds.ravel(),
        # #                                                                      nodeWeights=edgeStrength.ravel())
        # # oversegNodeWeighted = oversegNodeWeighted.reshape(shape)
        #
        # gridGraphEdgeStrength = gridGraph.imageToEdgeMap(edgeStrength, mode='sum')
        # np.random.permutation(gridGraphEdgeStrength)
        # oversegEdgeWeightedA = nifty.graph.edgeWeightedWatershedsSegmentation(graph=gridGraph, seeds=seeds.ravel(),
        #                                                                    edgeWeights=gridGraphEdgeStrength)
        # oversegEdgeWeightedA = oversegEdgeWeightedA.reshape(shape)
        # interpixelShape = [2 * s - 1 for s in shape]
        # imgBig = vigra.sampling.resize(taggedImg, interpixelShape)
        # edgeStrength = vigra.filters.structureTensorEigenvalues(imgBig, 2 * 1.5, 2 * 1.9)[:, :, 0]
        # edgeStrength = edgeStrength.squeeze()
        # edgeStrength = np.array(edgeStrength)
        # gridGraphEdgeStrength = gridGraph.imageToEdgeMap(edgeStrength, mode='interpixel')
        # oversegEdgeWeightedB = nifty.graph.edgeWeightedWatershedsSegmentation(
        #     graph=gridGraph,
        #     seeds=seeds.ravel(),
        #     edgeWeights=gridGraphEdgeStrength)
        # oversegEdgeWeightedB = oversegEdgeWeightedB.reshape(shape)

        affinities = get_naive_affinities(smooth_image, offsets)
        affinities[:self.sep_chnl] *= -1
        affinities[:self.sep_chnl] += +1
        affinities[:self.sep_chnl] /= 1.3
        affinities[self.sep_chnl:] *= 1.3
        affinities = np.clip(affinities, 0, 1)
        #
        valid_edges = get_valid_edges((len(self.edge_offsets), ) + self.dim,
                                      self.edge_offsets, self.sep_chnl, None,
                                      False)
        node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
            affinities.ravel(), valid_edges.ravel(), offsets, self.sep_chnl,
            self.dim)

        return img, None
Beispiel #4
0
    def get(self, idx):
        radius = np.random.randint(max(self.shape) // 5, max(self.shape) // 3)
        mp = (np.random.randint(0 + radius, self.shape[0] - radius),
              np.random.randint(0 + radius, self.shape[1] - radius))
        # mp = self.mp
        data = np.zeros(shape=self.shape, dtype=np.float)
        gt = np.zeros(shape=self.shape, dtype=np.float)
        for y in range(self.shape[0]):
            for x in range(self.shape[1]):
                ly, lx = y - mp[0], x - mp[1]
                if (ly**2 + lx**2)**.5 <= radius:
                    data[y, x] += np.sin(x * 10 * np.pi / self.shape[1])
                    data[y, x] += np.sin(
                        np.sqrt(x**2 + y**2) * 20 * np.pi / self.shape[1])
                    # data[y, x] += 4
                    gt[y, x] = 1
                else:
                    data[y, x] += np.sin(y * 10 * np.pi / self.shape[1])
                    data[y, x] += np.sin(
                        np.sqrt(x**2 + (self.shape[1] - y)**2) * 10 * np.pi /
                        self.shape[1])
        data += 1
        # plt.imshow(data);plt.show()
        gt_affinities, _ = compute_affinities(gt == 1, offsets)

        seg_arbitrary = np.zeros_like(data)
        square_dict = {}
        i = 0
        granularity = 30
        for y in range(self.shape[0]):
            for x in range(self.shape[1]):
                if (x // granularity, y // granularity) not in square_dict:
                    square_dict[(x // granularity, y // granularity)] = i
                    i += 1
                seg_arbitrary[y, x] += square_dict[(x // granularity,
                                                    y // granularity)]
        seg_arbitrary += gt * 1000
        i = 0
        segs = np.unique(seg_arbitrary)
        seg_arb = np.zeros_like(seg_arbitrary)
        for seg in segs:
            seg_arb += (seg_arbitrary == seg) * i
            i += 1
        seg_arbitrary = seg_arb
        rag = feats.compute_rag(np.expand_dims(seg_arbitrary, axis=0))
        neighbors = rag.uvIds()

        affinities = get_naive_affinities(data, offsets)
        # edge_feat = get_edge_features_1d(seg_arbitrary, offsets, affinities)
        # self.edge_offsets = [[1, 0], [0, 1], [1, 0], [0, 1]]
        # self.sep_chnl = 2
        # affinities = np.stack((ndimage.sobel(data, axis=0), ndimage.sobel(data, axis=1)))
        # affinities = np.concatenate((affinities, affinities), axis=0)
        affinities[:self.sep_chnl] *= -1
        affinities[:self.sep_chnl] += +1
        affinities[self.sep_chnl:] /= 0.2
        #
        raw = torch.tensor(data).unsqueeze(0).unsqueeze(0).float()
        # if self.aff_pred is not None:
        #     gt_affinities[self.sep_chnl:] *= -1
        #     gt_affinities[self.sep_chnl:] += +1
        #     gt_affinities[:self.sep_chnl] /= 1.5
        # with torch.set_grad_enabled(False):
        #     affinities = self.aff_pred(raw.to(self.aff_pred.device))
        #     affinities = affinities.squeeze().detach().cpu().numpy()
        #     affinities[self.sep_chnl:] *= -1
        #     affinities[self.sep_chnl:] += +1
        #     affinities[:self.sep_chnl] /= 1.2

        valid_edges = get_valid_edges((len(self.edge_offsets), ) + self.shape,
                                      self.edge_offsets, self.sep_chnl, None,
                                      False)
        node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
            affinities.ravel(), valid_edges.ravel(), offsets, self.sep_chnl,
            self.shape)
        node_labeling = node_labeling - 1
        node_labeling = seg_arbitrary
        # plt.imshow(cm.prism(node_labeling/node_labeling.max()));plt.show()
        # plt.imshow(data);plt.show()
        neighbors = (node_labeling.ravel())[neighbors]
        nodes = np.unique(node_labeling)
        edge_feat = get_edge_features_1d(node_labeling, offsets, affinities)

        # for i, node in enumerate(nodes):
        #     seg = node_labeling == node
        #     masked_data = seg * data
        #     idxs = np.where(seg)
        #     dxs1 = np.stack(idxs).transpose()
        #     # y, x = bbox(np.expand_dims(seg, 0))
        #     # y, x = y[0], x[0]
        #     mass = np.sum(seg)
        #     # _, s, _ = np.linalg.svd(StandardScaler().fit_transform(seg))
        #     mean = np.sum(masked_data) / mass
        #     cm = np.sum(dxs1, axis=0) / mass
        #     var = np.var(data[idxs[0], idxs[1]])
        #
        #     mean = 0 if mean < .5 else 1
        #
        #     node_features[node] = torch.tensor([mean])

        offsets_3d = [[0, 0, -1], [0, -1, 0], [0, -3, 0], [0, 0, -3]]

        # rag = feats.compute_rag(np.expand_dims(node_labeling, axis=0))
        # edge_feat = feats.compute_affinity_features(rag, np.expand_dims(affinities, axis=1), offsets_3d)[:, :]
        # gt_edge_weights = feats.compute_affinity_features(rag, np.expand_dims(gt_affinities, axis=1), offsets_3d)[:, 0]
        gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                                  node_labeling.squeeze(),
                                                  gt.squeeze())
        # gt_edge_weights = utils.calculate_naive_gt_edge_costs(neighbors, node_features).unsqueeze(-1)
        # affs = np.expand_dims(affinities, axis=1)
        # boundary_input = np.mean(affs, axis=0)
        # plt.imshow(multicut_from_probas(node_labeling, neighbors, gt_edge_weights, boundary_input));plt.show()

        # neighs = np.empty((10, 2))
        # gt_neighs = np.empty(10)
        # neighs[0] = neighbors[30]
        # gt_neighs[0] = gt_edge_weights[30]
        # i = 0
        # while True:
        #     for idx, n in enumerate(neighbors):
        #         if n[0] in neighs.ravel() or n[1] in neighs.ravel():
        #             neighs[i] = n
        #             gt_neighs[i] = gt_edge_weights[idx]
        #             i += 1
        #             if i == 10:
        #                 break
        #     if i == 10:
        #         break
        #
        # nodes = nodes[np.unique(neighs.ravel())]
        # node_features = nodes
        # neighbors = neighs

        edges = torch.from_numpy(neighbors.astype(np.long))
        raw = raw.squeeze()
        edge_feat = torch.from_numpy(edge_feat.astype(np.float32))
        nodes = torch.from_numpy(nodes.astype(np.float32))
        # gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.float32))
        # affinities = torch.from_numpy(affinities.astype(np.float32))
        affinities = torch.from_numpy(gt_affinities.astype(np.float32))
        gt_affinities = torch.from_numpy(gt_affinities.astype(np.float32))
        node_labeling = torch.from_numpy(node_labeling.astype(np.float32))

        gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.float32))
        # noise = torch.randn_like(edge_feat) / 3
        # edge_feat += noise
        # edge_feat = torch.min(edge_feat, torch.ones_like(edge_feat))
        # edge_feat = torch.max(edge_feat, torch.zeros_like(edge_feat))
        diff_to_gt = (edge_feat[:, 0] - gt_edge_weights).abs().sum()

        node_features, angles = get_stacked_node_data(nodes,
                                                      edges,
                                                      node_labeling,
                                                      raw,
                                                      size=[32, 32])
        # plt.imshow(node_features.view(-1, 32));
        # plt.show()

        edges = edges.t().contiguous()
        edges = torch.cat((edges, torch.stack((edges[1], edges[0]))), dim=1)

        return edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles
Beispiel #5
0
    def get(self, idx):
        n_disc = np.random.randint(8, 10)
        rads = []
        mps = []
        for disc in range(n_disc):
            radius = np.random.randint(
                max(self.shape) // 18,
                max(self.shape) // 15)
            touching = True
            while touching:
                mp = np.array([
                    np.random.randint(0 + radius, self.shape[0] - radius),
                    np.random.randint(0 + radius, self.shape[1] - radius)
                ])
                touching = False
                for other_rad, other_mp in zip(rads, mps):
                    diff = mp - other_mp
                    if (diff**2).sum()**.5 <= radius + other_rad + 2:
                        touching = True
            rads.append(radius)
            mps.append(mp)

        # take static image
        # rads = self.rads
        # mps = self.mps

        data = np.zeros(shape=self.shape, dtype=np.float)
        gt = np.zeros(shape=self.shape, dtype=np.float)
        for y in range(self.shape[0]):
            for x in range(self.shape[1]):
                bg = True
                for radius, mp in zip(rads, mps):
                    ly, lx = y - mp[0], x - mp[1]
                    if (ly**2 + lx**2)**.5 <= radius:
                        data[y, x] += np.cos(
                            np.sqrt((x - self.shape[1])**2 + y**2) * 50 *
                            np.pi / self.shape[1])
                        data[y, x] += np.cos(
                            np.sqrt(x**2 + y**2) * 50 * np.pi / self.shape[1])
                        # data[y, x] += 6
                        gt[y, x] = 1
                        bg = False
                if bg:
                    data[y, x] += np.cos(y * 40 * np.pi / self.shape[0])
                    data[y, x] += np.cos(
                        np.sqrt(x**2 + (self.shape[0] - y)**2) * 30 * np.pi /
                        self.shape[1])
        data += 1
        # plt.imshow(data);plt.show()
        # if self.no_suppix:
        #     raw = torch.from_numpy(data).float()
        #     return raw.unsqueeze(0), torch.from_numpy(gt.astype(np.long))
        # return torch.stack((torch.rand_like(raw), raw, torch.rand_like(raw))), torch.from_numpy(gt.astype(np.long))

        affinities = affutils.get_naive_affinities(data, self.offsets)
        gt_affinities, _ = compute_affinities(gt == 1, self.offsets)
        gt_affinities[self.sep_chnl:] *= -1
        gt_affinities[self.sep_chnl:] += +1
        affinities[self.sep_chnl:] *= -1
        affinities[self.sep_chnl:] += +1
        # affinities[:self.sep_chnl] /= 1.1
        affinities[self.sep_chnl:] *= 1.01
        affinities = (affinities -
                      (affinities * gt_affinities)) + gt_affinities

        # affinities[self.sep_chnl:] *= -1
        # affinities[self.sep_chnl:] += +1
        # affinities[self.sep_chnl:] *= 4
        affinities = affinities.clip(0, 1)

        valid_edges = get_valid_edges((len(self.offsets), ) + self.shape,
                                      self.offsets, self.sep_chnl, None, False)
        node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
            affinities.ravel(), valid_edges.ravel(), self.offsets,
            self.sep_chnl, self.shape)
        node_labeling = node_labeling - 1
        # rag = elf.segmentation.features.compute_rag(np.expand_dims(node_labeling, axis=0))
        # neighbors = rag.uvIds()
        i = 0

        # node_labeling = gt * 5000 + node_labeling
        # segs = np.unique(node_labeling)
        #
        # new_labeling = np.zeros_like(node_labeling)
        # for seg in segs:
        #     i += 1
        #     new_labeling += (node_labeling == seg) * i
        #
        # node_labeling = new_labeling - 1

        # gt_labeling, _, _, _ = compute_mws_segmentation_cstm(gt_affinities.ravel(),
        #                                                      valid_edges.ravel(),
        #                                                      offsets,
        #                                                      self.shape)
        #                                                      self.sep_chnl,

        nodes = np.unique(node_labeling)
        try:
            assert all(nodes == np.array(range(len(nodes)), dtype=np.float))
        except:
            Warning("node ids are off")

        noisy_affinities = np.random.rand(*affinities.shape)
        noisy_affinities = noisy_affinities.clip(0, 1)
        noisy_affinities = affinities

        edge_feat, neighbors = get_edge_features_1d(node_labeling,
                                                    self.offsets,
                                                    noisy_affinities)
        gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                                  node_labeling.squeeze(),
                                                  gt.squeeze())

        if self.less:
            raw = torch.from_numpy(data).float()
            node_labeling = torch.from_numpy(node_labeling.astype(np.float32))
            gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.long))
            edges = torch.from_numpy(neighbors.astype(np.long))
            edges = edges.t().contiguous()
            edges = torch.cat((edges, torch.stack((edges[1], edges[0]))),
                              dim=1)
            return raw.unsqueeze(0), node_labeling, torch.from_numpy(
                gt.astype(np.long)), gt_edge_weights, edges

        # affs = np.expand_dims(affinities, axis=1)
        # boundary_input = np.mean(affs, axis=0)
        # gt1 = gutils.multicut_from_probas(node_labeling.astype(np.float32), neighbors.astype(np.float32),
        #                                  gt_edge_weights.astype(np.float32), boundary_input.astype(np.float32))

        # plt.imshow(node_labeling)
        # plt.show()
        # plt.imshow(gt1)
        # plt.show()

        gt = torch.from_numpy(gt.astype(np.float32)).squeeze().float()

        edges = torch.from_numpy(neighbors.astype(np.long))
        raw = torch.tensor(data).squeeze().float()
        noisy_affinities = torch.tensor(noisy_affinities).squeeze().float()
        edge_feat = torch.from_numpy(edge_feat.astype(np.float32))
        nodes = torch.from_numpy(nodes.astype(np.float32))
        node_labeling = torch.from_numpy(node_labeling.astype(np.float32))
        gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.float32))
        diff_to_gt = (edge_feat[:, 0] - gt_edge_weights).abs().sum().item()
        # node_features, angles = get_stacked_node_data(nodes, edges, node_labeling, raw, size=[32, 32])

        # file = h5py.File("/g/kreshuk/hilt/projects/rags/" + "rag_" + str(self.fidx) + ".h5", "w")
        # file.create_dataset("edges", data=edges.numpy())
        # self.fidx += 1

        if self.no_suppix:
            raw = torch.from_numpy(data).float()
            return raw.unsqueeze(0), torch.from_numpy(gt.numpy().astype(
                np.long))

        edges = edges.t().contiguous()
        edges = torch.cat((edges, torch.stack((edges[1], edges[0]))), dim=1)

        # print('imbalance: ', abs(gt_edge_weights.sum() - (len(gt_edge_weights) / 2)))

        return edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, noisy_affinities, gt
Beispiel #6
0
def get_pix_data(length=50000, shape=(128, 128), radius=72):
    dim = (256, 256)
    edge_offsets = [
        [0, -1],
        [-1, 0],
        # direct 3d nhood for attractive edges
        # [0, -1], [-1, 0]]
        [-3, 0],
        [0, -3],
        [-6, 0],
        [0, -6]
    ]
    sep_chnl = 2
    n_ellips = 5
    n_polys = 10
    n_rect = 5
    ellips_color = np.array([1, 0, 0], dtype=np.float)
    rect_color = np.array([0, 0, 1], dtype=np.float)
    col_diff = 0.4
    min_r, max_r = 10, 20
    min_dist = max_r

    img = np.random.randn(*(dim + (3, ))) / 5
    gt = np.zeros(dim)

    ri1, ri2, ri3, ri4, ri5, ri6 = np.sign(np.random.randint(-100, 100)) * (
        (np.random.rand() * 2) + .5), np.sign(np.random.randint(-100, 100)) * (
            (np.random.rand() * 2) + .5), (np.random.rand() * 4) + 3, (
                np.random.rand() * 4) + 3, np.sign(np.random.randint(
                    -100, 100)) * ((np.random.rand() * 2) + .5), np.sign(
                        np.random.randint(-100, 100)) * (
                            (np.random.rand() * 2) + .5)
    x = np.zeros(dim)
    x[:, :] = np.arange(img.shape[0])[np.newaxis, :]
    y = x.transpose()
    img += (np.sin(
        np.sqrt((x * ri1)**2 + ((dim[1] - y) * ri2)**2) * ri3 * np.pi /
        dim[0]))[..., np.newaxis]
    img += (np.sin(
        np.sqrt((x * ri5)**2 + ((dim[1] - y) * ri6)**2) * ri4 * np.pi /
        dim[1]))[..., np.newaxis]
    img = gaussian(np.clip(img, 0.1, 1), sigma=.8)
    circles = []
    cmps = []
    while len(circles) < n_ellips:
        mp = np.random.randint(min_r, dim[0] - min_r, 2)
        too_close = False
        for cmp in cmps:
            if np.linalg.norm(cmp - mp) < min_dist:
                too_close = True
        if too_close:
            continue
        r = np.random.randint(min_r, max_r, 2)
        circles.append(draw.circle(mp[0], mp[1], r[0], shape=dim))
        cmps.append(mp)

    polys = []
    while len(polys) < n_polys:
        mp = np.random.randint(min_r, dim[0] - min_r, 2)
        too_close = False
        for cmp in cmps:
            if np.linalg.norm(cmp - mp) < min_dist // 2:
                too_close = True
        if too_close:
            continue
        circle = draw.circle_perimeter(mp[0], mp[1], max_r)
        poly_vert = np.random.choice(len(circle[0]),
                                     np.random.randint(3, 6),
                                     replace=False)
        polys.append(
            draw.polygon(circle[0][poly_vert], circle[1][poly_vert],
                         shape=dim))
        cmps.append(mp)

    rects = []
    while len(rects) < n_rect:
        mp = np.random.randint(min_r, dim[0] - min_r, 2)
        _len = np.random.randint(min_r // 2, max_r, (2, ))
        too_close = False
        for cmp in cmps:
            if np.linalg.norm(cmp - mp) < min_dist:
                too_close = True
        if too_close:
            continue
        start = (mp[0] - _len[0], mp[1] - _len[1])
        rects.append(
            draw.rectangle(start, extent=(_len[0] * 2, _len[1] * 2),
                           shape=dim))
        cmps.append(mp)

    for poly in polys:
        color = np.random.rand(3)
        while np.linalg.norm(color -
                             ellips_color) < col_diff or np.linalg.norm(
                                 color - rect_color) < col_diff:
            color = np.random.rand(3)
        img[poly[0], poly[1], :] = color
        img[poly[0], poly[1], :] += np.random.randn(len(poly[1]), 3) / 5

    cols = np.random.choice(np.arange(4, 11, 1).astype(np.float) / 10,
                            n_ellips,
                            replace=False)
    for i, ellipse in enumerate(circles):
        gt[ellipse[0], ellipse[1]] = 1 + (i / 10)
        ri1, ri2, ri3, ri4, ri5, ri6 = np.sign(np.random.randint(
            -100, 100)) * ((np.random.rand() * 4) + 7), np.sign(
                np.random.randint(-100, 100)) * (
                    (np.random.rand() * 4) + 7), (np.random.rand() + 1) * 3, (
                        np.random.rand() + 1) * 3, np.sign(
                            np.random.randint(-100, 100)) * (
                                (np.random.rand() * 4) + 7), np.sign(
                                    np.random.randint(-100, 100)) * (
                                        (np.random.rand() * 4) + 7)
        img[ellipse[0], ellipse[1], :] = np.array([cols[i], 0.0, 0.0])
        img[ellipse[0], ellipse[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[ellipse[0], ellipse[1]] * ri5)**2 +
                    ((dim[1] - y[ellipse[0], ellipse[1]]) * ri2)**2) * ri3 *
            np.pi / dim[0]))[..., np.newaxis] * 0.15) + 0.2
        img[ellipse[0], ellipse[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[ellipse[0], ellipse[1]] * ri6)**2 +
                    ((dim[1] - y[ellipse[0], ellipse[1]]) * ri1)**2) * ri4 *
            np.pi / dim[1]))[..., np.newaxis] * 0.15) + 0.2
        # img[ellipse[0], ellipse[1], :] += np.random.randn(len(ellipse[1]), 3) / 10

    cols = np.random.choice(np.arange(4, 11, 1).astype(np.float) / 10,
                            n_rect,
                            replace=False)
    for i, rect in enumerate(rects):
        gt[rect[0], rect[1]] = 2 + (i / 10)
        ri1, ri2, ri3, ri4, ri5, ri6 = np.sign(np.random.randint(
            -100, 100)) * ((np.random.rand() * 4) + 7), np.sign(
                np.random.randint(-100, 100)) * (
                    (np.random.rand() * 4) + 7), (np.random.rand() + 1) * 3, (
                        np.random.rand() + 1) * 3, np.sign(
                            np.random.randint(-100, 100)) * (
                                (np.random.rand() * 4) + 7), np.sign(
                                    np.random.randint(-100, 100)) * (
                                        (np.random.rand() * 4) + 7)
        img[rect[0], rect[1], :] = np.array([0.0, 0.0, cols[i]])
        img[rect[0], rect[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[rect[0], rect[1]] * ri5)**2 +
                    ((dim[1] - y[rect[0], rect[1]]) * ri2)**2) * ri3 * np.pi /
            dim[0]))[..., np.newaxis] * 0.15) + 0.2
        img[rect[0], rect[1], :] += np.array([1.0, 1.0, 0.0]) * ((np.sin(
            np.sqrt((x[rect[0], rect[1]] * ri1)**2 +
                    ((dim[1] - y[rect[0], rect[1]]) * ri6)**2) * ri4 * np.pi /
            dim[1]))[..., np.newaxis] * 0.15) + 0.2
        # img[rect[0], rect[1], :] += np.random.randn(*(rect[1].shape + (3,)))/10

    img = np.clip(img, 0, 1)

    affinities = get_naive_affinities(gaussian(np.clip(img, 0, 1), sigma=.2),
                                      edge_offsets)
    affinities[:sep_chnl] *= -1
    affinities[:sep_chnl] += +1
    affinities[:sep_chnl] /= 1.3
    affinities[sep_chnl:] *= 1.3
    affinities = np.clip(affinities, 0, 1)
    #
    valid_edges = get_valid_edges((len(edge_offsets), ) + dim, edge_offsets,
                                  sep_chnl, None, False)
    node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
        affinities.ravel(), valid_edges.ravel(), edge_offsets, sep_chnl, dim)
    node_labeling = node_labeling - 1
    nodes = np.unique(node_labeling)
    try:
        assert all(nodes == np.array(range(len(nodes)), dtype=np.float))
    except:
        Warning("node ids are off")

    edge_feat, neighbors = get_edge_features_1d(node_labeling, edge_offsets,
                                                affinities)
    gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                              node_labeling.squeeze(),
                                              gt.squeeze())
    edges = neighbors.astype(np.long)

    gt_seg = get_current_soln(gt_edge_weights, node_labeling, edges)
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
    ax1.imshow(cm.prism(gt / gt.max()))
    ax1.set_title('gt')
    ax2.imshow(cm.prism(node_labeling / node_labeling.max()))
    ax2.set_title('sp')
    ax3.imshow(cm.prism(gt_seg / gt_seg.max()))
    ax3.set_title('mc')
    plt.show()

    affinities = affinities.astype(np.float32)
    edge_feat = edge_feat.astype(np.float32)
    nodes = nodes.astype(np.float32)
    node_labeling = node_labeling.astype(np.float32)
    gt_edge_weights = gt_edge_weights.astype(np.float32)
    diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()

    edges = np.sort(edges, axis=-1)
    edges = edges.T

    return img, gt, edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, nodes, affinities
Beispiel #7
0
    def create_dsets(self, num):
        for file_index in range(num):
            n_disc = np.random.randint(25, 30)
            rads = []
            mps = []
            for disc in range(n_disc):
                radius = np.random.randint(
                    max(self.shape) // 25,
                    max(self.shape) // 20)
                touching = True
                while touching:
                    mp = np.array([
                        np.random.randint(0 + radius, self.shape[0] - radius),
                        np.random.randint(0 + radius, self.shape[1] - radius)
                    ])
                    touching = False
                    for other_rad, other_mp in zip(rads, mps):
                        diff = mp - other_mp
                        if (diff**2).sum()**.5 <= radius + other_rad + 2:
                            touching = True
                rads.append(radius)
                mps.append(mp)

            data = np.zeros(shape=self.shape, dtype=np.float)
            gt = np.zeros(shape=self.shape, dtype=np.float)
            for y in range(self.shape[0]):
                for x in range(self.shape[1]):
                    bg = True
                    for radius, mp in zip(rads, mps):
                        ly, lx = y - mp[0], x - mp[1]
                        if (ly**2 + lx**2)**.5 <= radius:
                            data[y, x] += np.cos(
                                np.sqrt((x - self.shape[1])**2 + y**2) * 50 *
                                np.pi / self.shape[1])
                            data[y, x] += np.cos(
                                np.sqrt(x**2 + y**2) * 50 * np.pi /
                                self.shape[1])
                            # data[y, x] += 6
                            gt[y, x] = 1
                            bg = False
                    if bg:
                        data[y, x] += np.cos(y * 40 * np.pi / self.shape[0])
                        data[y, x] += np.cos(
                            np.sqrt(x**2 + (self.shape[0] - y)**2) * 30 *
                            np.pi / self.shape[1])
            data += 1
            # plt.imshow(data);plt.show()
            if self.no_suppix:
                raw = torch.from_numpy(data).float()
                return raw.unsqueeze(0), torch.from_numpy(gt.astype(np.long))
                # return torch.stack((torch.rand_like(raw), raw, torch.rand_like(raw))), torch.from_numpy(gt.astype(np.long))

            affinities = affutils.get_naive_affinities(data, self.offsets)
            gt_affinities, _ = compute_affinities(gt == 1, self.offsets)
            gt_affinities[self.sep_chnl:] *= -1
            gt_affinities[self.sep_chnl:] += +1
            affinities[self.sep_chnl:] *= -1
            affinities[self.sep_chnl:] += +1
            # affinities[:self.sep_chnl] /= 1.1
            affinities[self.sep_chnl:] *= 1.01
            affinities = (affinities -
                          (affinities * gt_affinities)) + gt_affinities

            # affinities[self.sep_chnl:] *= -1
            # affinities[self.sep_chnl:] += +1
            # affinities[self.sep_chnl:] *= 4
            affinities = affinities.clip(0, 1)

            valid_edges = get_valid_edges((len(self.offsets), ) + self.shape,
                                          self.offsets, self.sep_chnl, None,
                                          False)
            node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
                affinities.ravel(), valid_edges.ravel(), self.offsets,
                self.sep_chnl, self.shape)
            node_labeling = node_labeling - 1
            nodes = np.unique(node_labeling)
            try:
                assert all(
                    nodes == np.array(range(len(nodes)), dtype=np.float))
            except:
                Warning("node ids are off")

            noisy_affinities = affinities

            edge_feat, neighbors = get_edge_features_1d(
                node_labeling, self.offsets, noisy_affinities)
            gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                                      node_labeling.squeeze(),
                                                      gt.squeeze())

            while abs(gt_edge_weights.sum() - (len(gt_edge_weights) / 2)) > 1:
                edge_idx = np.random.choice(np.arange(len(gt_edge_weights)),
                                            p=torch.softmax(torch.from_numpy(
                                                (gt_edge_weights == 0).astype(
                                                    np.float)),
                                                            dim=0).numpy())
                if gt_edge_weights[edge_idx] != 0.0:
                    continue

                # print(abs(gt_edge_weights.sum() - (len(gt_edge_weights) / 2)))
                edge = neighbors[edge_idx].astype(np.int)
                # merge superpixel
                diff = edge[0] - edge[1]

                mass = (node_labeling == edge[0]).sum()
                node_labeling = node_labeling - (node_labeling
                                                 == edge[0]) * diff
                new_mass = (node_labeling == edge[1]).sum()
                try:
                    assert new_mass >= mass
                except:
                    a = 1

                # if edge_idx == 0:
                #     neighbors = neighbors[1:]
                #     gt_edge_weights = gt_edge_weights[1:]
                # elif edge_idx == len(gt_edge_weights):
                #     neighbors = neighbors[:-1]
                #     gt_edge_weights = gt_edge_weights[:-1]
                # else:
                #     neighbors = np.concatenate((neighbors[:edge_idx], neighbors[edge_idx+1:]), axis=0)
                #     gt_edge_weights = np.concatenate((gt_edge_weights[:edge_idx], gt_edge_weights[edge_idx+1:]), axis=0)
                #
                # neighbors[neighbors == edge[0]] == edge[1]

                edge_feat, neighbors = get_edge_features_1d(
                    node_labeling, self.offsets, noisy_affinities)
                gt_edge_weights = calculate_gt_edge_costs(
                    neighbors, node_labeling.squeeze(), gt.squeeze())

            edge_feat, neighbors = get_edge_features_1d(
                node_labeling, self.offsets, noisy_affinities)
            gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                                      node_labeling.squeeze(),
                                                      gt.squeeze())

            gt = torch.from_numpy(gt.astype(np.float32)).squeeze().float()

            edges = torch.from_numpy(neighbors.astype(np.long))
            raw = torch.tensor(data).squeeze().float()
            noisy_affinities = torch.tensor(noisy_affinities).squeeze().float()
            edge_feat = torch.from_numpy(edge_feat.astype(np.float32))
            nodes = torch.from_numpy(nodes.astype(np.float32))
            node_labeling = torch.from_numpy(node_labeling.astype(np.float32))
            gt_edge_weights = torch.from_numpy(
                gt_edge_weights.astype(np.float32))
            diff_to_gt = (edge_feat[:, 0] - gt_edge_weights).abs().sum()
            edges = edges.t().contiguous()
            edges = torch.cat((edges, torch.stack((edges[1], edges[0]))),
                              dim=1)

            self.write_to_h5(
                '/g/kreshuk/hilt/projects/fewShotLearning/mutexWtsd/data/storage/balanced_graphs/balanced_graph_data'
                + str(file_index) + '.h5', edges, edge_feat, diff_to_gt,
                gt_edge_weights, node_labeling, raw, nodes, noisy_affinities,
                gt)
Beispiel #8
0
 def _calc_wtsd(self):
     return compute_mws_segmentation_cstm(self.current_affs.ravel(), self.valid_edges.ravel(), self.mtx_offsets,
                                          self.mtx_separating_channel, self.img_shape)