Пример #1
0
def sem_seg_inference(model, nd_array, overlay, chunk_size, num_classes, device, meta_map=None, metadata=None, output_path=Path(os.getcwd()), index=0, debug=False):
    """Inference on images using semantic segmentation
    Args:
        model: model to use for inference
        nd_array: nd_array
        overlay: amount of overlay to apply
        num_classes: number of different classes that may be predicted by the model
        device: device used by pytorch (cpu ou cuda)
        meta_map:
        metadata:
        output_path: path to save debug files
        index: (int) index of array from list of images on which inference is performed

        returns a numpy array of the same size (h,w) as the input image, where each value is the predicted output.
    """

    # switch to evaluate mode
    model.eval()

    if len(nd_array.shape) == 3:
        h, w, nb = nd_array.shape
        # Pad with overlay on left and top and pad with chunk_size on right and bottom
        padded_array = np.pad(nd_array, ((overlay, chunk_size), (overlay, chunk_size), (0, 0)), mode='constant')
    elif len(nd_array.shape) == 2:
        h, w = nd_array.shape
        padded_array = np.expand_dims(np.pad(nd_array, ((overlay, chunk_size), (overlay, chunk_size)),
                                             mode='constant'), axis=0)
    else:
        h = 0
        w = 0
        padded_array = None

    h_padded, w_padded = padded_array.shape[:2]
    # Create an empty array of dimensions (c x h x w): num_classes x height of padded array x width of padded array
    output_probs = np.empty([num_classes, h_padded, w_padded], dtype=np.float32)
    # Create identical 0-filled array without channels dimension to receive counts for number of outputs generated in specific area.
    output_counts = np.zeros([output_probs.shape[1], output_probs.shape[2]], dtype=np.int32)

    if padded_array.any():
        with torch.no_grad():
            for row in tqdm(range(overlay, h + chunk_size, chunk_size - overlay), position=1, leave=False,
                      desc=f'Inferring rows with "{device}"'):
                row_start = row - overlay
                row_end = row_start + chunk_size
                with tqdm(range(overlay, w + chunk_size, chunk_size - overlay), position=2, leave=False, desc='Inferring columns') as _tqdm:
                    for col in _tqdm:
                        col_start = col - overlay
                        col_end = col_start + chunk_size

                        chunk_input = padded_array[row_start:row_end, col_start:col_end, :]
                        if meta_map:
                            chunk_input = MetaSegmentationDataset.append_meta_layers(chunk_input, meta_map, metadata)
                        inputs = torch.from_numpy(np.float32(np.transpose(chunk_input, (2, 0, 1))))

                        inputs.unsqueeze_(0) #Add dummy batch dimension

                        inputs = inputs.to(device)
                        # forward
                        outputs = model(inputs)

                        # torchvision models give output in 'out' key. May cause problems in future versions of torchvision.
                        if isinstance(outputs, OrderedDict) and 'out' in outputs.keys():
                            outputs = outputs['out']

                        if debug:
                            if index == 0:
                                tqdm.write(f'(debug mode) Visualizing inferred tiles...')
                            vis_from_batch(params, inputs, outputs, batch_index=0, vis_path=output_path,
                                        dataset=f'{row_start}_{col_start}_inf', ep_num=index, debug=True)

                        outputs = F.softmax(outputs, dim=1)

                        output_counts[row_start:row_end, col_start:col_end] += 1

                        # Add inference on sub-image to all completed inferences on previous sub-images.
                        # FIXME: This operation need to be optimized. Using a lot of RAM on large images.
                        output_probs[:, row_start:row_end, col_start:col_end] += np.squeeze(outputs.cpu().numpy(),
                                                                                            axis=0)

                        if debug and device.type == 'cuda':
                            res, mem = gpu_stats(device=device.index)
                            _tqdm.set_postfix(OrderedDict(gpu_perc=f'{res.gpu} %',
                                                          gpu_RAM=f'{mem.used / (1024 ** 2):.0f}/{mem.total / (1024 ** 2):.0f} MiB',
                                                          inp_size=inputs.cpu().numpy().shape,
                                                          out_size=outputs.cpu().numpy().shape,
                                                          overlay=overlay))
            if debug:
                output_counts_PIL = Image.fromarray(output_counts.astype(np.uint8), mode='L')
                output_counts_PIL.save(output_path.joinpath(f'output_counts.png'))
                tqdm.write(f'Dividing array according to output counts...\n')

            # Divide array according to output counts. Manages overlap and returns a softmax array as if only one forward pass had been done.
            output_mask_raw = np.divide(output_probs, np.maximum(output_counts, 1))  # , 1 is added to overwrite 0 values.

            # Resize the output array to the size of the input image and write it
            output_mask_raw_cropped = np.moveaxis(output_mask_raw, 0, -1)
            output_mask_raw_cropped = output_mask_raw_cropped[overlay:(h + overlay), overlay:(w + overlay), :]

            return output_mask_raw_cropped
    else:
        raise IOError(f"Error classifying image : Image shape of {len(nd_array.shape)} is not recognized")
Пример #2
0
def sem_seg_inference(model,
                      nd_array,
                      overlay,
                      chunk_size,
                      num_classes,
                      device,
                      meta_map=None,
                      metadata=None):
    """Inference on images using semantic segmentation
    Args:
        model: model to use for inference
        nd_array: nd_array
        overlay: amount of overlay to apply
        num_classes: number of different classes that may be predicted by the model
        device: device used by pytorch (cpu ou cuda)

        returns a numpy array of the same size (h,w) as the input image, where each value is the predicted output.
    """

    # switch to evaluate mode
    model.eval()

    if len(nd_array.shape) == 3:
        h, w, nb = nd_array.shape
        padded_array = np.pad(nd_array, ((overlay, chunk_size),
                                         (overlay, chunk_size), (0, 0)),
                              mode='constant')
    elif len(nd_array.shape) == 2:
        h, w = nd_array.shape
        padded_array = np.expand_dims(np.pad(nd_array, ((overlay, chunk_size),
                                                        (overlay, chunk_size)),
                                             mode='constant'),
                                      axis=0)
    else:
        h = 0
        w = 0
        padded_array = None

    output_probs = np.empty(
        [num_classes, h + overlay + chunk_size, w + overlay + chunk_size],
        dtype=np.float32)
    output_counts = np.zeros([output_probs.shape[1], output_probs.shape[2]],
                             dtype=np.int32)

    if padded_array.any():
        with torch.no_grad():
            with tqdm(range(overlay, h, chunk_size - overlay),
                      position=1,
                      leave=False) as _tqdm:
                for row in _tqdm:
                    row_start = row - overlay
                    row_end = row_start + chunk_size
                    for col in range(overlay, w, chunk_size - overlay):
                        col_start = col - overlay
                        col_end = col_start + chunk_size

                        chunk_input = padded_array[row_start:row_end,
                                                   col_start:col_end, :]
                        if meta_map:
                            chunk_input = MetaSegmentationDataset.append_meta_layers(
                                chunk_input, meta_map, metadata)
                        inputs = torch.from_numpy(
                            np.float32(np.transpose(chunk_input, (2, 0, 1))))

                        inputs.unsqueeze_(0)

                        inputs = inputs.to(device)
                        # forward
                        outputs = model(inputs)

                        # torchvision models give output it 'out' key. May cause problems in future versions of torchvision.
                        if isinstance(outputs,
                                      OrderedDict) and 'out' in outputs.keys():
                            outputs = outputs['out']

                        output_counts[row_start:row_end,
                                      col_start:col_end] += 1
                        output_probs[:, row_start:row_end,
                                     col_start:col_end] += np.squeeze(
                                         outputs.cpu().numpy(), axis=0)

                    if debug and device.type == 'cuda':
                        res, mem = gpu_stats(device=device.index)
                        _tqdm.set_postfix(
                            OrderedDict(
                                device=device,
                                gpu_perc=f'{res.gpu} %',
                                gpu_RAM=
                                f'{mem.used / (1024 ** 2):.0f}/{mem.total / (1024 ** 2):.0f} MiB',
                                chunk_size=inputs.cpu().numpy().shape,
                                output_size=outputs.cpu().numpy().shape))

            output_mask = np.argmax(np.divide(output_probs,
                                              np.maximum(output_counts, 1)),
                                    axis=0)
            # Resize the output array to the size of the input image and write it
            return output_mask[overlay:(h + overlay),
                               overlay:(w + overlay)].astype(np.uint8)
    else:
        raise IOError(
            f"Error classifying image : Image shape of {len(nd_array.shape)} is not recognized"
        )