예제 #1
0
파일: model2d.py 프로젝트: snlee81/stardist
    def __getitem__(self, i):
        idx = slice(i*self.batch_size,(i+1)*self.batch_size)
        idx = list(self.perm[idx])

        arrays = [sample_patches_from_multiple_stacks((self.X[k],self.Y[k]),
                                                      patch_size=self.patch_size, n_samples=1,
                                                      patch_filter=self.no_background_patches_cached(k)) for k in idx]
        X, Y = list(zip(*[(x[0][self.b],y[0]) for x,y in arrays]))

        X, Y = self.augmenter(X, Y)

        prob = np.stack([edt_prob(lbl[self.b]) for lbl in Y])

        if self.shape_completion:
            Y_cleared = [clear_border(lbl) for lbl in Y]
            dist      = np.stack([star_dist(lbl,self.n_rays,mode=self.sd_mode)[self.b+(slice(None),)] for lbl in Y_cleared])
            dist_mask = np.stack([edt_prob(lbl[self.b]) for lbl in Y_cleared])
        else:
            dist      = np.stack([star_dist(lbl,self.n_rays,mode=self.sd_mode) for lbl in Y])
            dist_mask = prob

        X = np.stack(X)
        if X.ndim == 3: # input image has no channel axis
            X = np.expand_dims(X,-1)
        prob = np.expand_dims(prob,-1)
        dist_mask = np.expand_dims(dist_mask,-1)

        # subsample wth given grid
        dist_mask = dist_mask[self.ss_grid]
        prob      = prob[self.ss_grid]
        dist      = dist[self.ss_grid]

        return [X,dist_mask], [prob,dist]
예제 #2
0
    def __getitem__(self, i):
        idx = slice(i * self.batch_size, (i + 1) * self.batch_size)
        idx = list(self.perm[idx])

        arrays = [
            sample_patches_from_multiple_stacks(
                (self.Y[k], ) + self.channels_as_tuple(self.X[k]),
                patch_size=self.patch_size,
                n_samples=1,
                patch_filter=self.no_background_patches_cached(k)) for k in idx
        ]
        if self.n_channel is None:
            X, Y = list(zip(*[(x[0], y[0]) for y, x in arrays]))
        else:
            X, Y = list(
                zip(*[(np.stack([_x[0] for _x in x], axis=-1), y[0])
                      for y, *x in arrays]))

        X, Y = self.augmenter(X, Y)

        if len(Y) == 1:
            X = X[0][np.newaxis]
        else:
            X = np.stack(X, out=self.out_X[:len(Y)])
        if X.ndim == 4:  # input image has no channel axis
            X = np.expand_dims(X, -1)

        tmp = [edt_prob(lbl, anisotropy=self.anisotropy) for lbl in Y]
        if len(Y) == 1:
            prob = tmp[0][np.newaxis]
        else:
            prob = np.stack(tmp, out=self.out_edt_prob[:len(Y)])

        tmp = [star_dist3D(lbl, self.rays, mode=self.sd_mode) for lbl in Y]
        if len(Y) == 1:
            dist = tmp[0][np.newaxis]
        else:
            dist = np.stack(tmp, out=self.out_star_dist3D[:len(Y)])

        prob = dist_mask = np.expand_dims(prob, -1)

        # subsample wth given grid
        dist_mask = dist_mask[self.ss_grid]
        prob = prob[self.ss_grid]
        dist = dist[self.ss_grid]

        return [X, dist_mask], [prob, dist]