Example #1
0
def preprocess_image(im_path, n_levels, crop_shape=None, padding=None, aff_ref='FS', dist_map=False):

    # read image and corresponding info
    im, shape, aff, n_dims, n_channels, header, im_res = utils.get_volume_info(im_path, return_volume=True)

    if padding:
        im = edit_volumes.pad_volume(im, padding_shape=padding)
        pad_shape = im.shape[:n_dims]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape, length=n_dims, dtype='int')
        if not all([pad_shape[i] >= crop_shape[i] for i in range(len(pad_shape))]):
            crop_shape = [min(pad_shape[i], crop_shape[i]) for i in range(n_dims)]
        if not all([size % (2**n_levels) == 0 for size in crop_shape]):
            crop_shape = [utils.find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in crop_shape]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop_shape = [utils.find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in pad_shape]

    # crop image if necessary
    if crop_shape is not None:
        im, crop_idx = edit_volumes.crop_volume(im, cropping_shape=crop_shape, return_crop_idx=True)
    else:
        crop_idx = None

    # align image to training axes and directions
    if n_dims > 2:
        if aff_ref == 'FS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., 0., 1., 0.], [0., -1., 0., 0.], [0., 0., 0., 1.]])
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False, n_dims=n_dims)
        elif aff_ref == 'identity':
            aff_ref = np.eye(4)
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False, n_dims=n_dims)

    # normalise image
    if n_channels == 1:
        m = np.min(im)
        M = np.max(im)
        if M == m:
            im = np.zeros(im.shape)
        else:
            im = (im - m) / (M - m)
    else:
        for i in range(im.shape[-1]):
            if (not dist_map) | (dist_map & (i % 2 == 0)):
                channel = im[..., i]
                m = np.min(channel)
                M = np.max(channel)
                if M == m:
                    im[..., i] = np.zeros(channel.shape)
                else:
                    im[..., i] = (channel - m) / (M - m)

    # add batch and channel axes
    im = utils.add_axis(im) if n_channels > 1 else utils.add_axis(im, axis=[0, -1])

    return im, aff, header, im_res, n_channels, n_dims, shape, pad_shape, crop_idx
Example #2
0
 def generate_brain(self):
     """call this method when an object of this class has been instantiated to generate new brains"""
     (image, labels) = next(self.brain_generator)
     # put back images in native space
     list_images = list()
     list_labels = list()
     for i in range(self.batchsize):
         list_images.append(edit_volumes.align_volume_to_ref(image[i], np.eye(4),
                                                             aff_ref=self.aff, n_dims=self.n_dims))
         list_labels.append(edit_volumes.align_volume_to_ref(labels[i], np.eye(4),
                                                             aff_ref=self.aff, n_dims=self.n_dims))
     image = np.stack(list_images, axis=0)
     labels = np.stack(list_labels, axis=0)
     return image, labels
Example #3
0
def postprocess(prediction, pad_shape, im_shape, crop, n_dims, labels,
                keep_biggest_component, aff):

    # get posteriors and segmentation
    post_patch = np.squeeze(prediction)
    seg_patch = post_patch.argmax(-1)

    # keep biggest connected component (use it with smoothing!)
    if keep_biggest_component:
        left_mask = edit_volumes.get_largest_connected_component(
            (seg_patch > 0) & (seg_patch < 6))
        right_mask = edit_volumes.get_largest_connected_component(
            seg_patch > 5)
        seg_patch *= (left_mask | right_mask)

    # paste patches back to matrix of original image size
    if crop is not None:
        seg = np.zeros(shape=pad_shape, dtype='int32')
        posteriors = np.zeros(shape=[*pad_shape, labels.shape[0]])
        posteriors[...,
                   0] = np.ones(pad_shape)  # place background around patch
        if n_dims == 2:
            seg[crop[0]:crop[2], crop[1]:crop[3]] = seg_patch
            posteriors[crop[0]:crop[2], crop[1]:crop[3], :] = post_patch
        elif n_dims == 3:
            seg[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5]] = seg_patch
            posteriors[crop[0]:crop[3], crop[1]:crop[4],
                       crop[2]:crop[5], :] = post_patch
    else:
        seg = seg_patch
        posteriors = post_patch
    seg = labels[seg.astype('int')].astype('int')

    if im_shape != pad_shape:
        bounds = [int((p - i) / 2) for (p, i) in zip(pad_shape, im_shape)]
        bounds += [p + i for (p, i) in zip(bounds, im_shape)]
        seg = edit_volumes.crop_volume_with_idx(seg, bounds)

    # align prediction back to first orientation
    if n_dims > 2:
        seg = edit_volumes.align_volume_to_ref(seg, np.eye(4), aff_ref=aff)
        posteriors = edit_volumes.align_volume_to_ref(posteriors,
                                                      np.eye(4),
                                                      aff_ref=aff,
                                                      n_dims=n_dims)

    return seg, posteriors
Example #4
0
def postprocess(prediction, pad_shape, im_shape, crop, n_dims, labels, keep_biggest_component,
                aff, aff_ref='FS', keep_biggest_of_each_group=True, n_neutral_labels=None):

    # get posteriors and segmentation
    post_patch = np.squeeze(prediction)

    # reset posteriors to zero outside the largest connected component of each topological class
    if keep_biggest_of_each_group:
        topology_classes = np.array([0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 7, 8, 9, 10, 11, 12, 13, 14, 5])
        if n_neutral_labels != len(labels):
            left = topology_classes[n_neutral_labels:]
            topology_classes = np.concatenate([topology_classes, left + np.max(left) - np.min(left) + 1])
        unique_topology_classes = np.unique(topology_classes)
        post_patch_mask = post_patch > 0.2
        for topology_class in unique_topology_classes[1:]:
            tmp_topology_indices = np.where(topology_classes == topology_class)[0]
            tmp_mask = np.any(post_patch_mask[..., tmp_topology_indices], axis=-1)
            tmp_mask = edit_volumes.get_largest_connected_component(tmp_mask)
            for idx in tmp_topology_indices:
                post_patch[..., idx] *= tmp_mask

    # renormalise posteriors and get hard segmentation
    post_patch /= np.sum(post_patch, axis=-1)[..., np.newaxis]
    seg_patch = post_patch.argmax(-1)

    # keep biggest connected component (use it with smoothing!)
    if keep_biggest_component:
        mask = seg_patch > 0
        mask = edit_volumes.get_largest_connected_component(mask)
        seg_patch = seg_patch * mask

    # align prediction back to first orientation
    if n_dims > 2:
        if aff_ref == 'FS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., 0., 1., 0.], [0., -1., 0., 0.], [0., 0., 0., 1.]])
            seg_patch = edit_volumes.align_volume_to_ref(seg_patch, aff_ref, aff_ref=aff, return_aff=False)
            post_patch = edit_volumes.align_volume_to_ref(post_patch, aff_ref, aff_ref=aff, n_dims=n_dims)
        elif aff_ref == 'identity':
            aff_ref = np.eye(4)
            seg_patch = edit_volumes.align_volume_to_ref(seg_patch, aff_ref, aff_ref=aff, return_aff=False)
            post_patch = edit_volumes.align_volume_to_ref(post_patch, aff_ref, aff_ref=aff, n_dims=n_dims)

    # paste patches back to matrix of original image size
    if crop is not None:
        seg = np.zeros(shape=pad_shape, dtype='int32')
        posteriors = np.zeros(shape=[*pad_shape, labels.shape[0]])
        posteriors[..., 0] = np.ones(pad_shape)  # place background around patch
        if n_dims == 2:
            seg[crop[0]:crop[2], crop[1]:crop[3]] = seg_patch
            posteriors[crop[0]:crop[2], crop[1]:crop[3], :] = post_patch
        elif n_dims == 3:
            seg[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5]] = seg_patch
            posteriors[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5], :] = post_patch
    else:
        seg = seg_patch
        posteriors = post_patch
    seg = labels[seg.astype('int')].astype('int')

    if im_shape != pad_shape:
        bounds = [int((p-i)/2) for (p, i) in zip(pad_shape, im_shape)]
        bounds += [p + i for (p, i) in zip(bounds, im_shape)]
        seg = edit_volumes.crop_volume_with_idx(seg, bounds)

    return seg, posteriors
Example #5
0
# Do the actual work
print('Found %d images' % len(images_to_segment))
for idx, (path_image, path_prediction) in enumerate(
        zip(images_to_segment, path_predictions)):
    print('  Working on image %d ' % (idx + 1))
    print('  ' + path_image)

    im, aff, hdr = utils.load_volume(path_image, im_only=False, dtype='float')
    if args['ct']:
        im[im < 0] = 0
        im[im > 80] = 80
    im, aff = edit_volumes.resample_volume(im, aff, [1.0, 1.0, 1.0])
    aff_ref = np.eye(4)
    im, aff2 = edit_volumes.align_volume_to_ref(im,
                                                aff,
                                                aff_ref=aff_ref,
                                                return_aff=True,
                                                n_dims=3)
    im = im - np.min(im)
    im = im / np.max(im)
    I = im[np.newaxis, ..., np.newaxis]
    W = (np.ceil(np.array(I.shape[1:-1]) / 32.0) * 32).astype('int')
    idx = np.floor((W - I.shape[1:-1]) / 2).astype('int')
    S = np.zeros([1, *W, 1])
    S[0, idx[0]:idx[0] + I.shape[1], idx[1]:idx[1] + I.shape[2],
      idx[2]:idx[2] + I.shape[3], :] = I
    output = unet_model.predict(S)
    pred = np.squeeze(output)
    pred = 255 * pred
    pred[pred < 0] = 0
    pred[pred > 128] = 128
Example #6
0
# Do the actual work
print('Found %d images' % len(images_to_segment_t1))
for idx, (path_image_t1, path_image_t2, path_prediction) in enumerate(
        zip(images_to_segment_t1, images_to_segment_t2, path_predictions)):
    print('  Working on image %d ' % (idx + 1))
    print('  ' + path_image_t1 + ', ' + path_image_t2)

    im1, aff1, hdr1 = utils.load_volume(path_image_t1,
                                        im_only=False,
                                        dtype='float')
    im1, aff1 = edit_volumes.resample_volume(im1, aff1, [1.0, 1.0, 1.0])
    aff_ref = np.eye(4)
    im1, aff1_mod = edit_volumes.align_volume_to_ref(im1,
                                                     aff1,
                                                     aff_ref=aff_ref,
                                                     return_aff=True,
                                                     n_dims=3)
    im2, aff2, hdr2 = utils.load_volume(path_image_t2,
                                        im_only=False,
                                        dtype='float')
    im2 = edit_volumes.resample_volume_like(im1, aff1_mod, im2, aff2)

    minimum = np.min(im1)
    im1 = im1 - minimum
    spread = np.max(
        im1) / 3.0  # don't ask, it's something I messed up at training
    im1 = im1 / spread
    im2 = im2 - np.min(im2)
    im2 = im2 / np.max(
        im2) * 2.0  # don't ask, it's something I messed up at training
Example #7
0
def postprocess(post_patch,
                pad_shape,
                im_shape,
                crop,
                n_dims,
                segmentation_labels,
                lr_indices,
                keep_biggest_component,
                aff,
                topology_classes=True,
                post_patch_flip=None):

    # get posteriors and segmentation
    post_patch = np.squeeze(post_patch)
    if post_patch_flip is not None:
        post_patch_flip = edit_volumes.flip_volume(np.squeeze(post_patch_flip),
                                                   direction='rl',
                                                   aff=np.eye(4))
        if lr_indices is not None:
            post_patch_flip[..., lr_indices.flatten()] = post_patch_flip[
                ..., lr_indices[::-1].flatten()]
        post_patch = 0.5 * (post_patch + post_patch_flip)

    # keep biggest connected component (use it with smoothing!)
    if keep_biggest_component:
        tmp_post_patch = post_patch[..., 1:]
        post_patch_mask = np.sum(tmp_post_patch, axis=-1) > 0.25
        post_patch_mask = edit_volumes.get_largest_connected_component(
            post_patch_mask)
        post_patch_mask = np.stack([post_patch_mask] *
                                   tmp_post_patch.shape[-1],
                                   axis=-1)
        tmp_post_patch = edit_volumes.mask_volume(tmp_post_patch,
                                                  mask=post_patch_mask)
        post_patch[..., 1:] = tmp_post_patch

    # reset posteriors to zero outside the largest connected component of each topological class
    if topology_classes is not None:
        post_patch_mask = post_patch > 0.25
        for topology_class in np.unique(topology_classes)[1:]:
            tmp_topology_indices = np.where(
                topology_classes == topology_class)[0]
            tmp_mask = np.any(post_patch_mask[..., tmp_topology_indices],
                              axis=-1)
            tmp_mask = edit_volumes.get_largest_connected_component(tmp_mask)
            for idx in tmp_topology_indices:
                post_patch[..., idx] *= tmp_mask

    # renormalise posteriors and get hard segmentation
    if (post_patch_flip is not None) | keep_biggest_component | (
            topology_classes is not None):
        post_patch /= np.sum(post_patch, axis=-1)[..., np.newaxis]
    seg_patch = post_patch.argmax(-1)

    # paste patches back to matrix of original image size
    if crop is not None:
        seg = np.zeros(shape=pad_shape, dtype='int32')
        posteriors = np.zeros(shape=[*pad_shape, segmentation_labels.shape[0]])
        posteriors[...,
                   0] = np.ones(pad_shape)  # place background around patch
        if n_dims == 2:
            seg[crop[0]:crop[2], crop[1]:crop[3]] = seg_patch
            posteriors[crop[0]:crop[2], crop[1]:crop[3], :] = post_patch
        elif n_dims == 3:
            seg[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5]] = seg_patch
            posteriors[crop[0]:crop[3], crop[1]:crop[4],
                       crop[2]:crop[5], :] = post_patch
    else:
        seg = seg_patch
        posteriors = post_patch
    seg = segmentation_labels[seg.astype('int')].astype('int')

    if im_shape != pad_shape:
        bounds = [int((p - i) / 2) for (p, i) in zip(pad_shape, im_shape)]
        bounds += [p + i for (p, i) in zip(bounds, im_shape)]
        seg = edit_volumes.crop_volume_with_idx(seg, bounds)
        posteriors = edit_volumes.crop_volume_with_idx(posteriors,
                                                       bounds,
                                                       n_dims=n_dims)

    # align prediction back to first orientation
    if n_dims > 2:
        seg = edit_volumes.align_volume_to_ref(seg,
                                               aff=np.eye(4),
                                               aff_ref=aff,
                                               n_dims=n_dims,
                                               return_aff=False)
        posteriors = edit_volumes.align_volume_to_ref(posteriors,
                                                      aff=np.eye(4),
                                                      aff_ref=aff,
                                                      n_dims=n_dims)

    return seg, posteriors
Example #8
0
def preprocess_image(im_path,
                     n_levels,
                     target_res,
                     crop=None,
                     padding=None,
                     flip=False,
                     path_resample=None):

    # read image and corresponding info
    im, _, aff, n_dims, n_channels, header, im_res = utils.get_volume_info(
        im_path, True)

    # resample image if necessary
    if target_res is not None:
        target_res = np.squeeze(
            utils.reformat_to_n_channels_array(target_res, n_dims))
        if np.any((im_res > target_res + 0.05) | (im_res < target_res - 0.05)):
            im_res = target_res
            im, aff = edit_volumes.resample_volume(im, aff, im_res)
            if path_resample is not None:
                utils.save_volume(im, aff, header, path_resample)

    # align image
    im = edit_volumes.align_volume_to_ref(im,
                                          aff,
                                          aff_ref=np.eye(4),
                                          n_dims=n_dims)
    shape = list(im.shape)

    # pad image if specified
    if padding:
        im = edit_volumes.pad_volume(im, padding_shape=padding)
        pad_shape = im.shape[:n_dims]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop is not None:
        crop = utils.reformat_to_list(crop, length=n_dims, dtype='int')
        if not all([pad_shape[i] >= crop[i] for i in range(len(pad_shape))]):
            crop = [min(pad_shape[i], crop[i]) for i in range(n_dims)]
        if not all([size % (2**n_levels) == 0 for size in crop]):
            crop = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in crop
            ]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in pad_shape
            ]

    # crop image if necessary
    if crop is not None:
        im, crop_idx = edit_volumes.crop_volume(im,
                                                cropping_shape=crop,
                                                return_crop_idx=True)
    else:
        crop_idx = None

    # normalise image
    if n_channels == 1:
        im = edit_volumes.rescale_volume(im,
                                         new_min=0.,
                                         new_max=1.,
                                         min_percentile=0.5,
                                         max_percentile=99.5)
    else:
        for i in range(im.shape[-1]):
            im[..., i] = edit_volumes.rescale_volume(im[..., i],
                                                     new_min=0.,
                                                     new_max=1.,
                                                     min_percentile=0.5,
                                                     max_percentile=99.5)

    # flip image along right/left axis
    if flip & (n_dims > 2):
        im_flipped = edit_volumes.flip_volume(im,
                                              direction='rl',
                                              aff=np.eye(4))
        im_flipped = utils.add_axis(
            im_flipped) if n_channels > 1 else utils.add_axis(im_flipped,
                                                              axis=[0, -1])
    else:
        im_flipped = None

    # add batch and channel axes
    im = utils.add_axis(im) if n_channels > 1 else utils.add_axis(im,
                                                                  axis=[0, -1])

    return im, aff, header, im_res, n_channels, n_dims, shape, pad_shape, crop_idx, im_flipped
Example #9
0
def postprocess(prediction, crop_shape, pad_shape, im_shape, crop, n_dims, labels, keep_biggest_component,
                aff, aff_ref='FS'):

    # get posteriors and segmentation
    post_patch = np.squeeze(prediction)
    seg_patch = post_patch.argmax(-1)

    # keep biggest connected component (use it with smoothing!)
    if keep_biggest_component:
        components, n_components = label(seg_patch)
        if n_components > 1:
            unique_components = np.unique(components)
            size = 0
            mask = None
            for comp in unique_components[1:]:
                tmp_mask = components == comp
                tmp_size = np.sum(tmp_mask)
                if tmp_size > size:
                    size = tmp_size
                    mask = tmp_mask
            seg_patch[np.logical_not(mask)] = 0

    # align prediction back to first orientation
    if n_dims > 2:
        if aff_ref == 'FS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., 0., 1., 0.], [0., -1., 0., 0.], [0., 0., 0., 1.]])
            seg_patch = edit_volumes.align_volume_to_ref(seg_patch, aff_ref, aff_ref=aff, return_aff=False)
            post_patch = edit_volumes.align_volume_to_ref(post_patch, aff_ref, aff_ref=aff, n_dims=n_dims)
        elif aff_ref == 'identity':
            aff_ref = np.eye(4)
            seg_patch = edit_volumes.align_volume_to_ref(seg_patch, aff_ref, aff_ref=aff, return_aff=False)
            post_patch = edit_volumes.align_volume_to_ref(post_patch, aff_ref, aff_ref=aff, n_dims=n_dims)
        elif aff_ref == 'MS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., -1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]])
            seg_patch = edit_volumes.align_volume_to_ref(seg_patch, aff_ref, aff_ref=aff, return_aff=False)
            post_patch = edit_volumes.align_volume_to_ref(post_patch, aff_ref, aff_ref=aff, n_dims=n_dims)

    # paste patches back to matrix of original image size
    if crop_shape is not None:
        seg = np.zeros(shape=pad_shape, dtype='int32')
        posteriors = np.zeros(shape=[*pad_shape, labels.shape[0]])
        posteriors[..., 0] = np.ones(pad_shape)  # place background around patch
        if n_dims == 2:
            seg[crop[0]:crop[2], crop[1]:crop[3]] = seg_patch
            posteriors[crop[0]:crop[2], crop[1]:crop[3], :] = post_patch
        elif n_dims == 3:
            seg[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5]] = seg_patch
            posteriors[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5], :] = post_patch
    else:
        seg = seg_patch
        posteriors = post_patch
    seg = labels[seg.astype('int')].astype('int')

    if im_shape != pad_shape:
        lower_bound = [int((p-i)/2) for (p, i) in zip(pad_shape, im_shape)]
        upper_bound = [p-int((p-i)/2) for (p, i) in zip(pad_shape, im_shape)]
        if n_dims == 2:
            seg = seg[lower_bound[0]:upper_bound[0], lower_bound[1]:upper_bound[1]]
        elif n_dims == 3:
            seg = seg[lower_bound[0]:upper_bound[0], lower_bound[1]:upper_bound[1], lower_bound[2]:upper_bound[2]]

    return seg, posteriors
Example #10
0
def preprocess_image(im_path, n_levels, crop_shape=None, padding=None, aff_ref='FS'):

    # read image and corresponding info
    im, shape, aff, n_dims, n_channels, header, im_res = utils.get_volume_info(im_path, return_volume=True)

    if padding:
        if n_channels == 1:
            im = np.pad(im, padding, mode='constant')
            pad_shape = im.shape
        else:
            im = np.pad(im, tuple([(padding, padding)] * n_dims + [(0, 0)]), mode='constant')
            pad_shape = im.shape[:-1]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape, length=n_dims, dtype='int')
        if not all([pad_shape[i] >= crop_shape[i] for i in range(len(pad_shape))]):
            crop_shape = [min(pad_shape[i], crop_shape[i]) for i in range(n_dims)]
            print('cropping dimensions are higher than image size, changing cropping size to {}'.format(crop_shape))
        if not all([size % (2**n_levels) == 0 for size in crop_shape]):
            crop_shape = [utils.find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in crop_shape]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop_shape = [utils.find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in pad_shape]

    # crop image if necessary
    if crop_shape is not None:
        crop_idx = np.round((pad_shape - np.array(crop_shape)) / 2).astype('int')
        crop_idx = np.concatenate((crop_idx, crop_idx + crop_shape), axis=0)
        im = edit_volumes.crop_volume_with_idx(im, crop_idx=crop_idx)
    else:
        crop_idx = None

    # align image to training axes and directions
    if n_dims > 2:
        if aff_ref == 'FS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., 0., 1., 0.], [0., -1., 0., 0.], [0., 0., 0., 1.]])
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False)
        elif aff_ref == 'identity':
            aff_ref = np.eye(4)
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False)
        elif aff_ref == 'MS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., -1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]])
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False)

    # normalise image
    if n_channels == 1:
        m = np.min(im)
        M = np.max(im)
        if M == m:
            im = np.zeros(im.shape)
        else:
            im = (im - m) / (M - m)
    if n_channels > 1:
        for i in range(im.shape[-1]):
            channel = im[..., i]
            m = np.min(channel)
            M = np.max(channel)
            if M == m:
                im[..., i] = np.zeros(channel.shape)
            else:
                im[..., i] = (channel - m) / (M - m)

    # add batch and channel axes
    if n_channels > 1:
        im = utils.add_axis(im)
    else:
        im = utils.add_axis(im, -2)

    return im, aff, header, im_res, n_channels, n_dims, shape, pad_shape, crop_shape, crop_idx