Exemple #1
0
    def input_function(self, tensor):
        # print("affs: in shape", tensor.shape)
        if self.ignore_label is not None:
            # output.shape = (C, Z, Y, X)
            output, mask = compute_affinities(tensor,
                                              self.offsets,
                                              ignore_label=self.ignore_label,
                                              have_ignore_label=True)
            if self.learn_ignore_transitions:
                output, mask = self.include_ignore_transitions(
                    output, mask, tensor)
        else:
            output, mask = compute_affinities(tensor, self.offsets)

        # FIXME what does this do, need to refactor !
        # hack for platyneris data
        platy_hack = False
        if platy_hack:
            chan_mask = mask[1].astype('bool')
            output[0][chan_mask] = np.min(output[:2], axis=0)[chan_mask]

            chan_mask = mask[2].astype('bool')
            output[0][chan_mask] = np.minimum(output[0], output[2])[chan_mask]

        # Cast to be sure
        if not output.dtype == self.dtype:
            output = output.astype(self.dtype)
        #
        # print("affs: shape before binary", output.shape)
        if self.segmentation_to_binary:
            output = np.concatenate(
                (self.to_binary_segmentation(tensor)[None], output), axis=0)
        # print("affs: shape after binary", output.shape)

        # print("affs: shape before mask", output.shape)
        # We might want to carry the mask along.
        # If this is the case, we insert it after the targets.
        if self.retain_mask:
            mask = mask.astype(self.dtype, copy=False)
            if self.segmentation_to_binary:
                if self.ignore_label is None:
                    additional_mask = np.ones((1, ) + tensor.shape,
                                              dtype=self.dtype)
                else:
                    additional_mask = (tensor[None] !=
                                       self.ignore_label).astype(self.dtype)
                mask = np.concatenate([additional_mask, mask], axis=0)
            output = np.concatenate((output, mask), axis=0)
        # print("affs: shape after mask", output.shape)

        # We might want to carry the segmentation along for validation.
        # If this is the case, we insert it before the targets.
        if self.retain_segmentation:
            # Add a channel axis to tensor to make it (C, Z, Y, X) before cating to output
            output = np.concatenate(
                (tensor[None].astype(self.dtype, copy=False), output), axis=0)

        # print("affs: out shape", output.shape)
        return output
    def to_affinities(self, seg, mask):
        seg[~mask] = 0
        affs, aff_mask = compute_affinities(seg, self.offsets, have_ignore_label=True)
        aff_mask = aff_mask.astype('bool')
        affs = 1. - affs

        mask_transition, aff_mask2 = compute_affinities(mask, self.offsets)
        mask_transition[~aff_mask2.astype('bool')] = 1
        aff_mask[~mask_transition.astype('bool')] = True
        return affs, aff_mask
Exemple #3
0
    def __call__(self, labels):
        dtype = "uint64"
        if np.dtype(labels.dtype) in (np.dtype("int16"), np.dtype("int32"), np.dtype("int64")):
            dtype = "int64"
        labels = ensure_spatial_array(labels, self.ndim, dtype=dtype)
        affs, mask = compute_affinities(labels, self.offsets,
                                        have_ignore_label=self.ignore_label is not None,
                                        ignore_label=0 if self.ignore_label is None else self.ignore_label)
        # we use the "disaffinity" convention for training; i.e. 1 means repulsive, 0 attractive
        affs = 1. - affs

        # remove transitions to the ignore label from the mask
        if self.ignore_label is not None and self.include_ignore_transitions:
            affs, mask = self.add_ignore_transitions(affs, mask, labels)

        if self.add_binary_target:
            binary = labels_to_binary(labels)[None].astype(affs.dtype)
            assert binary.ndim == affs.ndim
            affs = np.concatenate([binary, affs], axis=0)

        if self.add_mask:
            if self.add_binary_target:
                if self.ignore_label is None:
                    mask_for_bin = np.ones((1,) + labels.shape, dtype=mask.dtype)
                else:
                    mask_for_bin = (labels != self.ignore_label)[None].astype(mask.dtype)
                assert mask.ndim == mask_for_bin.ndim
                mask = np.concatenate([mask_for_bin, mask], axis=0)
            assert affs.shape == mask.shape
            affs = np.concatenate([affs, mask.astype(affs.dtype)], axis=0)

        return affs
Exemple #4
0
 def add_ignore_transitions(self, affs, mask, labels):
     ignore_seg = (labels == self.ignore_label).astype(labels.dtype)
     ignore_transitions, invalid_mask = compute_affinities(ignore_seg, self.offsets)
     invalid_mask = np.logical_not(invalid_mask)
     # NOTE affinity convention returned by affogato: transitions are marked by 0
     ignore_transitions = ignore_transitions == 0
     ignore_transitions[invalid_mask] = 0
     affs[ignore_transitions] = 1
     mask[ignore_transitions] = 1
     return affs, mask
Exemple #5
0
def from_foreground_mask_to_edge_mask(foreground_mask,
                                      offsets,
                                      mask_used_edges=None):
    _, valid_edges = compute_affinities(foreground_mask.astype('uint64'),
                                        offsets.tolist(), True, 0)

    if mask_used_edges is not None:
        return np.logical_and(valid_edges, mask_used_edges)
    else:
        return valid_edges.astype('bool')
Exemple #6
0
    def tensor_function(self, tensor):
        # for 2 d input, we need singleton input
        if self.dim == 2:
            assert tensor.shape[0] == 1
            tensor = tensor[0]

        outputs = []
        for ii, bs in enumerate(self.block_shapes):
            # if the block shape is all ones, we can compute normal affinities
            # with nearest neighbor offsets. This should yield the same result,
            # but should be more efficient.
            original_scale = all(s == 1 for s in bs)
            if original_scale:
                if self.original_scale_offsets is None:
                    offsets = [[0 if i != d else -1 for i in range(self.dim)]
                               for d in range(self.dim)]
                else:
                    offsets = self.original_scale_offsets
                output, mask = compute_affinities(
                    tensor.squeeze().astype('uint64'),
                    offsets,
                    ignore_label=0
                    if self.ignore_label is None else self.ignore_label,
                    have_ignore_label=False
                    if self.ignore_label is None else True)
            else:
                output, mask = compute_multiscale_affinities(
                    tensor.squeeze().astype('uint64'),
                    bs,
                    ignore_label=0
                    if self.ignore_label is None else self.ignore_label,
                    have_ignore_label=False
                    if self.ignore_label is None else True)

            # Cast to be sure
            if not output.dtype == self.dtype:
                output = output.astype(self.dtype)

            # We might want to carry the mask along.
            # If this is the case, we insert it after the targets.
            if self.retain_mask:
                output = np.concatenate(
                    (output, mask.astype(self.dtype, copy=False)), axis=0)
            # We might want to carry the segmentation along for validation.
            # If this is the case, we insert it before the targets for the original scale.
            if self.retain_segmentation:
                ds_target = self.downsamplers[ii](tensor.astype(self.dtype,
                                                                copy=False))
                if ds_target.ndim != output.ndim:
                    assert ds_target.ndim == output.ndim - 1
                    ds_target = ds_target[None]
                output = np.concatenate((ds_target, output), axis=0)
            outputs.append(output)

        return outputs
Exemple #7
0
def get_sp_graph(data, gt, scal=1.01):
    offsets = [[0, -1], [-1, 0], [-3, 0], [0, -3]]
    sep_chnl = 2
    shape = (128, 128)

    affinities = affutils.get_naive_affinities(data, offsets)
    gt_affinities, _ = compute_affinities(gt == 1, offsets)
    gt_affinities[sep_chnl:] *= -1
    gt_affinities[sep_chnl:] += +1
    affinities[sep_chnl:] *= -1
    affinities[sep_chnl:] += +1
    affinities[sep_chnl:] *= scal
    affinities = (affinities - (affinities * gt_affinities)) + gt_affinities

    affinities = affinities.clip(0, 1)

    valid_edges = get_valid_edges((len(offsets), ) + shape, offsets, sep_chnl,
                                  None, False)
    node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
        affinities.ravel(), valid_edges.ravel(), offsets, sep_chnl, shape)
    node_labeling = node_labeling - 1

    nodes = np.unique(node_labeling)
    try:
        assert all(nodes == np.array(range(len(nodes)), dtype=np.float))
    except:
        Warning("node ids are off")

    noisy_affinities = np.random.rand(*affinities.shape)
    noisy_affinities = noisy_affinities.clip(0, 1)
    noisy_affinities = affinities

    edge_feat, neighbors = get_edge_features_1d(node_labeling, offsets,
                                                noisy_affinities)
    gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                              node_labeling.squeeze(),
                                              gt.squeeze())

    edges = neighbors.astype(np.long)
    noisy_affinities = noisy_affinities.astype(np.float32)
    edge_feat = edge_feat.astype(np.float32)
    nodes = nodes.astype(np.float32)
    node_labeling = node_labeling.astype(np.float32)
    gt_edge_weights = gt_edge_weights.astype(np.float32)
    diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()

    edges = np.sort(edges, axis=-1)
    edges = edges.T
    # edges = np.concatenate((edges, np.stack((edges[1], edges[0]))), axis=1)

    # return node_labeling
    # print('imbalance: ', abs(gt_edge_weights.sum() - (len(gt_edge_weights) / 2)))

    return edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, nodes, noisy_affinities
    def test_affs_2d(self):
        from affogato.affinities import compute_affinities
        shape = (100, 100)
        labels = np.random.randint(0, 100, size=shape)
        offsets = [[-1, 0], [0, -1], [-5, 0], [0, -5], [10, 10], [3, 9]]

        affs, mask = compute_affinities(labels, offsets)
        expected_shape = (len(offsets), ) + labels.shape
        self.assertEqual(affs.shape, expected_shape)
        self.assertEqual(mask.shape, expected_shape)
        self.assertNotEqual(np.sum(affs == 0), 0)
        self.assertNotEqual(np.sum(mask == 0), 0)
Exemple #9
0
def computeAffs(file_from, offsets):
    file = h5py.File(file_from, 'a')
    keys = list(file.keys())
    file.create_group('masks')
    file.create_group('affs')
    for k in keys:
        data = file[k][:].copy()
        affinities, _ = compute_affinities(data != 0, offsets)
        file['affs'].create_dataset(k, data=affinities)
        file['masks'].create_dataset(k, data=data)
        del file[k]
    return
Exemple #10
0
 def test_malis_2d(self):
     from affogato.affinities import compute_affinities
     from affogato.learning import compute_malis_2d
     shape = (100, 100)
     labels = np.random.randint(0, 100, size=shape)
     offsets = [[-1, 0], [0, -1]]
     affs, _ = compute_affinities(labels, offsets)
     affs += 0.1 * np.random.randn(*affs.shape)
     loss, grads = compute_malis_2d(affs, labels, offsets)
     self.assertEqual(grads.shape, affs.shape)
     self.assertNotEqual(loss, 0)
     self.assertFalse(np.allclose(grads, 0))
Exemple #11
0
 def __getitem__(self, idx):
     img = np.zeros((20, 20))
     affinities = np.ones((len(offsets), 20, 20))
     gt_affinities = np.ones((len(offsets), 20, 20))
     for y in range(len(img)):
         for x in range(len(img[0])):
             if y < 10 and x < 10:
                 img[y, x] = 1
             if y >= 10 and x < 10:
                 img[y, x] = 2
             if y < 10 and x >= 10:
                 img[y, x] = 3
             if y >= 10 and x >= 10:
                 img[y, x] = 4
             if 7 < y < 13 and 7 < x < 13:
                 img[y, x] = 5
     for i in np.unique(img):
         affs, _ = compute_affinities(img == i, offsets)
         gt_affinities *= affs
         # gt_affinities = (gt_affinities == 0).astype(np.float)
     for y in range(len(img)):
         for x in range(len(img[0])):
             if 10 < y < 12 and 10 < x < 12:
                 img[y, x] += 1
             if 16 < y < 18 and 6 < x < 8:
                 img[y, x] += 1
             if 6 < y < 10 and 9 < x < 11:
                 img[y, x] += 1
     for i in np.unique(img):
         affs, _ = compute_affinities(img == i, offsets)
         affinities *= affs
         # affinities = (affinities != 0).astype(np.float)
     affinities = (affinities == 1).astype(np.float)
     gt_affinities = (gt_affinities == 1).astype(np.float)
     affinities[sep_chnl:] *= -1
     affinities[sep_chnl:] += +1
     gt_affinities[sep_chnl:] *= -1
     gt_affinities[sep_chnl:] += +1
     return torch.tensor(img).unsqueeze(0).float(), torch.tensor(
         affinities).float(), torch.tensor(gt_affinities).float()
Exemple #12
0
 def __getitem__(self, idx):
     simple_img = [[1, 1, 1, 1], [1, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 0]]
     simple_affs = [[[0, 0, 0, 0], [0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0,
                                                                0]],
                    [[0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 0,
                                                                1]],
                    [[1, 1, 1, 1], [1, 0, 0, 1], [1, 1, 0, 1], [1, 1, 1,
                                                                1]],
                    [[1, 1, 0, 1], [1, 0, 1, 1], [1, 1, 1, 0], [1, 1, 1,
                                                                0]]]
     simple_affs, _ = compute_affinities(np.array(simple_img) == 0, offsets)
     return torch.tensor(simple_img).unsqueeze(0).float(), torch.tensor(
         simple_affs).unsqueeze(0).float()
Exemple #13
0
    def get_multicut_energy_segmentation(self,
                                         pixel_segm,
                                         affinities,
                                         offsets,
                                         edge_mask=None):
        if edge_mask is None:
            edge_mask = np.ones_like(affinities, dtype='bool')

        log_affinities = compute_edge_costs(1 - affinities,
                                            beta=self.beta_bias)

        # Find affinities "on cut":
        affs_not_on_cut, _ = compute_affinities(pixel_segm.astype('uint64'),
                                                offsets.tolist(), False, 0)
        return log_affinities[np.logical_and(affs_not_on_cut == 0,
                                             edge_mask)].sum()
 def test_mutex_malis_3d(self):
     from affogato.affinities import compute_affinities
     from affogato.learning import mutex_malis
     shape = (32, 64, 64)
     labels = np.random.randint(0, 1000, size=shape)
     offsets = [[-1, 0, 0], [0, -1, 0], [0, 0, -1],
                [-3, 0, 0], [0, -3, 0], [0, 0, -3]]
     affs, _ = compute_affinities(labels, offsets)
     affs += 0.1 * np.random.randn(*affs.shape)
     affs -= affs.min()
     affs /= affs.max()
     loss, grads, _, _ = mutex_malis(affs, labels, offsets, 3)
     self.assertEqual(grads.shape, affs.shape)
     # FIXME this fails
     self.assertNotEqual(loss, 0)
     self.assertFalse(np.allclose(grads, 0))
def _insert_affinities_block(block_id, blocking, ds, objects, offsets):
    fu.log("start processing block %i" % block_id)
    halo = np.max(np.abs(offsets), axis=0)

    block = blocking.getBlockWithHalo(block_id, halo.tolist())
    outer_bb = vu.block_to_bb(block.outerBlock)
    inner_bb = (slice(None), ) + vu.block_to_bb(block.innerBlock)
    local_bb = (slice(None), ) + vu.block_to_bb(block.innerBlockLocal)

    # load objects and check if we have any in this block
    objs = objects[outer_bb]
    if objs.sum() == 0:
        fu.log_block_success(block_id)
        return

    affs, _ = compute_affinities(objs, offsets)
    affs = cast(1. - affs, ds.dtype)
    ds[inner_bb] += affs[local_bb]
    fu.log_block_success(block_id)
Exemple #16
0
def _insert_affinities(affs, objs, offsets, dilate_by):
    dtype = affs.dtype
    # compute affinities to objs and bring them to our aff convention
    affs_insert, mask = compute_affinities(objs, offsets)
    mask = mask == 0
    affs_insert = 1. - affs_insert
    affs_insert[mask] = 0

    # dilate affinity channels
    for c in range(affs_insert.shape[0]):
        affs_insert[c] = dilate(affs_insert[c], iterations=dilate_by, dilate_2d=True)
    # dirty hack: z affinities look pretty weird, so we add the averaged xy affinities
    affs_insert[0] += np.mean(affs_insert[1:3], axis=0)

    # insert affinities
    affs = vu.normalize(affs)
    affs += affs_insert
    affs = np.clip(affs, 0., 1.)
    affs = cast(affs, dtype)
    return affs
Exemple #17
0
    def __getitem__(self, idx):
        radius = np.random.randint(max(self.shape) // 5, max(self.shape) // 3)
        mp = (np.random.randint(0 + radius, self.shape[0] - radius),
              np.random.randint(0 + radius, self.shape[1] - radius))
        # mp = self.mp
        data = np.zeros(shape=self.shape, dtype=np.float)
        gt = np.zeros(shape=self.shape, dtype=np.float)
        for y in range(self.shape[0]):
            for x in range(self.shape[1]):
                ly, lx = y - mp[0], x - mp[1]
                if (ly**2 + lx**2)**.5 <= radius:
                    data[y, x] += np.sin(x * 10 * np.pi / self.shape[1])
                    data[y, x] += np.sin(
                        np.sqrt(x**2 + y**2) * 20 * np.pi / self.shape[1])
                    gt[y, x] = 1
                else:
                    data[y, x] += np.sin(y * 5 * np.pi / self.shape[1])
                    data[y, x] += np.sin(
                        np.sqrt(x**2 + (self.shape[1] - y)**2) * 10 * np.pi /
                        self.shape[1])
        # plt.imshow(data);plt.show()
        gt_affinities, _ = compute_affinities(gt == 1, offsets)

        affinities = gt_affinities
        raw = torch.tensor(data).unsqueeze(0).unsqueeze(0).float()
        if self.aff_pred is not None:
            gt_affinities[self.sep_chnl:] *= -1
            gt_affinities[self.sep_chnl:] += +1
            with torch.set_grad_enabled(False):
                affinities = self.aff_pred(raw.to(self.aff_pred.device))
                affinities = affinities.squeeze().detach().cpu().numpy()
                affinities[self.sep_chnl:] *= -1
                affinities[self.sep_chnl:] += +1
                affinities[:self.sep_chnl] /= 1.5

        return raw.squeeze(0), affinities, gt_affinities
Exemple #18
0
    def get(self, idx):
        radius = np.random.randint(max(self.shape) // 5, max(self.shape) // 3)
        mp = (np.random.randint(0 + radius, self.shape[0] - radius),
              np.random.randint(0 + radius, self.shape[1] - radius))
        # mp = self.mp
        data = np.zeros(shape=self.shape, dtype=np.float)
        gt = np.zeros(shape=self.shape, dtype=np.float)
        for y in range(self.shape[0]):
            for x in range(self.shape[1]):
                ly, lx = y - mp[0], x - mp[1]
                if (ly**2 + lx**2)**.5 <= radius:
                    data[y, x] += np.sin(x * 10 * np.pi / self.shape[1])
                    data[y, x] += np.sin(
                        np.sqrt(x**2 + y**2) * 20 * np.pi / self.shape[1])
                    # data[y, x] += 4
                    gt[y, x] = 1
                else:
                    data[y, x] += np.sin(y * 10 * np.pi / self.shape[1])
                    data[y, x] += np.sin(
                        np.sqrt(x**2 + (self.shape[1] - y)**2) * 10 * np.pi /
                        self.shape[1])
        data += 1
        # plt.imshow(data);plt.show()
        gt_affinities, _ = compute_affinities(gt == 1, offsets)

        seg_arbitrary = np.zeros_like(data)
        square_dict = {}
        i = 0
        granularity = 30
        for y in range(self.shape[0]):
            for x in range(self.shape[1]):
                if (x // granularity, y // granularity) not in square_dict:
                    square_dict[(x // granularity, y // granularity)] = i
                    i += 1
                seg_arbitrary[y, x] += square_dict[(x // granularity,
                                                    y // granularity)]
        seg_arbitrary += gt * 1000
        i = 0
        segs = np.unique(seg_arbitrary)
        seg_arb = np.zeros_like(seg_arbitrary)
        for seg in segs:
            seg_arb += (seg_arbitrary == seg) * i
            i += 1
        seg_arbitrary = seg_arb
        rag = feats.compute_rag(np.expand_dims(seg_arbitrary, axis=0))
        neighbors = rag.uvIds()

        affinities = get_naive_affinities(data, offsets)
        # edge_feat = get_edge_features_1d(seg_arbitrary, offsets, affinities)
        # self.edge_offsets = [[1, 0], [0, 1], [1, 0], [0, 1]]
        # self.sep_chnl = 2
        # affinities = np.stack((ndimage.sobel(data, axis=0), ndimage.sobel(data, axis=1)))
        # affinities = np.concatenate((affinities, affinities), axis=0)
        affinities[:self.sep_chnl] *= -1
        affinities[:self.sep_chnl] += +1
        affinities[self.sep_chnl:] /= 0.2
        #
        raw = torch.tensor(data).unsqueeze(0).unsqueeze(0).float()
        # if self.aff_pred is not None:
        #     gt_affinities[self.sep_chnl:] *= -1
        #     gt_affinities[self.sep_chnl:] += +1
        #     gt_affinities[:self.sep_chnl] /= 1.5
        # with torch.set_grad_enabled(False):
        #     affinities = self.aff_pred(raw.to(self.aff_pred.device))
        #     affinities = affinities.squeeze().detach().cpu().numpy()
        #     affinities[self.sep_chnl:] *= -1
        #     affinities[self.sep_chnl:] += +1
        #     affinities[:self.sep_chnl] /= 1.2

        valid_edges = get_valid_edges((len(self.edge_offsets), ) + self.shape,
                                      self.edge_offsets, self.sep_chnl, None,
                                      False)
        node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
            affinities.ravel(), valid_edges.ravel(), offsets, self.sep_chnl,
            self.shape)
        node_labeling = node_labeling - 1
        node_labeling = seg_arbitrary
        # plt.imshow(cm.prism(node_labeling/node_labeling.max()));plt.show()
        # plt.imshow(data);plt.show()
        neighbors = (node_labeling.ravel())[neighbors]
        nodes = np.unique(node_labeling)
        edge_feat = get_edge_features_1d(node_labeling, offsets, affinities)

        # for i, node in enumerate(nodes):
        #     seg = node_labeling == node
        #     masked_data = seg * data
        #     idxs = np.where(seg)
        #     dxs1 = np.stack(idxs).transpose()
        #     # y, x = bbox(np.expand_dims(seg, 0))
        #     # y, x = y[0], x[0]
        #     mass = np.sum(seg)
        #     # _, s, _ = np.linalg.svd(StandardScaler().fit_transform(seg))
        #     mean = np.sum(masked_data) / mass
        #     cm = np.sum(dxs1, axis=0) / mass
        #     var = np.var(data[idxs[0], idxs[1]])
        #
        #     mean = 0 if mean < .5 else 1
        #
        #     node_features[node] = torch.tensor([mean])

        offsets_3d = [[0, 0, -1], [0, -1, 0], [0, -3, 0], [0, 0, -3]]

        # rag = feats.compute_rag(np.expand_dims(node_labeling, axis=0))
        # edge_feat = feats.compute_affinity_features(rag, np.expand_dims(affinities, axis=1), offsets_3d)[:, :]
        # gt_edge_weights = feats.compute_affinity_features(rag, np.expand_dims(gt_affinities, axis=1), offsets_3d)[:, 0]
        gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                                  node_labeling.squeeze(),
                                                  gt.squeeze())
        # gt_edge_weights = utils.calculate_naive_gt_edge_costs(neighbors, node_features).unsqueeze(-1)
        # affs = np.expand_dims(affinities, axis=1)
        # boundary_input = np.mean(affs, axis=0)
        # plt.imshow(multicut_from_probas(node_labeling, neighbors, gt_edge_weights, boundary_input));plt.show()

        # neighs = np.empty((10, 2))
        # gt_neighs = np.empty(10)
        # neighs[0] = neighbors[30]
        # gt_neighs[0] = gt_edge_weights[30]
        # i = 0
        # while True:
        #     for idx, n in enumerate(neighbors):
        #         if n[0] in neighs.ravel() or n[1] in neighs.ravel():
        #             neighs[i] = n
        #             gt_neighs[i] = gt_edge_weights[idx]
        #             i += 1
        #             if i == 10:
        #                 break
        #     if i == 10:
        #         break
        #
        # nodes = nodes[np.unique(neighs.ravel())]
        # node_features = nodes
        # neighbors = neighs

        edges = torch.from_numpy(neighbors.astype(np.long))
        raw = raw.squeeze()
        edge_feat = torch.from_numpy(edge_feat.astype(np.float32))
        nodes = torch.from_numpy(nodes.astype(np.float32))
        # gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.float32))
        # affinities = torch.from_numpy(affinities.astype(np.float32))
        affinities = torch.from_numpy(gt_affinities.astype(np.float32))
        gt_affinities = torch.from_numpy(gt_affinities.astype(np.float32))
        node_labeling = torch.from_numpy(node_labeling.astype(np.float32))

        gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.float32))
        # noise = torch.randn_like(edge_feat) / 3
        # edge_feat += noise
        # edge_feat = torch.min(edge_feat, torch.ones_like(edge_feat))
        # edge_feat = torch.max(edge_feat, torch.zeros_like(edge_feat))
        diff_to_gt = (edge_feat[:, 0] - gt_edge_weights).abs().sum()

        node_features, angles = get_stacked_node_data(nodes,
                                                      edges,
                                                      node_labeling,
                                                      raw,
                                                      size=[32, 32])
        # plt.imshow(node_features.view(-1, 32));
        # plt.show()

        edges = edges.t().contiguous()
        edges = torch.cat((edges, torch.stack((edges[1], edges[0]))), dim=1)

        return edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles
Exemple #19
0
    def input_function(self, tensor):
        labels = tensor[0]
        boundary_mask = None
        glia_mask = None
        extra_masks = None

        if tensor.shape[0] > 1:
            # Here we get both the segmentation and an additional mask:
            assert tensor.shape[
                0] == 2, "Only one additional mask is supported at the moment"
            extra_masks = tensor[1]

            if self.boundary_label is not None:
                boundary_mask = (extra_masks == self.boundary_label)
            if not self.train_affs_on_glia and self.glia_label is not None:
                glia_mask = (extra_masks == self.glia_label)

        output, mask = compute_affinities(labels.astype('int64'),
                                          self.offsets,
                                          ignore_label=self.ignore_label,
                                          boundary_mask=boundary_mask,
                                          glia_mask=glia_mask)

        if self.learn_ignore_transitions and self.ignore_label is not None:
            output, mask = self.include_ignore_transitions(
                output, mask, labels)

        # Cast to be sure
        if not output.dtype == self.dtype:
            output = output.astype(self.dtype)
        #
        # print("affs: shape before binary", output.shape)
        if self.segmentation_to_binary:
            output = np.concatenate(
                (self.to_binary_segmentation(labels)[None], output), axis=0)
        # print("affs: shape after binary", output.shape)

        # print("affs: shape before mask", output.shape)
        # We might want to carry the mask along.
        # If this is the case, we insert it after the targets.
        if self.retain_mask:
            mask = mask.astype(self.dtype, copy=False)
            if self.segmentation_to_binary:
                if self.ignore_label is None:
                    additional_mask = np.ones((1, ) + labels.shape,
                                              dtype=self.dtype)
                else:
                    additional_mask = (labels[None] !=
                                       self.ignore_label).astype(self.dtype)
                mask = np.concatenate([additional_mask, mask], axis=0)
            output = np.concatenate((output, mask), axis=0)
        # print("affs: shape after mask", output.shape)

        # We might want to carry the segmentation along for validation.
        # If this is the case, we insert it before the targets.
        if self.retain_segmentation:
            # Add a channel axis to labels to make it (C, Z, Y, X) before cating to output
            if self.retain_extra_masks:
                assert extra_masks is not None, "Extra masks where not passed and cannot be concatenated"
                output = np.concatenate(
                    (labels[None].astype(self.dtype, copy=False),
                     extra_masks[None].astype(self.dtype, copy=False), output),
                    axis=0)
            else:
                output = np.concatenate(
                    (labels[None].astype(self.dtype, copy=False), output),
                    axis=0)

        if self.retain_glia_mask:
            assert self.glia_label is not None
            output = np.concatenate(
                (output,
                 np.expand_dims(
                     (extra_masks == self.glia_label).astype('float32'),
                     axis=0)),
                axis=0)

        # print("affs: out shape", output.shape)
        return output
offsets = [[-1, 0, 0], [0, -1, 0], [0, 0, -1], [-2, 0, 0], [0, -3, 0],
           [0, 0, -3], [-3, 0, 0], [0, -9, 0], [0, 0, -9], [-4, 0, 0],
           [0, -27, 0], [0, 0, -27]]

# Fake duplicate affinities:
duplicate_affs = np.empty_like(affs)
for i, off in enumerate(offsets):
    duplicate_affs[i] = np.roll(affs[i], off)
affs = np.concatenate([affs, duplicate_affs], axis=0)
offsets = offsets + [[-off[0], -off[1], -off[2]] for off in offsets]

from neurofire.transform.affinities import Segmentation2AffinitiesDynamicOffsets, affinity_config_to_transform

from affogato.affinities import compute_multiscale_affinities, compute_affinities

_, mask = compute_affinities(np.zeros_like(raw, dtype='int64'), offsets)

print("Total valid edges: ", mask.sum())

dynHC = DynamicHC(affs - 0.5, offsets)

# for n in range(dynHC.nb_nodes):
#     print(n, dynHC.from_label_to_coord(n))

dynHC.run_v3(np.array([0, 20, 20]))

print("Final number edges inserted: ", dynHC.graph.numberOfEdges)
print("Final number active nodes: ", dynHC.is_node_active.sum())

final_segm, active_nodes = dynHC.get_segmentation()
import segmfriends.vis as vis
Exemple #21
0
    def create_dsets(self, num):
        for file_index in range(num):
            n_disc = np.random.randint(25, 30)
            rads = []
            mps = []
            for disc in range(n_disc):
                radius = np.random.randint(
                    max(self.shape) // 25,
                    max(self.shape) // 20)
                touching = True
                while touching:
                    mp = np.array([
                        np.random.randint(0 + radius, self.shape[0] - radius),
                        np.random.randint(0 + radius, self.shape[1] - radius)
                    ])
                    touching = False
                    for other_rad, other_mp in zip(rads, mps):
                        diff = mp - other_mp
                        if (diff**2).sum()**.5 <= radius + other_rad + 2:
                            touching = True
                rads.append(radius)
                mps.append(mp)

            data = np.zeros(shape=self.shape, dtype=np.float)
            gt = np.zeros(shape=self.shape, dtype=np.float)
            for y in range(self.shape[0]):
                for x in range(self.shape[1]):
                    bg = True
                    for radius, mp in zip(rads, mps):
                        ly, lx = y - mp[0], x - mp[1]
                        if (ly**2 + lx**2)**.5 <= radius:
                            data[y, x] += np.cos(
                                np.sqrt((x - self.shape[1])**2 + y**2) * 50 *
                                np.pi / self.shape[1])
                            data[y, x] += np.cos(
                                np.sqrt(x**2 + y**2) * 50 * np.pi /
                                self.shape[1])
                            # data[y, x] += 6
                            gt[y, x] = 1
                            bg = False
                    if bg:
                        data[y, x] += np.cos(y * 40 * np.pi / self.shape[0])
                        data[y, x] += np.cos(
                            np.sqrt(x**2 + (self.shape[0] - y)**2) * 30 *
                            np.pi / self.shape[1])
            data += 1
            # plt.imshow(data);plt.show()
            if self.no_suppix:
                raw = torch.from_numpy(data).float()
                return raw.unsqueeze(0), torch.from_numpy(gt.astype(np.long))
                # return torch.stack((torch.rand_like(raw), raw, torch.rand_like(raw))), torch.from_numpy(gt.astype(np.long))

            affinities = affutils.get_naive_affinities(data, self.offsets)
            gt_affinities, _ = compute_affinities(gt == 1, self.offsets)
            gt_affinities[self.sep_chnl:] *= -1
            gt_affinities[self.sep_chnl:] += +1
            affinities[self.sep_chnl:] *= -1
            affinities[self.sep_chnl:] += +1
            # affinities[:self.sep_chnl] /= 1.1
            affinities[self.sep_chnl:] *= 1.01
            affinities = (affinities -
                          (affinities * gt_affinities)) + gt_affinities

            # affinities[self.sep_chnl:] *= -1
            # affinities[self.sep_chnl:] += +1
            # affinities[self.sep_chnl:] *= 4
            affinities = affinities.clip(0, 1)

            valid_edges = get_valid_edges((len(self.offsets), ) + self.shape,
                                          self.offsets, self.sep_chnl, None,
                                          False)
            node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
                affinities.ravel(), valid_edges.ravel(), self.offsets,
                self.sep_chnl, self.shape)
            node_labeling = node_labeling - 1
            nodes = np.unique(node_labeling)
            try:
                assert all(
                    nodes == np.array(range(len(nodes)), dtype=np.float))
            except:
                Warning("node ids are off")

            noisy_affinities = affinities

            edge_feat, neighbors = get_edge_features_1d(
                node_labeling, self.offsets, noisy_affinities)
            gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                                      node_labeling.squeeze(),
                                                      gt.squeeze())

            while abs(gt_edge_weights.sum() - (len(gt_edge_weights) / 2)) > 1:
                edge_idx = np.random.choice(np.arange(len(gt_edge_weights)),
                                            p=torch.softmax(torch.from_numpy(
                                                (gt_edge_weights == 0).astype(
                                                    np.float)),
                                                            dim=0).numpy())
                if gt_edge_weights[edge_idx] != 0.0:
                    continue

                # print(abs(gt_edge_weights.sum() - (len(gt_edge_weights) / 2)))
                edge = neighbors[edge_idx].astype(np.int)
                # merge superpixel
                diff = edge[0] - edge[1]

                mass = (node_labeling == edge[0]).sum()
                node_labeling = node_labeling - (node_labeling
                                                 == edge[0]) * diff
                new_mass = (node_labeling == edge[1]).sum()
                try:
                    assert new_mass >= mass
                except:
                    a = 1

                # if edge_idx == 0:
                #     neighbors = neighbors[1:]
                #     gt_edge_weights = gt_edge_weights[1:]
                # elif edge_idx == len(gt_edge_weights):
                #     neighbors = neighbors[:-1]
                #     gt_edge_weights = gt_edge_weights[:-1]
                # else:
                #     neighbors = np.concatenate((neighbors[:edge_idx], neighbors[edge_idx+1:]), axis=0)
                #     gt_edge_weights = np.concatenate((gt_edge_weights[:edge_idx], gt_edge_weights[edge_idx+1:]), axis=0)
                #
                # neighbors[neighbors == edge[0]] == edge[1]

                edge_feat, neighbors = get_edge_features_1d(
                    node_labeling, self.offsets, noisy_affinities)
                gt_edge_weights = calculate_gt_edge_costs(
                    neighbors, node_labeling.squeeze(), gt.squeeze())

            edge_feat, neighbors = get_edge_features_1d(
                node_labeling, self.offsets, noisy_affinities)
            gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                                      node_labeling.squeeze(),
                                                      gt.squeeze())

            gt = torch.from_numpy(gt.astype(np.float32)).squeeze().float()

            edges = torch.from_numpy(neighbors.astype(np.long))
            raw = torch.tensor(data).squeeze().float()
            noisy_affinities = torch.tensor(noisy_affinities).squeeze().float()
            edge_feat = torch.from_numpy(edge_feat.astype(np.float32))
            nodes = torch.from_numpy(nodes.astype(np.float32))
            node_labeling = torch.from_numpy(node_labeling.astype(np.float32))
            gt_edge_weights = torch.from_numpy(
                gt_edge_weights.astype(np.float32))
            diff_to_gt = (edge_feat[:, 0] - gt_edge_weights).abs().sum()
            edges = edges.t().contiguous()
            edges = torch.cat((edges, torch.stack((edges[1], edges[0]))),
                              dim=1)

            self.write_to_h5(
                '/g/kreshuk/hilt/projects/fewShotLearning/mutexWtsd/data/storage/balanced_graphs/balanced_graph_data'
                + str(file_index) + '.h5', edges, edge_feat, diff_to_gt,
                gt_edge_weights, node_labeling, raw, nodes, noisy_affinities,
                gt)
Exemple #22
0
    def get(self, idx):
        n_disc = np.random.randint(8, 10)
        rads = []
        mps = []
        for disc in range(n_disc):
            radius = np.random.randint(
                max(self.shape) // 18,
                max(self.shape) // 15)
            touching = True
            while touching:
                mp = np.array([
                    np.random.randint(0 + radius, self.shape[0] - radius),
                    np.random.randint(0 + radius, self.shape[1] - radius)
                ])
                touching = False
                for other_rad, other_mp in zip(rads, mps):
                    diff = mp - other_mp
                    if (diff**2).sum()**.5 <= radius + other_rad + 2:
                        touching = True
            rads.append(radius)
            mps.append(mp)

        # take static image
        # rads = self.rads
        # mps = self.mps

        data = np.zeros(shape=self.shape, dtype=np.float)
        gt = np.zeros(shape=self.shape, dtype=np.float)
        for y in range(self.shape[0]):
            for x in range(self.shape[1]):
                bg = True
                for radius, mp in zip(rads, mps):
                    ly, lx = y - mp[0], x - mp[1]
                    if (ly**2 + lx**2)**.5 <= radius:
                        data[y, x] += np.cos(
                            np.sqrt((x - self.shape[1])**2 + y**2) * 50 *
                            np.pi / self.shape[1])
                        data[y, x] += np.cos(
                            np.sqrt(x**2 + y**2) * 50 * np.pi / self.shape[1])
                        # data[y, x] += 6
                        gt[y, x] = 1
                        bg = False
                if bg:
                    data[y, x] += np.cos(y * 40 * np.pi / self.shape[0])
                    data[y, x] += np.cos(
                        np.sqrt(x**2 + (self.shape[0] - y)**2) * 30 * np.pi /
                        self.shape[1])
        data += 1
        # plt.imshow(data);plt.show()
        # if self.no_suppix:
        #     raw = torch.from_numpy(data).float()
        #     return raw.unsqueeze(0), torch.from_numpy(gt.astype(np.long))
        # return torch.stack((torch.rand_like(raw), raw, torch.rand_like(raw))), torch.from_numpy(gt.astype(np.long))

        affinities = affutils.get_naive_affinities(data, self.offsets)
        gt_affinities, _ = compute_affinities(gt == 1, self.offsets)
        gt_affinities[self.sep_chnl:] *= -1
        gt_affinities[self.sep_chnl:] += +1
        affinities[self.sep_chnl:] *= -1
        affinities[self.sep_chnl:] += +1
        # affinities[:self.sep_chnl] /= 1.1
        affinities[self.sep_chnl:] *= 1.01
        affinities = (affinities -
                      (affinities * gt_affinities)) + gt_affinities

        # affinities[self.sep_chnl:] *= -1
        # affinities[self.sep_chnl:] += +1
        # affinities[self.sep_chnl:] *= 4
        affinities = affinities.clip(0, 1)

        valid_edges = get_valid_edges((len(self.offsets), ) + self.shape,
                                      self.offsets, self.sep_chnl, None, False)
        node_labeling, neighbors, cutting_edges, mutexes = compute_mws_segmentation_cstm(
            affinities.ravel(), valid_edges.ravel(), self.offsets,
            self.sep_chnl, self.shape)
        node_labeling = node_labeling - 1
        # rag = elf.segmentation.features.compute_rag(np.expand_dims(node_labeling, axis=0))
        # neighbors = rag.uvIds()
        i = 0

        # node_labeling = gt * 5000 + node_labeling
        # segs = np.unique(node_labeling)
        #
        # new_labeling = np.zeros_like(node_labeling)
        # for seg in segs:
        #     i += 1
        #     new_labeling += (node_labeling == seg) * i
        #
        # node_labeling = new_labeling - 1

        # gt_labeling, _, _, _ = compute_mws_segmentation_cstm(gt_affinities.ravel(),
        #                                                      valid_edges.ravel(),
        #                                                      offsets,
        #                                                      self.shape)
        #                                                      self.sep_chnl,

        nodes = np.unique(node_labeling)
        try:
            assert all(nodes == np.array(range(len(nodes)), dtype=np.float))
        except:
            Warning("node ids are off")

        noisy_affinities = np.random.rand(*affinities.shape)
        noisy_affinities = noisy_affinities.clip(0, 1)
        noisy_affinities = affinities

        edge_feat, neighbors = get_edge_features_1d(node_labeling,
                                                    self.offsets,
                                                    noisy_affinities)
        gt_edge_weights = calculate_gt_edge_costs(neighbors,
                                                  node_labeling.squeeze(),
                                                  gt.squeeze())

        if self.less:
            raw = torch.from_numpy(data).float()
            node_labeling = torch.from_numpy(node_labeling.astype(np.float32))
            gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.long))
            edges = torch.from_numpy(neighbors.astype(np.long))
            edges = edges.t().contiguous()
            edges = torch.cat((edges, torch.stack((edges[1], edges[0]))),
                              dim=1)
            return raw.unsqueeze(0), node_labeling, torch.from_numpy(
                gt.astype(np.long)), gt_edge_weights, edges

        # affs = np.expand_dims(affinities, axis=1)
        # boundary_input = np.mean(affs, axis=0)
        # gt1 = gutils.multicut_from_probas(node_labeling.astype(np.float32), neighbors.astype(np.float32),
        #                                  gt_edge_weights.astype(np.float32), boundary_input.astype(np.float32))

        # plt.imshow(node_labeling)
        # plt.show()
        # plt.imshow(gt1)
        # plt.show()

        gt = torch.from_numpy(gt.astype(np.float32)).squeeze().float()

        edges = torch.from_numpy(neighbors.astype(np.long))
        raw = torch.tensor(data).squeeze().float()
        noisy_affinities = torch.tensor(noisy_affinities).squeeze().float()
        edge_feat = torch.from_numpy(edge_feat.astype(np.float32))
        nodes = torch.from_numpy(nodes.astype(np.float32))
        node_labeling = torch.from_numpy(node_labeling.astype(np.float32))
        gt_edge_weights = torch.from_numpy(gt_edge_weights.astype(np.float32))
        diff_to_gt = (edge_feat[:, 0] - gt_edge_weights).abs().sum().item()
        # node_features, angles = get_stacked_node_data(nodes, edges, node_labeling, raw, size=[32, 32])

        # file = h5py.File("/g/kreshuk/hilt/projects/rags/" + "rag_" + str(self.fidx) + ".h5", "w")
        # file.create_dataset("edges", data=edges.numpy())
        # self.fidx += 1

        if self.no_suppix:
            raw = torch.from_numpy(data).float()
            return raw.unsqueeze(0), torch.from_numpy(gt.numpy().astype(
                np.long))

        edges = edges.t().contiguous()
        edges = torch.cat((edges, torch.stack((edges[1], edges[0]))), dim=1)

        # print('imbalance: ', abs(gt_edge_weights.sum() - (len(gt_edge_weights) / 2)))

        return edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, noisy_affinities, gt
input_h5_file, output_h5_file = parse_args(sys.argv)

print(f"Computing affinities based on {input_h5_file}")

with h5py.File(input_h5_file, "r+") as input_h5:
    label_ids = numpy.copy(input_h5["/volumes/labels/merged_ids"])

    start = time.time()
    affinities, mask = compute_affinities(label_ids,
                                          offset=[
                                              [-1, 0, 0],
                                              [0, -1, 0],
                                              [0, 0, -1],

                                              [-7, 0, 0],
                                              [0, -7, 0],
                                              [0, 0, -7],

                                              [-15, 0, 0],
                                              [0, -15, 0],
                                              [0, 0, -15]
                                          ],
                                          have_ignore_label=True,
                                          ignore_label=0)

    end = time.time()

    print("Computing affinities took %.3f" % (end - start))

    with h5py.File(output_h5_file, "w") as output_h5:
        output_h5.create_dataset("affinities", data=affinities,
                                 compression="gzip")
        else:
            GT = f['segmentations/groundtruth_fixed'][:]
            # raw = f['raw'][:]

    from affogato.affinities import compute_affinities

    offsets = [
        [0, 1, 0],
        [0, 0, 1],
    ]
    print(GT.max())

    # affs: 0 boundary, 1 segment;
    # valid_mask: 1 is valid
    affs, affs_valid_mask = compute_affinities(GT.astype('int64'),
                                               offsets,
                                               ignore_label=0,
                                               have_ignore_label=True)

    # Where it is not valid, we should not predict a boundary label:
    affs[affs_valid_mask == 0] = 1

    # Combine left and right affinities:
    segment_mask = np.logical_and(affs[0], affs[1])

    # This functions erode binary mask (segments 1, boundary 0)
    eroded_segment_mask = segment_mask.copy()
    for z in range(eroded_segment_mask.shape[0]):
        eroded_segment_mask[z] = vigra.filters.multiBinaryErosion(
            segment_mask[z], radius=2.)
    boundary_mask = np.logical_not(eroded_segment_mask)
            ]

have_ignore_label: (boolean) indicating whether there is a label that should be ignored while computing the affinities.
                    By default is False.

ignore_label: (int) value of the ignore label. By default is 0. If `have_ignore_label` is False, is ignored.


##########
# Outputs:
##########

affinities: boolean numpy array. It will have shape ( len(offset), ) + labels.shape

valid_mask: boolean numpy array with the same shape of `affinities`, indicating which computed affinities
            are valid and which are not (for example because they go out of the segmentation boundaries or they
            involve the ignore_label).
"""

import numpy as np
from affogato.affinities import compute_affinities

example_offsets = [[-1, 0, 0], [0, -1, 0], [0, 0, -1], [0, -4, 0], [0, 0, -4],
                   [0, -12, 0], [0, 0, -12]]

test_shape = (20, 20, 20)
example_segmentation = np.random.randint(0, 1000, size=test_shape)

affinities, valid_affinities = compute_affinities(
    example_segmentation.astype('uint64'), example_offsets, False, 0)