Exemplo n.º 1
0
def get_voxel_grid_real_space(images, append_ones=False):
    # Get shape excluding channels
    shape = images.shape[:-1]

    # Get affine transforming voxel positions to real space positions
    vox_to_real_affine = images.affine[:-1, :-1]

    # Transform axes from voxel space to real space
    grid_vox_space = np.mgrid[0:shape[0]:1, 0:shape[1]:1, 0:shape[2]:1]

    # Move grid to real space
    grid_points_real_space = vox_to_real_affine.dot(
        mgrid_to_points(grid_vox_space).T).T

    # Center
    centered_grid_points_real_space = grid_points_real_space - \
                                      np.mean(grid_points_real_space, axis=0)

    # Append column of ones?
    if append_ones:
        centered_grid_points_real_space = np.column_stack(
            (grid_points_real_space, np.ones(len(grid_points_real_space))))

    # Return real space grid as mgrid
    points = points_to_mgrid(centered_grid_points_real_space, shape)

    return points
Exemplo n.º 2
0
    def get_base_patches_from(self, image, return_y=False):
        real_dims = image.real_shape

        # Calculate positions
        sample_space = np.asarray([max(i, self.real_box_dim) for i in real_dims])
        d = (sample_space - self.real_box_dim)
        min_cov = [np.ceil(sample_space[i]/self.real_box_dim).astype(np.int) for i in range(3)]
        ds = [np.linspace(0, d[i], min_cov[i]) - sample_space[i]/2 for i in range(3)]

        # Get placement coordinate points
        placements = mgrid_to_points(np.meshgrid(*tuple(ds)))

        for p in placements:
            grid, axes, inv_mat = sample_box_at(real_placement=p,
                                                sample_dim=self.sample_dim,
                                                real_box_dim=self.real_box_dim,
                                                noise_sd=0.0,
                                                test_mode=True)

            im, lab = self._intrp_and_norm(image, grid, return_y)

            if return_y:
                yield im, lab, grid, axes, inv_mat, len(placements)
            else:
                yield im, grid, axes, inv_mat, len(placements)
Exemplo n.º 3
0
 def apply_rotation(self, mgrid):
     if self.rot_mat is not None:
         shape = mgrid[0].shape
         rotated = self.rot_mat.dot(mgrid_to_points(mgrid).T).T
         return points_to_mgrid(rotated, shape)
     else:
         return mgrid
Exemplo n.º 4
0
def get_voxel_grid(images, as_points=False):
    shape = images.shape[:3]
    grid = np.mgrid[0:shape[0]:1, 0:shape[1]:1, 0:shape[2]:1]

    if as_points:
        return mgrid_to_points(grid)
    else:
        return grid
Exemplo n.º 5
0
def sample_plane_at(norm_vector,
                    sample_dim,
                    real_space_span,
                    offset_from_center,
                    noise_sd,
                    test_mode=False):
    # Prepare normal vector to the plane
    n_hat = np.array(norm_vector, np.float32)
    n_hat /= np.linalg.norm(n_hat)

    # Add noise?
    if type(noise_sd) is not np.ndarray:
        noise_sd = np.random.normal(scale=noise_sd, size=3)

    n_hat += noise_sd
    n_hat /= np.linalg.norm(n_hat)

    if np.all(n_hat[:-1] < 0.2):
        # Vector pointing primarily up, noise will have large effect on image
        # orientation. We force the first two components to go into the
        # positive direction to control variability of sampling
        n_hat[:-1] = np.abs(n_hat[:-1])
    if np.all(np.isclose(n_hat[:-1], 0)):
        u = np.array([1, 0, 0])
        v = np.array([0, 1, 0])
    else:
        # Find vector in same vertical plane as nhat
        nhat_vs = n_hat.copy()
        nhat_vs[-1] = nhat_vs[-1] + 1
        nhat_vs /= np.linalg.norm(nhat_vs)

        # Get two orthogonal vectors in plane, u pointing down in z-direction
        u = get_rotation_matrix(np.cross(n_hat, nhat_vs), -90).dot(n_hat)
        v = np.cross(n_hat, u)

    # Define basis matrix + displacement to center (affine transformation)
    basis = np.column_stack((u, v, n_hat))

    # Define regular grid (centered at origin)
    hd = real_space_span // 2
    g = np.linspace(-hd, hd, sample_dim)

    j = complex(sample_dim)
    grid = np.mgrid[-hd:hd:j, -hd:hd:j,
                    offset_from_center:offset_from_center:1j]

    # Calculate voxel coordinates on the real space grid
    points = mgrid_to_points(grid)

    real_points = basis.dot(points.T).T
    real_grid = points_to_mgrid(real_points, grid.shape[1:])

    if test_mode:
        return real_grid, g, np.linalg.inv(basis)
    else:
        return real_grid
Exemplo n.º 6
0
def map_real_space_pred(pred,
                        grid,
                        inv_basis,
                        voxel_grid_real_space,
                        method="nearest"):
    """
    TODO
    """
    print("Mapping to real coordinate space...")

    # Prepare fill value vector, we set this to 1.0 background
    fill = np.zeros(shape=pred.shape[-1], dtype=np.float32)
    fill[0] = 1.0

    # Initialize interpolator object
    intrp = RegularGridInterpolator(grid,
                                    pred,
                                    fill_value=fill,
                                    bounds_error=False,
                                    method=method)

    points = inv_basis.dot(mgrid_to_points(voxel_grid_real_space).T).T
    transformed_grid = points_to_mgrid(points, voxel_grid_real_space[0].shape)

    # Prepare mapped pred volume
    mapped = np.empty(transformed_grid[0].shape + (pred.shape[-1], ),
                      dtype=pred.dtype)

    # Prepare interpolation function
    def _do(xs, ys, zs, index):
        return intrp((xs, ys, zs)), index

    # Prepare thread pool of 10 workers
    from concurrent.futures import ThreadPoolExecutor
    pool = ThreadPoolExecutor(max_workers=7)

    # Perform interpolation async.
    inds = np.arange(transformed_grid.shape[1])
    result = pool.map(_do, transformed_grid[0], transformed_grid[1],
                      transformed_grid[2], inds)

    i = 1
    for map, ind in result:
        # Print status
        print("  %i/%i" % (i, inds[-1] + 1), end="\r", flush=True)
        i += 1

        # Map the interpolation results into the volume
        mapped[ind] = map

    print("")
    pool.shutdown()
    return mapped
Exemplo n.º 7
0
    def get_base_patches(self, image):
        X = image.image

        # Calculate positions
        sample_space = np.asarray([max(i, self.dim) for i in image.shape[:3]])
        d = (sample_space - self.dim)
        min_cov = [np.ceil(sample_space[i]/self.dim).astype(np.int) for i in range(3)]
        ds = [np.linspace(0, d[i], min_cov[i], dtype=np.int) for i in range(3)]

        # Get placement coordinate points
        placements = mgrid_to_points(np.meshgrid(*tuple(ds)))

        for p in placements:
            yield image.scaler.transform(X[p[0]:p[0]+self.dim,
                                         p[1]:p[1]+self.dim,
                                         p[2]:p[2]+self.dim]), p
Exemplo n.º 8
0
def sample_box_at(real_placement, sample_dim, real_box_dim, noise_sd,
                  test_mode):

    j = complex(sample_dim)
    a, b, c = real_placement
    grid = np.mgrid[a:a + real_box_dim:j, b:b + real_box_dim:j,
                    c:c + real_box_dim:j]

    rot_mat = np.eye(3)
    rot_grid = grid
    if noise_sd:
        # Get random rotation vector
        rot_axis = get_random_views(N=1, dim=3, pos_z=True)

        rot_angle = False
        while not rot_angle:
            angle = np.abs(np.random.normal(scale=noise_sd, size=1)[0])
            if angle < 2 * np.pi:
                rot_angle = angle

        rot_mat = get_rotation_matrix(rot_axis, angle_rad=rot_angle)

        # Center --> apply rotation --> revert centering --> mgrid
        points = mgrid_to_points(grid)
        center = np.mean(points, axis=0)
        points -= center
        points = rot_mat.dot(points.T).T + center
        rot_grid = points_to_mgrid(points, grid.shape[1:])

    if test_mode:
        axes = (np.linspace(a, a + real_box_dim, sample_dim),
                np.linspace(b, b + real_box_dim, sample_dim),
                np.linspace(c, c + real_box_dim, sample_dim))
        return rot_grid, axes, np.linalg.inv(rot_mat)
    else:
        return rot_grid
Exemplo n.º 9
0
def pred_3D_iso(model, sequence, image, extra_boxes, min_coverage=None):
    total_extra_boxes = extra_boxes

    # Get reference to the image
    n_classes = sequence.n_classes
    pred_shape = tuple(image.shape[:3]) + (n_classes, )
    vox_shape = tuple(image.shape[:3]) + (3, )

    # Prepare interpolator object
    vox_grid = get_voxel_grid(image, as_points=False)

    # Get voxel regular grid centered in real space
    g_all, basis, _ = get_voxel_axes_real_space(image.image,
                                                image.affine,
                                                return_basis=True)
    g_all = list(g_all)

    # Flip axes? Must be strictly increasing
    flip = np.sign(np.diagonal(basis)) == -1
    for i, (g, f) in enumerate(zip(g_all, flip)):
        if f:
            g_all[i] = np.flip(g, 0)
            vox_grid = np.flip(vox_grid, i + 1)
    vox_points = mgrid_to_points(vox_grid).reshape(vox_shape).astype(
        np.float32)

    # Setup interpolator - takes a point in the scanner space and returns
    # the nearest voxel coordinate
    intrp = RegularGridInterpolator(tuple(g_all),
                                    vox_points,
                                    method="nearest",
                                    bounds_error=False,
                                    fill_value=np.nan,
                                    dtype=np.float32)

    # Prepare prediction volume
    pred_vol = np.zeros(shape=pred_shape, dtype=np.float32)

    # Predict on base patches first
    base_patches = sequence.get_base_patches_from(image, return_y=False)

    # Sample boxes and predict --> sum into pred_vol
    is_covered, base_reached, extra_reached, N_base, N_extra = not min_coverage, False, False, 0, 0

    while not is_covered or not base_reached or not extra_reached:
        try:
            im, rgrid, _, _, total_base = next(base_patches)
            N_base += 1

            if isinstance(total_extra_boxes, str):
                # Number specified in string format '2x', '2.5x' etc. as a
                # multiplier of number of base patches
                total_extra_boxes = int(
                    float(total_extra_boxes.split("x")[0]) * total_base)

        except StopIteration:
            p = sequence.get_N_random_patches_from(image, 1, return_y=False)
            im, rgrid, _, _ = next(p)
            N_extra += 1

        # Predict on the box
        pred = model.predict_on_batch(np.expand_dims(im, 0))[0]

        # Apply rotation if needed
        rgrid = image.interpolator.apply_rotation(rgrid)

        # Interpolate to nearest vox grid positions
        vox_inds = intrp(tuple(rgrid)).reshape(-1, 3)

        # Flatten and mask results
        mask = np.logical_not(np.all(np.isnan(vox_inds), axis=-1))
        vox_inds = [i for i in vox_inds[mask].astype(np.int).T]

        # Add to volume
        pred_vol[tuple(vox_inds)] += pred.reshape(-1, n_classes)[mask]

        # Check coverage fraction
        if min_coverage:
            covered = np.logical_not(np.all(np.isclose(pred_vol, 0), axis=-1))
            coverage = np.sum(covered) / np.prod(pred_vol.shape[:3])
            cov_string = "%.3f/%.3f" % coverage, min_coverage
            is_covered = coverage >= min_coverage
        else:
            cov_string = "[Not calculated]"

        print("   N base patches: %i/%i --- N extra patches %i/%i --- "
              "Coverage: %s" %
              (N_base, total_base, N_extra, total_extra_boxes, cov_string),
              end="\r",
              flush=True)

        # Check convergence
        base_reached = N_base >= total_base
        extra_reached = N_extra >= total_extra_boxes
    print("")

    # Return prediction volume - OBS not normalized
    return pred_vol
Exemplo n.º 10
0
    def get_patch_corners(self):
        xc = np.linspace(0, self.dim_r[0], self.strides[0]).astype(np.int)
        yc = np.linspace(0, self.dim_r[1], self.strides[1]).astype(np.int)
        zc = np.linspace(0, self.dim_r[2], self.strides[2]).astype(np.int)

        return mgrid_to_points(np.meshgrid(xc, yc, zc))