コード例 #1
0
ファイル: losses.py プロジェクト: wolny/pytorch-3dunet
    def forward(self, input, target, weights):
        assert target.size() == weights.size()
        # normalize the input
        log_probabilities = self.log_softmax(input)
        # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW)
        target = expand_as_one_hot(target,
                                   C=input.size()[1],
                                   ignore_index=self.ignore_index)
        # expand weights
        weights = weights.unsqueeze(1)
        weights = weights.expand_as(input)

        # create default class_weights if None
        if self.class_weights is None:
            class_weights = torch.ones(input.size()[1]).float().to(
                input.device)
        else:
            class_weights = self.class_weights

        # resize class_weights to be broadcastable into the weights
        class_weights = class_weights.view(1, -1, 1, 1, 1)

        # multiply weights tensor by class weights
        weights = class_weights * weights

        # compute the losses
        result = -weights * target * log_probabilities
        # average the losses
        return result.mean()
コード例 #2
0
    def __call__(self, input, target):
        """
		:param input: 5D probability maps torch float tensor (NxCxDxHxW)
		:param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot
		:return: intersection over union averaged over all channels
		"""
        assert input.dim() == 5

        predictions, target = convert_to_numpy(input, target)
        predictions = predictions[0]

        # global otsu threshold on the predictions
        global_thresh = threshold_otsu(predictions)

        low_intensity_region = np.where(predictions < global_thresh)

        predictions = np.array(predictions)
        predictions[low_intensity_region] = 0
        predictions = np.expand_dims(predictions, axis=0)

        predictions = torch.tensor(predictions)

        target = torch.tensor(target)

        n_classes = input.size()[1]

        if target.dim() == 4:
            target = expand_as_one_hot(target,
                                       C=n_classes,
                                       ignore_index=self.ignore_index)

        assert predictions.size() == target.size()

        per_batch_iou = []
        for _input, _target in zip(predictions, target):
            binary_prediction = self._binarize_predictions(_input, n_classes)

            if self.ignore_index is not None:
                # zero out ignore_index
                mask = _target == self.ignore_index
                binary_prediction[mask] = 0
                _target[mask] = 0

            # convert to uint8 just in case
            binary_prediction = binary_prediction.byte()
            _target = _target.byte()

            per_channel_iou = []
            for c in range(n_classes):
                if c in self.skip_channels:
                    continue

                per_channel_iou.append(
                    self._jaccard_index(binary_prediction[c], _target[c]))

            assert per_channel_iou, "All channels were ignored from the computation"
            mean_iou = torch.mean(torch.tensor(per_channel_iou))
            per_batch_iou.append(mean_iou)

        return torch.mean(torch.tensor(per_batch_iou))
コード例 #3
0
    def forward(self, input, target, weights):
        assert target.size() == weights.size()
        # normalize the input
        log_probabilities = self.log_softmax(input)
        # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW)
        target = expand_as_one_hot(target, C=input.size()[1], ignore_index=self.ignore_index)
        # expand weights
        weights = weights.unsqueeze(0)
        weights = weights.expand_as(input)

        # mask ignore_index if present
        if self.ignore_index is not None:
            mask = Variable(target.data.ne(self.ignore_index).float(), requires_grad=False)
            log_probabilities = log_probabilities * mask
            target = target * mask

        # create default class_weights if None
        if self.class_weights is None:
            class_weights = torch.ones(input.size()[1]).float().to(input.device)
            self.register_buffer('class_weights', class_weights)

        # resize class_weights to be broadcastable into the weights
        class_weights = self.class_weights.view(1, -1, 1, 1, 1)

        # multiply weights tensor by class weights
        weights = class_weights * weights

        # compute the losses
        result = -weights * target * log_probabilities
        # average the losses
        return result.mean()
コード例 #4
0
ファイル: losses.py プロジェクト: charmsoya/pytorch-3dunet
def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None):
    """
    Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given  a multi channel input and target.
    Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function.

    Args:
         input (torch.Tensor): NxCxSpatial input tensor
         target (torch.Tensor): NxCxSpatial target tensor
         epsilon (float): prevents division by zero
         weight (torch.Tensor): Cx1 tensor of weight per channel/class
    """

    # input and target shapes must match
    target = expand_as_one_hot(target, input.size()[1])
    assert input.size() == target.size(
    ), "'input' and 'target' must have the same shape"

    input = flatten(input)
    target = flatten(target)
    target = target.float()

    # compute per channel Dice Coefficient
    intersect = (input * target).sum(-1)
    if weight is not None:
        intersect = weight * intersect

    # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1)
    denominator = (input * input).sum(-1) + (target * target).sum(-1)
    return 2 * (intersect / denominator.clamp(min=epsilon))
コード例 #5
0
    def forward(self, input, target):
        """
        Args:
             input (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims)
             target (torch.tensor): ground truth instance segmentation (NxDxHxW)

        Returns:
            Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term
        """
        # get number of instances in the batch
        C = torch.unique(target).size()[0]
        # expand each label as a one-hot vector: N x D x H x W -> N x C x D x H x W
        target = expand_as_one_hot(target, C)
        # compare spatial dimensions
        assert input.dim() == target.dim() == 5
        assert input.size()[2:] == target.size()[2:]

        # compute mean embeddings and assign embeddings to instances
        cluster_means, embeddings_per_instance = self._compute_cluster_means(input, target)
        variance_term = self._compute_variance_term(cluster_means, embeddings_per_instance, target)
        distance_term = self._compute_distance_term(cluster_means, C)
        regularization_term = self._compute_regularizer_term(cluster_means, C)
        # total loss
        loss = self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term
        # reduce batch dimension
        return torch.mean(loss)
コード例 #6
0
    def __call__(self, input, target):
        """
        :param input: 5D probability maps torch float tensor (NxCxDxHxW)
        :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot
        :return: intersection over union averaged over all channels
        """
        assert input.dim() == 5

        predictions, target = convert_to_numpy(input, target)
        predictions = predictions[0]

        # global otsu threshold on the predictions
        global_thresh = threshold_otsu(predictions)

        # define the foreground based on threshold
        foreground = predictions > global_thresh

        # get the local peaks from the predictions
        local_max = skimage.feature.peak_local_max(predictions[0], min_distance=5)

        # prepare the vol with peaks
        local_peaks_vol = np.zeros((1, 48, 128, 128))

        for coordinate in local_max:
            local_peaks_vol[0][coordinate[0], coordinate[1], coordinate[2]] = 1.0

        # dilate the peaks
        inv_local_peaks_vol = np.logical_not(local_peaks_vol)

        # get distance transform
        local_peaks_edt = ndimage.distance_transform_edt(inv_local_peaks_vol)

        # threshold the edt and invert back: fg as 1, bg as 0
        spherical_peaks = local_peaks_edt > 3
        spherical_peaks = np.logical_not(spherical_peaks).astype(np.float64)

        # get the outliers based on threshold and set zero
        outliers = np.where(spherical_peaks != foreground)
        spherical_peaks[outliers] = 0

        spherical_peaks = np.expand_dims(spherical_peaks, axis=0)
        # print(spherical_peaks.shape)
        # print(np.min(spherical_peaks))
        # print(np.max(spherical_peaks))
        # print(len(np.where(spherical_peaks==1.0)[0]))

        # spherical_peaks = torch.tensor(spherical_peaks)

        # spherical_peaks.to('cuda')

        spherical_peaks = torch.tensor(spherical_peaks)
        target = torch.tensor(target)

        n_classes = input.size()[1]

        if target.dim() == 4:
            target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index)

        assert spherical_peaks.size() == target.size()

        per_batch_iou = []
        for _input, _target in zip(spherical_peaks, target):
            binary_prediction = self._binarize_predictions(_input, n_classes)

            if self.ignore_index is not None:
                # zero out ignore_index
                mask = _target == self.ignore_index
                binary_prediction[mask] = 0
                _target[mask] = 0

            # convert to uint8 just in case
            binary_prediction = binary_prediction.byte()
            _target = _target.byte()

            per_channel_iou = []
            for c in range(n_classes):
                if c in self.skip_channels:
                    continue

                per_channel_iou.append(self._jaccard_index(binary_prediction[c], _target[c]))

            assert per_channel_iou, "All channels were ignored from the computation"
            mean_iou = torch.mean(torch.tensor(per_channel_iou))
            per_batch_iou.append(mean_iou)

        return torch.mean(torch.tensor(per_batch_iou))
コード例 #7
0
ファイル: sxth.py プロジェクト: charmsoya/pytorch-3dunet
    def create_datasets(cls, dataset_config, phase):
        # prerocess each patient case
        sxth_case_loader = SXTHCaseLoader( dataset_config['original_data_dir']  )
        sxth_case_loader.set_phase(phase)
        cls.original_dir = Path(dataset_config['original_data_dir'])
        cls.train_dir = Path(dataset_config[phase]['file_paths'][0])
        if cls.train_dir.is_dir() == False: 
            cls.train_dir.mkdir( parents = True )

        # exist already preprocessed data?
        files = os.listdir(cls.train_dir)
        need_processing = True 

        if need_processing:
            # deleted existed files
            for f in files:
                    os.remove(Path(cls.train_dir)/f)

            # We use the following step to carry out the pre-process operation
            # STEP 1: loop all cases for their three axis' voxel space, determin a suitable voxel space
            # STEP 2: for each case, 
            #   a. remove dark borders (air) of each slice, crop each slice to human body
            #   b. resample each slice according to the above determined suitable voxel space
            #   c. only keep the slices from the first  to the last slice with mask
            # STEP 3: loop all cases for a suitable path size   
    
            # STEP 1:
            #case_info_dict = {}  
            #rt_info_file = open("runtime_info.yml",'w',encoding='utf-8')
            seg_info = []
            foreground_vol = np.array([])
            one_case_visualizer = visualizer()

            for vol,seg,img_info in sxth_case_loader: 

                spacing = img_info['Spacing']
                case_id = img_info['Case_id']
                #if phase != 'test':
                #    crop_border = cls.crop_image_only_outside(cls, vol, -1000)
                #    seg = seg[:,crop_border[0]:crop_border[1], crop_border[2]:crop_border[3]]
                #    vol = vol[:,crop_border[0]:crop_border[1], crop_border[2]:crop_border[3]]

                # resample (or re-slice) for isotropic voxel
                new_spacing = [2, 0.8, 0.8]
                resize_factor = np.array(spacing) / new_spacing
                new_real_shape = vol.shape * resize_factor
                new_shape = np.round(new_real_shape)   
                real_resize_factor = new_shape / vol.shape
                new_spacing = spacing / real_resize_factor

                #case_info = { 'case_'+str(case_id): np.array( spacing ) }
                #case_info_dict.update( case_info )
                resize_factor = np.array(spacing) / new_spacing
                print(f'\t shape after crop: {vol.shape}')
                print(f'\t resampling image voxel spacing from {spacing} to {new_spacing} ...')
                vol = ndimage.zoom(vol, resize_factor, order=0 )
                print(f'\t shape after resample: {vol.shape}') 
                
                print('\t resampling segmetation voxel ...')
                nclass = 2 
                seg_resampled = np.zeros( ( nclass, vol.shape[0], vol.shape[1], vol.shape[2] ), dtype='int32' )
                if phase != 'test':
                    seg_expanded = torch.from_numpy( np.expand_dims(seg.astype('int64'),0) )
                    seg_onehot = expand_as_one_hot( seg_expanded, nclass).squeeze(0).numpy()
                    for ilabel in range( nclass):
                        seg_resampled[ilabel,] = ndimage.zoom( seg_onehot[ilabel,], resize_factor, order= 0)
                    seg = np.argmax( seg_resampled, axis = 0)
		    # collect all fore-ground voxels for further compute  max/min/mean/std value
                    foreground_vol = np.concatenate((foreground_vol, vol[seg>0]),axis = 0)
                #one_case_visualizer.save(vol,seg,'/mnt/sda2/sxth_processed/img') 
                print(f'\t after removing slices without masks, image & segmentation data shape is {vol.shape} ')
                print('\t done.')

                # store as a h5d file
                f =  h5py.File(cls.train_dir/'case_{:05d}.h5'.format(case_id),'w')
                f.create_dataset('raw', data = vol)
                if phase != 'test':
                    f.create_dataset('label', data = seg)
                    f.create_dataset('weight', data = np.ones( seg.shape ) )
                f.close()
                seg_info.append(vol.shape)
            
        # save data to disk
        #save_img = visualizer(vol_folder = dataset_config.get(phase)['file_paths'][0]) 
        #save_img.save_train_cases( '/mnt/sda2/sxth_processed/img/')
        #ipdb.set_trace()
        return super().create_datasets( dataset_config, phase )
コード例 #8
0
ファイル: kits19.py プロジェクト: charmsoya/pytorch-3dunet
    def create_datasets(cls, dataset_config, phase):
        # prerocess each patient case

        # exist folder for storing preprocessed files?
        cls.original_dir = Path(dataset_config['original_data_dir'])
        cls.train_dir = Path(dataset_config[phase]['file_paths'][0])
        if cls.train_dir.is_dir() == False:
            cls.train_dir.mkdir()

        # exist already preprocessed data?
        files = os.listdir(cls.train_dir)
        need_processing = True

        if phase == 'train':
            id_range = [
                x for x in list(range(0, 200)) if x not in [15, 37, 88]
            ]
        if phase == 'val':
            id_range = range(200, 210)
        if phase == 'test':
            id_range = range(210, 300)

        if (phase == 'train' and len(files) == len(id_range) ) or \
           (phase == 'val' and len(files) == len(id_range) ) or \
           (phase == 'test' and len(files) == len(id_range) ):
            need_processing = False

        if need_processing:
            # deleted existed files
            for f in files:
                os.remove(Path(cls.train_dir) / f)

            # We use the following step to carry out the pre-process operation
            # STEP 1: loop all cases for their three axis' voxel space, determin a suitable voxel space
            # STEP 2: for each case,
            #   a. remove dark borders (air) of each slice, crop each slice to human body
            #   b. resample each slice according to the above determined suitable voxel space
            #   c. only keep the slices from the first  to the last slice with mask
            # STEP 3: loop all cases for a suitable path size

            # STEP 1:
            #case_info_dict = {}
            #rt_info_file = open("runtime_info.yml",'w',encoding='utf-8')
            seg_info = []
            foreground_vol = np.array([])
            one_case_visualizer = visualizer()

            for case_id in id_range:

                print(
                    f'Preprocessing {phase} case {case_id+1}/{len(id_range)},')
                print('\t reading data ...,')
                if phase != 'test':
                    vol, seg = cls.load_case(cls, case_id)
                    seg = seg.get_data()
                    seg = seg.astype(np.int32)
                    print(f'\t original shape: {vol.shape}')
                else:
                    vol = cls.load_case(cls, case_id)
                #spacing = vol.affine
                spacing = vol.header.get_zooms()
                vol = vol.get_data()

                # crop CT image to humman body
                if phase != 'test':
                    crop_border = cls.crop_image_only_outside(cls, vol, -1000)
                    seg = seg[:, crop_border[0]:crop_border[1],
                              crop_border[2]:crop_border[3]]
                    vol = vol[:, crop_border[0]:crop_border[1],
                              crop_border[2]:crop_border[3]]

                # resample (or re-slice) for isotropic voxel
                new_spacing = [2, 2, 2]
                resize_factor = np.array(spacing) / new_spacing
                new_real_shape = vol.shape * resize_factor
                new_shape = np.round(new_real_shape)
                real_resize_factor = new_shape / vol.shape
                new_spacing = spacing / real_resize_factor

                #case_info = { 'case_'+str(case_id): np.array( spacing ) }
                #case_info_dict.update( case_info )
                resize_factor = np.array(spacing) / new_spacing
                print(f'\t shape after crop: {vol.shape}')
                print(
                    f'\t resampling image voxel spacing from {spacing} to {new_spacing} ...'
                )
                vol = ndimage.zoom(vol, resize_factor, order=0)
                print(f'\t shape after resample: {vol.shape}')

                print('\t resampling segmetation voxel ...')
                nclass = 3
                seg_resampled = np.zeros(
                    (nclass, vol.shape[0], vol.shape[1], vol.shape[2]),
                    dtype='int32')
                if phase != 'test':
                    seg_expanded = torch.from_numpy(
                        np.expand_dims(seg.astype('int64'), 0))
                    seg_onehot = expand_as_one_hot(seg_expanded,
                                                   nclass).squeeze(0).numpy()
                    for ilabel in range(nclass):
                        seg_resampled[ilabel, ] = ndimage.zoom(
                            seg_onehot[ilabel, ], resize_factor, order=0)
                    seg = np.argmax(seg_resampled, axis=0)
                    # 3D ROI, CT slices only contain masks will be preserved
                    ior_z = []
                    for i in range(seg.shape[0]):
                        unique_list = np.unique(seg[i, :, :])
                        assert (all(
                            [element in (0, 1, 2) for element in unique_list]))
                        if len(unique_list) > 1 and len(unique_list) <= 3:
                            ior_z.append(i)
                    ior_zmin = min(
                        ior_z) if min(ior_z) - 3 < 0 else min(ior_z) - 3
                    ior_zmax = max(ior_z) if max(
                        ior_z) + 3 > vol.shape[0] else max(ior_z) + 3
                    vol = vol[ior_zmin:ior_zmax + 1, :, :]
                    seg = seg[ior_zmin:ior_zmax + 1, :, :]
                    # collect all fore-ground voxels for further compute  max/min/mean/std value
                    foreground_vol = np.concatenate(
                        (foreground_vol, vol[seg > 0]), axis=0)
                #ipdb.set_trace()
                #one_case_visualizer.save(vol,seg,'/mnt/sda2/kits19_processed/img')
                print(
                    f'\t after removing slices without masks, image & segmentation data shape is {vol.shape} '
                )
                print('\t done.')

                # store as a h5d file
                f = h5py.File(cls.train_dir / 'case_{:05d}.h5'.format(case_id),
                              'w')
                f.create_dataset('raw', data=vol)
                if phase != 'test':
                    f.create_dataset('label', data=seg)
                    f.create_dataset('weight', data=np.ones(seg.shape))
                f.close()
                seg_info.append(vol.shape)

            if phase != 'test':
                ipdb.set_trace()
                print(np.percentile(foreground_vol, 99.5))
                print(np.percentile(foreground_vol, 0.5))
                foreground_vol = foreground_vol[(foreground_vol >= -73)
                                                & (foreground_vol <= 289)]

        # save data to disk
        #save_img = visualizer(vol_folder = dataset_config.get(phase)['file_paths'][0])
        #save_img.save_train_cases( '/mnt/sda2/kits19_processed/img/')
        #ipdb.set_trace()
        return super().create_datasets(dataset_config, phase)