def _init_interpolators(self, image, labels, bg_value, bg_class, affine):

        # Get voxel regular grid centered in real space
        g_all, basis, rot_mat = get_voxel_axes_real_space(image, affine,
                                                          return_basis=True)
        g_all = list(g_all)

        # Set rotation matrix
        self.rot_mat = rot_mat

        # Flip axes? Must be strictly increasing
        flip = np.sign(np.diagonal(basis)) == -1

        for i, (g, f) in enumerate(zip(g_all, flip)):
            if f:
                g_all[i] = np.flip(g, 0)
                image = np.flip(image, i)
                if labels is not None:
                    labels = np.flip(labels, i)
        g_xx, g_yy, g_zz = g_all

        # Set interpolator for image, one for each channel
        im_intrps = []
        for i in range(self.n_channels):
            im_intrps.append(RegularGridInterpolator((g_xx, g_yy, g_zz),
                                                     image[..., i].squeeze(),
                                                     bounds_error=False,
                                                     fill_value=bg_value,
                                                     method="linear",
                                                     dtype=np.float32))

        try:
            # Set interpolator for labels
            lab_intrp = RegularGridInterpolator((g_xx, g_yy, g_zz), labels,
                                                bounds_error=False,
                                                fill_value=bg_class,
                                                method="nearest",
                                                dtype=np.uint8)
        except (AttributeError, TypeError, ValueError):
            lab_intrp = None

        return im_intrps, lab_intrp
示例#2
0
def map_real_space_pred(pred, grid, inv_basis, voxel_grid_real_space, method="nearest"):
    print("Mapping to real coordinate space...")

    # Prepare fill value vector, we set this to 1.0 background
    fill = np.zeros(shape=pred.shape[-1], dtype=np.float32)
    fill[0] = 1.0

    # Initialize interpolator object
    intrp = RegularGridInterpolator(grid, pred, fill_value=fill,
                                    bounds_error=False, method=method)

    points = inv_basis.dot(mgrid_to_points(voxel_grid_real_space).T).T
    transformed_grid = points_to_mgrid(points, voxel_grid_real_space[0].shape)

    # Prepare mapped pred volume
    mapped = np.empty(transformed_grid[0].shape + (pred.shape[-1],),
                      dtype=pred.dtype)

    # Prepare interpolation function
    def _do(xs, ys, zs, index):
        return intrp((xs, ys, zs)), index

    # Prepare thread pool of 10 workers
    from concurrent.futures import ThreadPoolExecutor
    from multiprocessing import cpu_count
    pool = ThreadPoolExecutor(max_workers=max(7, cpu_count()))

    # Perform interpolation async.
    inds = np.arange(transformed_grid.shape[1])
    result = pool.map(_do, transformed_grid[0], transformed_grid[1],
                      transformed_grid[2], inds)

    i = 1
    for map, ind in result:
        # Print status
        print("  %i/%i" % (i, inds[-1]+1), end="\r", flush=True)
        i += 1

        # Map the interpolation results into the volume
        mapped[ind] = map

    # Interpolate
    # mapped = intrp(tuple(transformed_grid))
    print("")
    pool.shutdown()
    return mapped
示例#3
0
def pred_3D_iso(model, sequence, image, extra_boxes, min_coverage=None):
    total_extra_boxes = extra_boxes

    # Get reference to the image
    n_classes = sequence.n_classes
    pred_shape = tuple(image.shape[:3]) + (n_classes,)
    vox_shape = tuple(image.shape[:3]) + (3,)

    # Prepare interpolator object
    vox_grid = get_voxel_grid(image, as_points=False)

    # Get voxel regular grid centered in real space
    g_all, basis, _ = get_voxel_axes_real_space(image.image, image.affine,
                                                return_basis=True)
    g_all = list(g_all)

    # Flip axes? Must be strictly increasing
    flip = np.sign(np.diagonal(basis)) == -1
    for i, (g, f) in enumerate(zip(g_all, flip)):
        if f:
            g_all[i] = np.flip(g, 0)
            vox_grid = np.flip(vox_grid, i+1)
    vox_points = mgrid_to_points(vox_grid).reshape(vox_shape).astype(np.float32)

    # Setup interpolator - takes a point in the scanner space and returns
    # the nearest voxel coordinate
    intrp = RegularGridInterpolator(tuple(g_all), vox_points,
                                    method="nearest", bounds_error=False,
                                    fill_value=np.nan, dtype=np.float32)

    # Prepare prediction volume
    pred_vol = np.zeros(shape=pred_shape, dtype=np.float32)

    # Predict on base patches first
    base_patches = sequence.get_base_patches_from(image, return_y=False)

    # Sample boxes and predict --> sum into pred_vol
    is_covered, base_reached, extra_reached, N_base, N_extra = not min_coverage, False, False, 0, 0

    while not is_covered or not base_reached or not extra_reached:
        try:
            im, rgrid, _, _, total_base = next(base_patches)
            N_base += 1

            if isinstance(total_extra_boxes, str):
                # Number specified in string format '2x', '2.5x' etc. as a
                # multiplier of number of base patches
                total_extra_boxes = int(float(total_extra_boxes.split("x")[0]) * total_base)

        except StopIteration:
            p = sequence.get_N_random_patches_from(image, 1, return_y=False)
            im, rgrid, _, _ = next(p)
            N_extra += 1

        # Predict on the box
        pred = model.predict(np.expand_dims(im, 0))[0]

        # Apply rotation if needed
        rgrid = image.interpolator.apply_rotation(rgrid)

        # Interpolate to nearest vox grid positions
        vox_inds = intrp(tuple(rgrid)).reshape(-1, 3)

        # Flatten and mask results
        mask = np.logical_not(np.all(np.isnan(vox_inds), axis=-1))
        vox_inds = [i for i in vox_inds[mask].astype(np.int).T]

        # Add to volume
        pred_vol[tuple(vox_inds)] += pred.reshape(-1, n_classes)[mask]

        # Check coverage fraction
        if min_coverage:
            covered = np.logical_not(np.all(np.isclose(pred_vol, 0), axis=-1))
            coverage = np.sum(covered) / np.prod(pred_vol.shape[:3])
            cov_string = "%.3f/%.3f" % coverage, min_coverage
            is_covered = coverage >= min_coverage
        else:
            cov_string = "[Not calculated]"

        print("   N base patches: %i/%i --- N extra patches %i/%i --- "
              "Coverage: %s" % (
                N_base, total_base, N_extra, total_extra_boxes, cov_string),
              end="\r", flush=True)

        # Check convergence
        base_reached = N_base >= total_base
        extra_reached = N_extra >= total_extra_boxes
    print("")

    # Return prediction volume - OBS not normalized
    return pred_vol