コード例 #1
0
def predict(model, dataset, out_channels, device):
    """
    Return prediction masks by applying the model on the given dataset

    Args:
        model (Unet3D): trained 3D UNet model used for prediction
        dataset (torch.utils.data.Dataset): input dataset
        out_channels (int): number of channels in the network output
        device (torch.Device): device to run the prediction on

    Returns:
         probability_maps (numpy array): prediction masks for given dataset
    """
    logger.info(f'Running prediction on {len(dataset)} patches...')
    # dimensionality of the the output (CxDxHxW)
    dataset_shape = dataset.raw.shape
    if len(dataset_shape) == 3:
        volume_shape = dataset_shape
    else:
        volume_shape = dataset_shape[1:]
    probability_maps_shape = (out_channels, ) + volume_shape
    logger.info(
        f'Shape of the output probability map: {probability_maps_shape}')
    # initialize the output prediction array
    probability_maps = np.zeros(probability_maps_shape, dtype='float32')

    # initialize normalization mask in order to average out probabilities
    # of overlapping patches
    normalization_mask = np.zeros(probability_maps_shape, dtype='float32')

    # Sets the module in evaluation mode explicitly, otherwise the final Softmax/Sigmoid won't be applied!
    model.eval()
    with torch.no_grad():
        for patch, index in dataset:
            logger.info(f'Predicting slice:{index}')

            # save patch index: (C,) + (D,H,W)
            channel_slice = slice(0, out_channels)
            index = (channel_slice, ) + index

            # convert patch to torch tensor NxCxDxHxW and send to device
            # we're using batch size of 1
            patch = patch.view((1, ) + patch.shape).to(device)

            # forward pass
            probs = model(patch)
            # convert back to numpy array
            probs = probs.squeeze().cpu().numpy()
            # for out_channel == 1 we need to expand back to 4D
            if probs.ndim == 3:
                probs = np.expand_dims(probs, axis=0)
            # unpad in order to avoid block artifacts in the output probability maps
            probs, index = utils.unpad(probs, index, volume_shape)
            # accumulate probabilities into the output prediction array
            probability_maps[index] += probs
            # count voxel visits for normalization
            normalization_mask[index] += 1

    return probability_maps / normalization_mask
コード例 #2
0
    def predict(self):
        out_channels = self.config['model'].get('out_channels')
        if out_channels is None:
            out_channels = self.config['model']['dt_out_channels']

        prediction_channel = self.config.get('prediction_channel', None)
        if prediction_channel is not None:
            self.logger.info(
                f"Using only channel '{prediction_channel}' from the network output"
            )

        device = self.config['device']
        output_heads = self.config['model'].get('output_heads', 1)

        self.logger.info(
            f'Running prediction on {len(self.loader)} batches...')

        # dimensionality of the the output predictions
        volume_shape = self._volume_shape(self.loader.dataset)
        if prediction_channel is None:
            prediction_maps_shape = (out_channels, ) + volume_shape
        else:
            # single channel prediction map
            prediction_maps_shape = (1, ) + volume_shape

        self.logger.info(
            f'The shape of the output prediction maps (CDHW): {prediction_maps_shape}'
        )

        avoid_block_artifacts = self.predictor_config.get(
            'avoid_block_artifacts', True)
        self.logger.info(f'Avoid block artifacts: {avoid_block_artifacts}')

        # create destination H5 file
        h5_output_file = h5py.File(self.output_file, 'w')

        # allocate prediction and normalization arrays
        self.logger.info('Allocating prediction and normalization arrays...')
        prediction_maps, normalization_masks = self._allocate_prediction_maps(
            prediction_maps_shape, output_heads, h5_output_file)

        # Sets the module in evaluation mode explicitly, otherwise the final Softmax/Sigmoid won't be applied!
        self.model.eval()
        # Run predictions on the entire input dataset
        with torch.no_grad():
            for batch, indices in self.loader:
                # send batch to device
                batch = batch.to(device)

                # forward pass
                predictions = self.model(batch)

                # wrap predictions into a list if there is only one output head from the network
                if output_heads == 1:
                    predictions = [predictions]

                # for each output head
                for prediction, prediction_map, normalization_mask in zip(
                        predictions, prediction_maps, normalization_masks):

                    # convert to numpy array
                    prediction = prediction.cpu().numpy()

                    # for each batch sample
                    for pred, index in zip(prediction, indices):
                        # save patch index: (C,D,H,W)
                        if prediction_channel is None:
                            channel_slice = slice(0, out_channels)
                        else:
                            channel_slice = slice(0, 1)
                        index = (channel_slice, ) + index

                        if prediction_channel is not None:
                            # use only the 'prediction_channel'
                            self.logger.info(
                                f"Using channel '{prediction_channel}'...")
                            pred = np.expand_dims(pred[prediction_channel],
                                                  axis=0)

                        self.logger.info(
                            f'Saving predictions for slice:{index}...')

                        if avoid_block_artifacts:
                            # unpad in order to avoid block artifacts in the output probability maps
                            u_prediction, u_index = unpad(
                                pred, index, volume_shape)
                            # accumulate probabilities into the output prediction array
                            prediction_map[u_index] += u_prediction
                            # count voxel visits for normalization
                            normalization_mask[u_index] += 1
                        else:
                            # accumulate probabilities into the output prediction array
                            prediction_map[index] += pred
                            # count voxel visits for normalization
                            normalization_mask[index] += 1

        # save results to
        self._save_results(prediction_maps, normalization_masks, output_heads,
                           h5_output_file, self.loader.dataset)

        # close the output H5 file
        h5_output_file.close()
コード例 #3
0
ファイル: predict.py プロジェクト: wjx2/pytorch-3dunet
def predict(model, hdf5_dataset, config):
    """
    Return prediction masks by applying the model on the given dataset

    Args:
        model (Unet3D): trained 3D UNet model used for prediction
        hdf5_dataset (torch.utils.data.Dataset): input dataset
        out_channels (int): number of channels in the network output
        device (torch.Device): device to run the prediction on

    Returns:
         prediction_maps (numpy array): prediction masks for given dataset
    """

    def _volume_shape(hdf5_dataset):
        #TODO: support multiple internal datasets
        raw = hdf5_dataset.raws[0]
        if raw.ndim == 3:
            return raw.shape
        else:
            return raw.shape[1:]

    out_channels = config['model'].get('out_channels')
    if out_channels is None:
        out_channels = config['model']['dt_out_channels']

    device = config['device']
    output_heads = config['model'].get('output_heads', 1)

    logger.info(f'Running prediction on {len(hdf5_dataset)} patches...')
    # dimensionality of the the output (CxDxHxW)
    volume_shape = _volume_shape(hdf5_dataset)
    prediction_maps_shape = (out_channels,) + volume_shape
    logger.info(f'The shape of the output prediction maps (CDHW): {prediction_maps_shape}')

    # initialize the output prediction arrays
    prediction_maps = [np.zeros(prediction_maps_shape, dtype='float32') for _ in range(output_heads)]
    # initialize normalization mask in order to average out probabilities of overlapping patches
    normalization_masks = [np.zeros(prediction_maps_shape, dtype='float32') for _ in range(output_heads)]

    # Sets the module in evaluation mode explicitly, otherwise the final Softmax/Sigmoid won't be applied!
    model.eval()
    # Run predictions on the entire input dataset
    with torch.no_grad():
        for patch, index in hdf5_dataset:
            logger.info(f'Predicting slice:{index}')

            # save patch index: (C,D,H,W)
            channel_slice = slice(0, out_channels)
            index = (channel_slice,) + index

            # convert patch to torch tensor NxCxDxHxW and send to device we're using batch size of 1
            patch = patch.unsqueeze(dim=0).to(device)

            # forward pass
            predictions = model(patch)
            # wrap predictions into a list if there is only one output head from the network
            if output_heads == 1:
                predictions = [predictions]

            for prediction, prediction_map, normalization_mask in zip(predictions, prediction_maps,
                                                                      normalization_masks):
                # squeeze batch dimension and convert back to numpy array
                prediction = prediction.squeeze(dim=0).cpu().numpy()
                # unpad in order to avoid block artifacts in the output probability maps
                u_prediction, u_index = utils.unpad(prediction, index, volume_shape)
                # accumulate probabilities into the output prediction array
                prediction_map[u_index] += u_prediction
                # count voxel visits for normalization
                normalization_mask[u_index] += 1

    return [prediction_map / normalization_mask for prediction_map, normalization_mask in
            zip(prediction_maps, normalization_masks)]
コード例 #4
0
def predict(model, data_loader, output_file, config):
    """
    Return prediction masks by applying the model on the given dataset

    Args:
        model (Unet3D): trained 3D UNet model used for prediction
        data_loader (torch.utils.data.DataLoader): input data loader
        output_file (str): path to the output H5 file
        config (dict): global config dict

    """
    def _volume_shape(dataset):
        # TODO: support multiple internal datasets
        raw = dataset.raws[0]
        if raw.ndim == 3:
            return raw.shape
        else:
            return raw.shape[1:]

    out_channels = config['model'].get('out_channels')
    if out_channels is None:
        out_channels = config['model']['dt_out_channels']

    prediction_channel = config.get('prediction_channel', None)
    if prediction_channel is not None:
        logger.info(
            f"Using only channel '{prediction_channel}' from the network output"
        )

    device = config['device']
    output_heads = config['model'].get('output_heads', 1)

    logger.info(f'Running prediction on {len(data_loader)} patches...')

    # dimensionality of the the output (CxDxHxW)
    volume_shape = _volume_shape(data_loader.dataset)
    if prediction_channel is None:
        prediction_maps_shape = (out_channels, ) + volume_shape
    else:
        # single channel prediction map
        prediction_maps_shape = (1, ) + volume_shape

    logger.info(
        f'The shape of the output prediction maps (CDHW): {prediction_maps_shape}'
    )

    with h5py.File(output_file, 'w') as f:
        # allocate datasets for probability maps
        prediction_datasets = _get_dataset_names(config,
                                                 output_heads,
                                                 prefix='predictions')
        prediction_maps = [
            f.create_dataset(dataset_name,
                             shape=prediction_maps_shape,
                             dtype='float32',
                             chunks=True,
                             compression='gzip')
            for dataset_name in prediction_datasets
        ]

        # allocate datasets for normalization masks
        normalization_datasets = _get_dataset_names(config,
                                                    output_heads,
                                                    prefix='normalization')
        normalization_masks = [
            f.create_dataset(dataset_name,
                             shape=prediction_maps_shape,
                             dtype='uint8',
                             chunks=True,
                             compression='gzip')
            for dataset_name in normalization_datasets
        ]

        # Sets the module in evaluation mode explicitly, otherwise the final Softmax/Sigmoid won't be applied!
        model.eval()
        # Run predictions on the entire input dataset
        with torch.no_grad():
            for patch, index in data_loader:
                logger.info(f'Predicting slice:{index}')

                # save patch index: (C,D,H,W)
                if prediction_channel is None:
                    channel_slice = slice(0, out_channels)
                else:
                    channel_slice = slice(0, 1)

                index = (channel_slice, ) + tuple(index)

                # send patch to device
                patch = patch.to(device)
                # forward pass
                predictions = model(patch)

                # wrap predictions into a list if there is only one output head from the network
                if output_heads == 1:
                    predictions = [predictions]

                for prediction, prediction_map, normalization_mask in zip(
                        predictions, prediction_maps, normalization_masks):
                    # squeeze batch dimension and convert back to numpy array
                    prediction = prediction.squeeze(dim=0).cpu().numpy()
                    if prediction_channel is not None:
                        # use only the 'prediction_channel'
                        logger.info(f"Using channel '{prediction_channel}'...")
                        prediction = np.expand_dims(
                            prediction[prediction_channel], axis=0)

                    # unpad in order to avoid block artifacts in the output probability maps
                    u_prediction, u_index = utils.unpad(
                        prediction, index, volume_shape)
                    # accumulate probabilities into the output prediction array
                    prediction_map[u_index] += u_prediction
                    # count voxel visits for normalization
                    normalization_mask[u_index] += 1

        # normalize the prediction_maps inside the H5
        for prediction_map, normalization_mask, prediction_dataset, normalization_dataset in zip(
                prediction_maps, normalization_masks, prediction_datasets,
                normalization_datasets):
            # TODO: iterate block by block
            # split the volume into 4 parts and load each into the memory separately
            logger.info(f'Normalizing {prediction_dataset}...')
            z, y, x = volume_shape
            mid_x = x // 2
            mid_y = y // 2
            prediction_map[:, :, 0:mid_y,
                           0:mid_x] /= normalization_mask[:, :, 0:mid_y,
                                                          0:mid_x]
            prediction_map[:, :, mid_y:,
                           0:mid_x] /= normalization_mask[:, :, mid_y:,
                                                          0:mid_x]
            prediction_map[:, :, 0:mid_y,
                           mid_x:] /= normalization_mask[:, :, 0:mid_y, mid_x:]
            prediction_map[:, :, mid_y:,
                           mid_x:] /= normalization_mask[:, :, mid_y:, mid_x:]
            logger.info(f'Deleting {normalization_dataset}...')
            del f[normalization_dataset]
コード例 #5
0
def predict_in_memory(model, data_loader, output_file, config):
    """
    Return prediction masks by applying the model on the given dataset

    Args:
        model (Unet3D): trained 3D UNet model used for prediction
        data_loader (torch.utils.data.DataLoader): input data loader
        output_file (str): path to the output H5 file
        config (dict): global config dict

    Returns:
         prediction_maps (numpy array): prediction masks for given dataset
    """
    def _volume_shape(dataset):
        # TODO: support multiple internal datasets
        raw = dataset.raws[0]
        if raw.ndim == 3:
            return raw.shape
        else:
            return raw.shape[1:]

    out_channels = config['model'].get('out_channels')
    if out_channels is None:
        out_channels = config['model']['dt_out_channels']

    prediction_channel = config.get('prediction_channel', None)
    if prediction_channel is not None:
        logger.info(
            f"Using only channel '{prediction_channel}' from the network output"
        )

    device = config['device']
    output_heads = config['model'].get('output_heads', 1)

    logger.info(f'Running prediction on {len(data_loader)} patches...')
    # dimensionality of the the output (CxDxHxW)
    volume_shape = _volume_shape(data_loader.dataset)
    if prediction_channel is None:
        prediction_maps_shape = (out_channels, ) + volume_shape
    else:
        # single channel prediction map
        prediction_maps_shape = (1, ) + volume_shape

    logger.info(
        f'The shape of the output prediction maps (CDHW): {prediction_maps_shape}'
    )

    # initialize the output prediction arrays
    prediction_maps = [
        np.zeros(prediction_maps_shape, dtype='float32')
        for _ in range(output_heads)
    ]
    # initialize normalization mask in order to average out probabilities of overlapping patches
    normalization_masks = [
        np.zeros(prediction_maps_shape, dtype='float32')
        for _ in range(output_heads)
    ]

    # Sets the module in evaluation mode explicitly, otherwise the final Softmax/Sigmoid won't be applied!
    model.eval()
    # Run predictions on the entire input dataset
    with torch.no_grad():
        for patch, index in data_loader:
            logger.info(f'Predicting slice:{index}')

            # save patch index: (C,D,H,W)
            if prediction_channel is None:
                channel_slice = slice(0, out_channels)
            else:
                channel_slice = slice(0, 1)

            index = (channel_slice, ) + tuple(index)

            # send patch to device
            patch = patch.to(device)
            # forward pass
            predictions = model(patch)

            # wrap predictions into a list if there is only one output head from the network
            if output_heads == 1:
                predictions = [predictions]

            for prediction, prediction_map, normalization_mask in zip(
                    predictions, prediction_maps, normalization_masks):
                # squeeze batch dimension and convert back to numpy array
                prediction = prediction.squeeze(dim=0).cpu().numpy()
                if prediction_channel is not None:
                    # use only the 'prediction_channel'
                    logger.info(f"Using channel '{prediction_channel}'...")
                    prediction = np.expand_dims(prediction[prediction_channel],
                                                axis=0)

                # unpad in order to avoid block artifacts in the output probability maps
                u_prediction, u_index = utils.unpad(prediction, index,
                                                    volume_shape)
                # accumulate probabilities into the output prediction array
                prediction_map[u_index] += u_prediction
                # count voxel visits for normalization
                normalization_mask[u_index] += 1

    # save probability maps
    prediction_datasets = _get_dataset_names(config,
                                             output_heads,
                                             prefix='predictions')
    with h5py.File(output_file, 'w') as f:
        for prediction_map, normalization_mask, prediction_dataset in zip(
                prediction_maps, normalization_masks, prediction_datasets):
            prediction_map = prediction_map / normalization_mask
            logger.info(
                f'Saving predictions to: {output_file}/{prediction_dataset}...'
            )
            f.create_dataset(prediction_dataset,
                             data=prediction_map,
                             compression="gzip")
コード例 #6
0
ファイル: predict_fd.py プロジェクト: zyxwvu321/CT_seg
def predict(model, data_loader, config):
    """
    Return prediction masks by applying the model on the given dataset

    Args:
        model (Unet3D): trained 3D UNet model used for prediction
        data_loader (torch.utils.data.DataLoader): input data loader
        out_channels (int): number of channels in the network output
        device (torch.Device): device to run the prediction on

    Returns:
         prediction_maps (numpy array): prediction masks for given dataset
    """

    def _volume_shape(dataset):
        # TODO: support multiple internal datasets
        raw = dataset.raws[0]
        if raw.ndim == 3:
            return raw.shape
        else:
            return raw.shape[1:]

    out_channels = config['model'].get('out_channels')
    if out_channels is None:
        out_channels = config['model']['dt_out_channels']

    prediction_channel = config.get('prediction_channel', None)
    if prediction_channel is not None:
        logger.info(f"Using only channel '{prediction_channel}' from the network output")

    device = config['device']
    output_heads = config['model'].get('output_heads', 1)

    logger.info(f'Running prediction on {len(data_loader)} patches...')
    # dimensionality of the the output (CxDxHxW)
    volume_shape = _volume_shape(data_loader.dataset)
    if prediction_channel is None:
        prediction_maps_shape = (out_channels,) + volume_shape
    else:
        # single channel prediction map
        prediction_maps_shape = (1,) + volume_shape

    logger.info(f'The shape of the output prediction maps (CDHW): {prediction_maps_shape}')

    pad_width = config['model'].get('pad_width',None) 




    # initialize the output prediction arrays
    prediction_maps = [np.zeros(prediction_maps_shape, dtype='float32') for _ in range(output_heads)]
    # initialize normalization mask in order to average out probabilities of overlapping patches
    normalization_masks = [np.zeros(prediction_maps_shape, dtype='float32') for _ in range(output_heads)]

    # Sets the module in evaluation mode explicitly, otherwise the final Softmax/Sigmoid won't be applied!
    model.eval()
    # Run predictions on the entire input dataset
    with torch.no_grad():
        for tt in tqdm(data_loader):
            if len(tt)==2:
                patch,index = tt
            elif len(tt)==3:
                patch,gp,index = tt
            else:
                raise ValueError('len of loader is wrong')
            #logger.info(f'Predicting slice:{index}')

            # save patch index: (C,D,H,W)
            if prediction_channel is None:
                channel_slice = slice(0, out_channels)
            else:
                channel_slice = slice(0, 1)

            index = (channel_slice,) + tuple(index)

            # send patch to device
            patch = patch.to(device)
            
            # forward pass
            if len(tt)==2:
                predictions = model(patch)
            elif len(tt)==3:
                gp = gp.to(device)
                predictions = model(patch,gp)
            

            # wrap predictions into a list if there is only one output head from the network
            if output_heads == 1:
                predictions = [predictions]

            for prediction, prediction_map, normalization_mask in zip(predictions, prediction_maps,
                                                                      normalization_masks):
                # squeeze batch dimension and convert back to numpy array
                prediction = prediction.squeeze(dim=0).cpu().numpy()
                if prediction_channel is not None:
                    # use only the 'prediction_channel'
                    #logger.info(f"Using channel '{prediction_channel}'...")
                    prediction = np.expand_dims(prediction[prediction_channel], axis=0)

                # unpad in order to avoid block artifacts in the output probability maps
                u_prediction, u_index = utils.unpad(prediction, index, volume_shape,pad_width)



                # accumulate probabilities into the output prediction array
                prediction_map[u_index] += u_prediction
                # count voxel visits for normalization
                normalization_mask[u_index] += 1

    return [prediction_map / normalization_mask for prediction_map, normalization_mask in
            zip(prediction_maps, normalization_masks)]