def forward(self, input, target, weights): assert target.size() == weights.size() # normalize the input log_probabilities = self.log_softmax(input) # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW) target = expand_as_one_hot(target, C=input.size()[1], ignore_index=self.ignore_index) # expand weights weights = weights.unsqueeze(1) weights = weights.expand_as(input) # create default class_weights if None if self.class_weights is None: class_weights = torch.ones(input.size()[1]).float().to( input.device) else: class_weights = self.class_weights # resize class_weights to be broadcastable into the weights class_weights = class_weights.view(1, -1, 1, 1, 1) # multiply weights tensor by class weights weights = class_weights * weights # compute the losses result = -weights * target * log_probabilities # average the losses return result.mean()
def __call__(self, input, target): """ :param input: 5D probability maps torch float tensor (NxCxDxHxW) :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot :return: intersection over union averaged over all channels """ assert input.dim() == 5 predictions, target = convert_to_numpy(input, target) predictions = predictions[0] # global otsu threshold on the predictions global_thresh = threshold_otsu(predictions) low_intensity_region = np.where(predictions < global_thresh) predictions = np.array(predictions) predictions[low_intensity_region] = 0 predictions = np.expand_dims(predictions, axis=0) predictions = torch.tensor(predictions) target = torch.tensor(target) n_classes = input.size()[1] if target.dim() == 4: target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index) assert predictions.size() == target.size() per_batch_iou = [] for _input, _target in zip(predictions, target): binary_prediction = self._binarize_predictions(_input, n_classes) if self.ignore_index is not None: # zero out ignore_index mask = _target == self.ignore_index binary_prediction[mask] = 0 _target[mask] = 0 # convert to uint8 just in case binary_prediction = binary_prediction.byte() _target = _target.byte() per_channel_iou = [] for c in range(n_classes): if c in self.skip_channels: continue per_channel_iou.append( self._jaccard_index(binary_prediction[c], _target[c])) assert per_channel_iou, "All channels were ignored from the computation" mean_iou = torch.mean(torch.tensor(per_channel_iou)) per_batch_iou.append(mean_iou) return torch.mean(torch.tensor(per_batch_iou))
def forward(self, input, target, weights): assert target.size() == weights.size() # normalize the input log_probabilities = self.log_softmax(input) # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW) target = expand_as_one_hot(target, C=input.size()[1], ignore_index=self.ignore_index) # expand weights weights = weights.unsqueeze(0) weights = weights.expand_as(input) # mask ignore_index if present if self.ignore_index is not None: mask = Variable(target.data.ne(self.ignore_index).float(), requires_grad=False) log_probabilities = log_probabilities * mask target = target * mask # create default class_weights if None if self.class_weights is None: class_weights = torch.ones(input.size()[1]).float().to(input.device) self.register_buffer('class_weights', class_weights) # resize class_weights to be broadcastable into the weights class_weights = self.class_weights.view(1, -1, 1, 1, 1) # multiply weights tensor by class weights weights = class_weights * weights # compute the losses result = -weights * target * log_probabilities # average the losses return result.mean()
def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None): """ Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. Args: input (torch.Tensor): NxCxSpatial input tensor target (torch.Tensor): NxCxSpatial target tensor epsilon (float): prevents division by zero weight (torch.Tensor): Cx1 tensor of weight per channel/class """ # input and target shapes must match target = expand_as_one_hot(target, input.size()[1]) assert input.size() == target.size( ), "'input' and 'target' must have the same shape" input = flatten(input) target = flatten(target) target = target.float() # compute per channel Dice Coefficient intersect = (input * target).sum(-1) if weight is not None: intersect = weight * intersect # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1) denominator = (input * input).sum(-1) + (target * target).sum(-1) return 2 * (intersect / denominator.clamp(min=epsilon))
def forward(self, input, target): """ Args: input (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims) target (torch.tensor): ground truth instance segmentation (NxDxHxW) Returns: Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term """ # get number of instances in the batch C = torch.unique(target).size()[0] # expand each label as a one-hot vector: N x D x H x W -> N x C x D x H x W target = expand_as_one_hot(target, C) # compare spatial dimensions assert input.dim() == target.dim() == 5 assert input.size()[2:] == target.size()[2:] # compute mean embeddings and assign embeddings to instances cluster_means, embeddings_per_instance = self._compute_cluster_means(input, target) variance_term = self._compute_variance_term(cluster_means, embeddings_per_instance, target) distance_term = self._compute_distance_term(cluster_means, C) regularization_term = self._compute_regularizer_term(cluster_means, C) # total loss loss = self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term # reduce batch dimension return torch.mean(loss)
def __call__(self, input, target): """ :param input: 5D probability maps torch float tensor (NxCxDxHxW) :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot :return: intersection over union averaged over all channels """ assert input.dim() == 5 predictions, target = convert_to_numpy(input, target) predictions = predictions[0] # global otsu threshold on the predictions global_thresh = threshold_otsu(predictions) # define the foreground based on threshold foreground = predictions > global_thresh # get the local peaks from the predictions local_max = skimage.feature.peak_local_max(predictions[0], min_distance=5) # prepare the vol with peaks local_peaks_vol = np.zeros((1, 48, 128, 128)) for coordinate in local_max: local_peaks_vol[0][coordinate[0], coordinate[1], coordinate[2]] = 1.0 # dilate the peaks inv_local_peaks_vol = np.logical_not(local_peaks_vol) # get distance transform local_peaks_edt = ndimage.distance_transform_edt(inv_local_peaks_vol) # threshold the edt and invert back: fg as 1, bg as 0 spherical_peaks = local_peaks_edt > 3 spherical_peaks = np.logical_not(spherical_peaks).astype(np.float64) # get the outliers based on threshold and set zero outliers = np.where(spherical_peaks != foreground) spherical_peaks[outliers] = 0 spherical_peaks = np.expand_dims(spherical_peaks, axis=0) # print(spherical_peaks.shape) # print(np.min(spherical_peaks)) # print(np.max(spherical_peaks)) # print(len(np.where(spherical_peaks==1.0)[0])) # spherical_peaks = torch.tensor(spherical_peaks) # spherical_peaks.to('cuda') spherical_peaks = torch.tensor(spherical_peaks) target = torch.tensor(target) n_classes = input.size()[1] if target.dim() == 4: target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index) assert spherical_peaks.size() == target.size() per_batch_iou = [] for _input, _target in zip(spherical_peaks, target): binary_prediction = self._binarize_predictions(_input, n_classes) if self.ignore_index is not None: # zero out ignore_index mask = _target == self.ignore_index binary_prediction[mask] = 0 _target[mask] = 0 # convert to uint8 just in case binary_prediction = binary_prediction.byte() _target = _target.byte() per_channel_iou = [] for c in range(n_classes): if c in self.skip_channels: continue per_channel_iou.append(self._jaccard_index(binary_prediction[c], _target[c])) assert per_channel_iou, "All channels were ignored from the computation" mean_iou = torch.mean(torch.tensor(per_channel_iou)) per_batch_iou.append(mean_iou) return torch.mean(torch.tensor(per_batch_iou))
def create_datasets(cls, dataset_config, phase): # prerocess each patient case sxth_case_loader = SXTHCaseLoader( dataset_config['original_data_dir'] ) sxth_case_loader.set_phase(phase) cls.original_dir = Path(dataset_config['original_data_dir']) cls.train_dir = Path(dataset_config[phase]['file_paths'][0]) if cls.train_dir.is_dir() == False: cls.train_dir.mkdir( parents = True ) # exist already preprocessed data? files = os.listdir(cls.train_dir) need_processing = True if need_processing: # deleted existed files for f in files: os.remove(Path(cls.train_dir)/f) # We use the following step to carry out the pre-process operation # STEP 1: loop all cases for their three axis' voxel space, determin a suitable voxel space # STEP 2: for each case, # a. remove dark borders (air) of each slice, crop each slice to human body # b. resample each slice according to the above determined suitable voxel space # c. only keep the slices from the first to the last slice with mask # STEP 3: loop all cases for a suitable path size # STEP 1: #case_info_dict = {} #rt_info_file = open("runtime_info.yml",'w',encoding='utf-8') seg_info = [] foreground_vol = np.array([]) one_case_visualizer = visualizer() for vol,seg,img_info in sxth_case_loader: spacing = img_info['Spacing'] case_id = img_info['Case_id'] #if phase != 'test': # crop_border = cls.crop_image_only_outside(cls, vol, -1000) # seg = seg[:,crop_border[0]:crop_border[1], crop_border[2]:crop_border[3]] # vol = vol[:,crop_border[0]:crop_border[1], crop_border[2]:crop_border[3]] # resample (or re-slice) for isotropic voxel new_spacing = [2, 0.8, 0.8] resize_factor = np.array(spacing) / new_spacing new_real_shape = vol.shape * resize_factor new_shape = np.round(new_real_shape) real_resize_factor = new_shape / vol.shape new_spacing = spacing / real_resize_factor #case_info = { 'case_'+str(case_id): np.array( spacing ) } #case_info_dict.update( case_info ) resize_factor = np.array(spacing) / new_spacing print(f'\t shape after crop: {vol.shape}') print(f'\t resampling image voxel spacing from {spacing} to {new_spacing} ...') vol = ndimage.zoom(vol, resize_factor, order=0 ) print(f'\t shape after resample: {vol.shape}') print('\t resampling segmetation voxel ...') nclass = 2 seg_resampled = np.zeros( ( nclass, vol.shape[0], vol.shape[1], vol.shape[2] ), dtype='int32' ) if phase != 'test': seg_expanded = torch.from_numpy( np.expand_dims(seg.astype('int64'),0) ) seg_onehot = expand_as_one_hot( seg_expanded, nclass).squeeze(0).numpy() for ilabel in range( nclass): seg_resampled[ilabel,] = ndimage.zoom( seg_onehot[ilabel,], resize_factor, order= 0) seg = np.argmax( seg_resampled, axis = 0) # collect all fore-ground voxels for further compute max/min/mean/std value foreground_vol = np.concatenate((foreground_vol, vol[seg>0]),axis = 0) #one_case_visualizer.save(vol,seg,'/mnt/sda2/sxth_processed/img') print(f'\t after removing slices without masks, image & segmentation data shape is {vol.shape} ') print('\t done.') # store as a h5d file f = h5py.File(cls.train_dir/'case_{:05d}.h5'.format(case_id),'w') f.create_dataset('raw', data = vol) if phase != 'test': f.create_dataset('label', data = seg) f.create_dataset('weight', data = np.ones( seg.shape ) ) f.close() seg_info.append(vol.shape) # save data to disk #save_img = visualizer(vol_folder = dataset_config.get(phase)['file_paths'][0]) #save_img.save_train_cases( '/mnt/sda2/sxth_processed/img/') #ipdb.set_trace() return super().create_datasets( dataset_config, phase )
def create_datasets(cls, dataset_config, phase): # prerocess each patient case # exist folder for storing preprocessed files? cls.original_dir = Path(dataset_config['original_data_dir']) cls.train_dir = Path(dataset_config[phase]['file_paths'][0]) if cls.train_dir.is_dir() == False: cls.train_dir.mkdir() # exist already preprocessed data? files = os.listdir(cls.train_dir) need_processing = True if phase == 'train': id_range = [ x for x in list(range(0, 200)) if x not in [15, 37, 88] ] if phase == 'val': id_range = range(200, 210) if phase == 'test': id_range = range(210, 300) if (phase == 'train' and len(files) == len(id_range) ) or \ (phase == 'val' and len(files) == len(id_range) ) or \ (phase == 'test' and len(files) == len(id_range) ): need_processing = False if need_processing: # deleted existed files for f in files: os.remove(Path(cls.train_dir) / f) # We use the following step to carry out the pre-process operation # STEP 1: loop all cases for their three axis' voxel space, determin a suitable voxel space # STEP 2: for each case, # a. remove dark borders (air) of each slice, crop each slice to human body # b. resample each slice according to the above determined suitable voxel space # c. only keep the slices from the first to the last slice with mask # STEP 3: loop all cases for a suitable path size # STEP 1: #case_info_dict = {} #rt_info_file = open("runtime_info.yml",'w',encoding='utf-8') seg_info = [] foreground_vol = np.array([]) one_case_visualizer = visualizer() for case_id in id_range: print( f'Preprocessing {phase} case {case_id+1}/{len(id_range)},') print('\t reading data ...,') if phase != 'test': vol, seg = cls.load_case(cls, case_id) seg = seg.get_data() seg = seg.astype(np.int32) print(f'\t original shape: {vol.shape}') else: vol = cls.load_case(cls, case_id) #spacing = vol.affine spacing = vol.header.get_zooms() vol = vol.get_data() # crop CT image to humman body if phase != 'test': crop_border = cls.crop_image_only_outside(cls, vol, -1000) seg = seg[:, crop_border[0]:crop_border[1], crop_border[2]:crop_border[3]] vol = vol[:, crop_border[0]:crop_border[1], crop_border[2]:crop_border[3]] # resample (or re-slice) for isotropic voxel new_spacing = [2, 2, 2] resize_factor = np.array(spacing) / new_spacing new_real_shape = vol.shape * resize_factor new_shape = np.round(new_real_shape) real_resize_factor = new_shape / vol.shape new_spacing = spacing / real_resize_factor #case_info = { 'case_'+str(case_id): np.array( spacing ) } #case_info_dict.update( case_info ) resize_factor = np.array(spacing) / new_spacing print(f'\t shape after crop: {vol.shape}') print( f'\t resampling image voxel spacing from {spacing} to {new_spacing} ...' ) vol = ndimage.zoom(vol, resize_factor, order=0) print(f'\t shape after resample: {vol.shape}') print('\t resampling segmetation voxel ...') nclass = 3 seg_resampled = np.zeros( (nclass, vol.shape[0], vol.shape[1], vol.shape[2]), dtype='int32') if phase != 'test': seg_expanded = torch.from_numpy( np.expand_dims(seg.astype('int64'), 0)) seg_onehot = expand_as_one_hot(seg_expanded, nclass).squeeze(0).numpy() for ilabel in range(nclass): seg_resampled[ilabel, ] = ndimage.zoom( seg_onehot[ilabel, ], resize_factor, order=0) seg = np.argmax(seg_resampled, axis=0) # 3D ROI, CT slices only contain masks will be preserved ior_z = [] for i in range(seg.shape[0]): unique_list = np.unique(seg[i, :, :]) assert (all( [element in (0, 1, 2) for element in unique_list])) if len(unique_list) > 1 and len(unique_list) <= 3: ior_z.append(i) ior_zmin = min( ior_z) if min(ior_z) - 3 < 0 else min(ior_z) - 3 ior_zmax = max(ior_z) if max( ior_z) + 3 > vol.shape[0] else max(ior_z) + 3 vol = vol[ior_zmin:ior_zmax + 1, :, :] seg = seg[ior_zmin:ior_zmax + 1, :, :] # collect all fore-ground voxels for further compute max/min/mean/std value foreground_vol = np.concatenate( (foreground_vol, vol[seg > 0]), axis=0) #ipdb.set_trace() #one_case_visualizer.save(vol,seg,'/mnt/sda2/kits19_processed/img') print( f'\t after removing slices without masks, image & segmentation data shape is {vol.shape} ' ) print('\t done.') # store as a h5d file f = h5py.File(cls.train_dir / 'case_{:05d}.h5'.format(case_id), 'w') f.create_dataset('raw', data=vol) if phase != 'test': f.create_dataset('label', data=seg) f.create_dataset('weight', data=np.ones(seg.shape)) f.close() seg_info.append(vol.shape) if phase != 'test': ipdb.set_trace() print(np.percentile(foreground_vol, 99.5)) print(np.percentile(foreground_vol, 0.5)) foreground_vol = foreground_vol[(foreground_vol >= -73) & (foreground_vol <= 289)] # save data to disk #save_img = visualizer(vol_folder = dataset_config.get(phase)['file_paths'][0]) #save_img.save_train_cases( '/mnt/sda2/kits19_processed/img/') #ipdb.set_trace() return super().create_datasets(dataset_config, phase)