Exemplo n.º 1
0
def segment_vol_by_slice(segmenter_model,
                         X,
                         label_mapping,
                         batch_size=8,
                         Y_oh=None,
                         compute_cce=False):
    '''
    Segments a 3D volume by running a per-slice segmenter on batches of slices
    :param segmenter_model:
    :param X: 3D volume, we assume this has a batch size of 1
    :param label_mapping:
    :param batch_size:
    :return:
    '''
    n_slices = X.shape[-2]
    n_labels = len(label_mapping)
    preds = np.zeros(X.shape[:-1] + (1, ))
    n_batches = int(np.ceil(float(n_slices) / batch_size))

    cce_total = 0.
    for sbi in range(n_batches):
        # slice in z, then make slices into batch
        X_batched_slices = np.transpose(
            X[0, :, :, sbi * batch_size:min(n_slices, (sbi + 1) * batch_size)],
            (2, 0, 1, 3))

        preds_slices_oh = segmenter_model.predict(X_batched_slices)
        if compute_cce:
            slice_cce = segmenter_model.evaluate(
                X_batched_slices,
                np.transpose(
                    Y_oh[0, :, :, sbi * batch_size:min(n_slices, (sbi + 1) *
                                                       batch_size)],
                    (2, 0, 1, 3)),
                verbose=False)
            # if we have multiple losses, take the first one
            if isinstance(slice_cce, list):
                slice_cce = slice_cce[0]

            # we want an average over slices, so make sure we count the correct number in the batch
            cce_total += slice_cce * X_batched_slices.shape[0]
        # convert onehot to labels and assign to preds volume
        preds[0, :, :, sbi * batch_size: min(n_slices, (sbi + 1) * batch_size)] \
            = np.transpose(classification_utils.onehot_to_labels(
            preds_slices_oh, label_mapping=label_mapping), (1, 2, 0))[..., np.newaxis]
    if compute_cce:
        return preds, cce_total / float(n_slices)
    else:
        return preds
Exemplo n.º 2
0
    def _make_results_im(self, input_im_batches, labels,
                        overlay_on_ims=None,
                        do_normalize=None, is_seg=None,
                         max_batch_size=32):
        # batch_size = inputs_im.shape[0]
        batch_size = self.batch_size
        display_batch_size = min(max_batch_size, batch_size)
        zeros_batch = np.zeros((batch_size,) + self.pred_img_shape)

        if do_normalize is None:
            do_normalize = [False] * len(input_im_batches)
        if is_seg is None:
            is_seg = [False] * len(input_im_batches)

        if display_batch_size < batch_size:
            input_im_batches = [batch[:display_batch_size] for batch in input_im_batches]
            overlay_on_ims = [im[:display_batch_size] if im is not None else None for im in overlay_on_ims]

        show_label_idx = 12 # cerebral wm
        out_im = np.concatenate([
            vis_utils.label_ims(batch, labels[i], inverse_normalize=do_normalize[i]) if not is_seg[i] else
            np.concatenate([  # we want two images here: overlay and a single label
                vis_utils.label_ims(np.transpose(
                        utils.overlay_segs_on_ims_batch(
                            ims=np.transpose(overlay_on_ims[i], (1, 2, 3, 0)),
                            segs=np.transpose(
                                classification_utils.onehot_to_labels(
                                    batch, label_mapping=self.label_mapping), (1, 2, 0)),
                            include_labels=self.label_mapping,
                            draw_contours=True,
                        ),
                        (3, 0, 1, 2)), []),
                vis_utils.label_ims(batch[..., [show_label_idx]],
                                    'label {}'.format(self.label_mapping[show_label_idx]), normalize=True)], axis=1) \
            for i, batch in enumerate(input_im_batches) if batch is not None
        ], axis=1)

        return out_im
Exemplo n.º 3
0
def label_ims(ims_batch,
              labels=None,
              inverse_normalize=False,
              normalize=False,
              clip_flow=10,
              display_h=128,
              pad_top=None,
              clip_norm=None,
              padding_size=0,
              padding_color=255,
              border_size=0,
              border_color=0,
              color_space='rgb',
              combine_from_axis=0,
              concat_axis=0,
              interp=cv2.INTER_LINEAR):
    '''
    Displays a batch of matrices as an image.

    :param ims_batch: n_batches x h x w x c array of images.
    :param labels: optional labels. Can be an n_batches length list of tuples, floats or strings
    :param inverse_normalize: boolean to do normalization from [-1, 1] to [0, 255]
    :param normalize: boolean to normalize any [min, max] to [0, 255]
    :param clip_flow: float for the min, max absolute flow magnitude to display
    :param display_h: integer number of pixels for the height of each image to display
    :param pad_top: integer number of pixels to pad each image at the top with (for more readable labels)
    :param color_space: string of either 'rgb' or 'ycbcr' to do color space conversion before displaying
    :param concat_axis: integer axis number to concatenate batch along (default is 0 for rows)

    :return:
    '''

    if isinstance(ims_batch, np.ndarray) and len(
            ims_batch.shape) == 3 and ims_batch.shape[-1] == 3:
        # already an image
        return ims_batch

    # transpose the image until batches are in the 0th axis
    if not combine_from_axis == 0:
        # compute all remaining axes
        all_axes = list(range(len(ims_batch.shape)))
        del all_axes[combine_from_axis]
        ims_batch = np.transpose(ims_batch,
                                 (combine_from_axis, ) + tuple(all_axes))

    batch_size = len(ims_batch)  # works for lists and np arrays
    h = ims_batch[0].shape[0]
    w = ims_batch[0].shape[1]
    if len(ims_batch[0].shape) == 2:
        n_chans = 1
    else:
        n_chans = ims_batch[0].shape[-1]

    if type(labels) == list and len(labels) == 1:  # only label the first image
        labels = labels + [''] * (batch_size - 1)
    elif labels is not None and not type(labels) == list and not type(
            labels) == np.ndarray:
        labels = [labels] * batch_size

    scale_factor = display_h / float(h)

    if pad_top:
        im_h = int(display_h + pad_top)
    else:
        im_h = display_h
        im_w = round(scale_factor * float(w))

    # make sure we have a channels dimension
    if len(ims_batch.shape) < 4:
        ims_batch = np.expand_dims(ims_batch, 3)

    if ims_batch.shape[-1] == 2:  # assume to be x,y flow; map to color im
        X_fullcolor = np.concatenate(
            [ims_batch.copy(),
             np.zeros(ims_batch.shape[:-1] + (1, ))], axis=3)

        if labels is not None:
            labels = [''] * batch_size

        for i in range(batch_size):
            X_fullcolor[i], min_flow, max_flow = flow_to_im(
                ims_batch[i], clip_flow=clip_flow)

            # also include the min and max flow in  the label
            if labels[i] is not None:
                labels[i] = '{},'.format(labels[i])
            else:
                labels[i] = ''

            for c in range(len(min_flow)):
                labels[i] += '({}, {})'.format(round(min_flow[c], 1),
                                               round(max_flow[c], 1))
        ims_batch = X_fullcolor.copy()
    elif ims_batch.shape[-1] > 3:
        # not an image, probably labels

        n_labels = ims_batch.shape[-1]
        cmap = make_cmap_rainbow(n_labels)

        labels_im = classification_utils.onehot_to_labels(
            ims_batch, n_classes=ims_batch.shape[-1])
        labels_im_flat = labels_im.flatten()
        labeled_im_flat = np.tile(labels_im_flat[..., np.newaxis],
                                  (1, 3)).astype(np.float32)

        #for ei in range(batch_size):
        for l in range(n_labels):
            labeled_im_flat[labels_im_flat == l, :] = cmap[l]
        ims_batch = labeled_im_flat.reshape((-1, ) + ims_batch.shape[1:-1] +
                                            (3, ))

    elif inverse_normalize:
        ims_batch = image_utils.inverse_normalize(ims_batch)

    elif normalize:
        flattened_dims = np.prod(ims_batch.shape[1:])

        X_spatially_flat = np.reshape(ims_batch, (batch_size, -1, n_chans))
        X_orig_min = np.min(X_spatially_flat, axis=1)
        X_orig_max = np.max(X_spatially_flat, axis=1)

        # now actually flatten and normalize across channels
        X_flat = np.reshape(ims_batch, (batch_size, -1))
        if clip_norm is None:
            X_flat = X_flat - np.tile(np.min(X_flat, axis=1, keepdims=True),
                                      (1, flattened_dims))
            # avoid dividing by 0
            X_flat = X_flat / np.clip(
                np.tile(np.max(X_flat, axis=1, keepdims=True),
                        (1, flattened_dims)), 1e-5, None)
        else:
            X_flat = X_flat - (-float(clip_norm))
            # avoid dividing by 0
            X_flat = X_flat / (2. * clip_norm)
            #X_flat = X_flat - np.tile(np.min(X_flat, axis=1, keepdims=True), (1, flattened_dims))
            # avoid dividing by 0
            #X_flat = X_flat / np.clip(np.tile(np.max(X_flat, axis=1, keepdims=True), (1, flattened_dims)), 1e-5, None)

        ims_batch = np.reshape(X_flat, ims_batch.shape)
        ims_batch = np.clip(ims_batch.astype(np.float32), 0., 1.)
        for i in range(batch_size):
            if labels is not None and len(labels) > 0:
                if labels[i] is not None:
                    labels[i] = '{},'.format(labels[i])
                else:
                    labels[i] = ''
                # show the min, max of each channel
                for c in range(n_chans):
                    labels[i] += '({:.2f}, {:.2f})'.format(
                        round(X_orig_min[i, c], 2), round(X_orig_max[i, c], 2))
    else:
        ims_batch = np.clip(ims_batch, 0., 1.)

    if color_space == 'ycbcr':
        for i in range(batch_size):
            ims_batch[i] = cv2.cvtColor(ims_batch[i], cv2.COLOR_YCR_CB2BGR)

    if np.max(ims_batch) <= 1.0:
        ims_batch = ims_batch * 255.0

    out_im = []
    for i in range(batch_size):
        # convert grayscale to rgb if needed
        if len(ims_batch[i].shape) == 2:
            curr_im = np.tile(np.expand_dims(ims_batch[i], axis=-1), (1, 1, 3))
        elif ims_batch.shape[-1] == 1:
            curr_im = np.tile(ims_batch[i], (1, 1, 3))
        else:
            curr_im = ims_batch[i]

        # scale to specified display size
        if not scale_factor == 1:
            curr_im = cv2.resize(curr_im,
                                 None,
                                 fx=scale_factor,
                                 fy=scale_factor,
                                 interpolation=interp)

        if pad_top:
            curr_im = np.concatenate([
                np.zeros(
                    (pad_top, curr_im.shape[1], curr_im.shape[2])), curr_im
            ],
                                     axis=0)

        if border_size > 0:
            # add a border all around the image
            curr_im = cv2.copyMakeBorder(curr_im,
                                         border_size,
                                         border_size,
                                         border_size,
                                         border_size,
                                         borderType=cv2.BORDER_CONSTANT,
                                         value=border_color)

        if padding_size > 0 and i < batch_size - 1:
            # include a border between images
            padding_shape = list(curr_im.shape[:3])
            padding_shape[concat_axis] = padding_size

            curr_im = np.concatenate(
                [curr_im, np.ones(padding_shape) * padding_color],
                axis=concat_axis)

        out_im.append(curr_im)

    if display_h > 50:
        font_size = 15
    else:
        font_size = 10

    if concat_axis is not None:
        out_im = np.concatenate(out_im, axis=concat_axis).astype(np.uint8)
    else:
        out_im = np.concatenate(out_im, axis=0).astype(np.uint8)

    max_text_width = int(17 * display_h / 128.)  # empirically determined
    if labels is not None and len(labels) > 0:
        im_pil = Image.fromarray(out_im)
        draw = ImageDraw.Draw(im_pil)

        for i in range(batch_size):
            if len(labels) > i:  # if we have a label for this image
                if type(labels[i]) == tuple or type(labels[i]) == list:
                    # format tuple or list nicely
                    formatted_text = ', '.join([
                        labels[i][j].decode('UTF-8') if type(labels[i][j]) == np.unicode_ \
                            else labels[i][j] if type(labels[i][j]) == str \
                            else str(round(labels[i][j], 2)) if isinstance(labels[i][j], float) \
                            else str(labels[i][j]) for j in range(len(labels[i]))])
                elif type(labels[i]) == float or type(labels[i]) == np.float32:
                    formatted_text = str(round(labels[i],
                                               2))  # round floats to 2 digits
                elif isinstance(labels[i], np.ndarray):
                    # assume that this is a 1D array
                    curr_labels = np.squeeze(labels[i]).astype(np.float32)
                    formatted_text = np.array2string(curr_labels,
                                                     precision=2,
                                                     separator=',')
                    # ', '.join(['{}'.format(
                    #	np.around(labels[i][j], 2)) for j in range(labels[i].size)])
                else:
                    formatted_text = '{}'.format(labels[i])

                if display_h > 30:  # only print label if we have room
                    try:
                        font = ImageFont.truetype('Ubuntu-M.ttf', font_size)
                    except:
                        font = ImageFont.truetype('arial.ttf', font_size)
                    # wrap the text so it fits
                    formatted_text = textwrap.wrap(formatted_text,
                                                   width=max_text_width)

                    for li, line in enumerate(formatted_text):
                        if concat_axis == 0:
                            draw.text((5, i * im_h + 5 + 14 * li),
                                      line,
                                      font=font,
                                      fill=(50, 50, 255))
                        elif concat_axis == 1:
                            draw.text((5 + i * im_w, 5 + 14 * li),
                                      line,
                                      font=font,
                                      fill=(50, 50, 255))

        out_im = np.asarray(im_pil)

    # else:
    #     out_im = [im.astype(np.uint8) for im in out_im]
    #
    #     max_text_width = int(17 * display_h / 128.)  # empirically determined
    #     if labels is not None and len(labels) > 0:
    #         for i, im in enumerate(out_im):
    #             im_pil = Image.fromarray(im)
    #             draw = ImageDraw.Draw(im_pil)
    #
    #
    #             if len(labels) > i:  # if we have a label for this image
    #                 if type(labels[i]) == tuple or type(labels[i]) == list:
    #                     # format tuple or list nicely
    #                     formatted_text = ', '.join([
    #                         labels[i][j].decode('UTF-8') if type(labels[i][j]) == np.unicode_ \
    #                             else labels[i][j] if type(labels[i][j]) == str \
    #                             else str(round(labels[i][j], 2)) if isinstance(labels[i][j], float) \
    #                             else str(labels[i][j]) for j in range(len(labels[i]))])
    #                 elif type(labels[i]) == float or type(labels[i]) == np.float32:
    #                     formatted_text = str(round(labels[i], 2))  # round floats to 2 digits
    #                 elif isinstance(labels[i], np.ndarray):
    #                     # assume that this is a 1D array
    #                     curr_labels = np.squeeze(labels[i]).astype(np.float32)
    #                     formatted_text = np.array2string(curr_labels, precision=2, separator=',')
    #                     # ', '.join(['{}'.format(
    #                     #	np.around(labels[i][j], 2)) for j in range(labels[i].size)])
    #                 else:
    #                     formatted_text = '{}'.format(labels[i])
    #
    #                 if display_h > 30:  # only print label if we have room
    #                     try:
    #                         font = ImageFont.truetype('Ubuntu-M.ttf', font_size)
    #                     except:
    #                         font = ImageFont.truetype('arial.ttf', font_size)
    #                     # wrap the text so it fits
    #                     formatted_text = textwrap.wrap(formatted_text, width=max_text_width)
    #
    #                     for li, line in enumerate(formatted_text):
    #                         draw.text((5, 5 + 14 * li), line, font=font, fill=(50, 50, 255))
    #             im = np.asarray(im_pil)
    if concat_axis is None:
        # un-concat the image. faster this way
        out_im = np.split(out_im, batch_size, axis=combine_from_axis)
    return out_im