def test_input_warnings(self): chn_input = torch.ones((1, 1, 3)) chn_target = torch.ones((1, 1, 3)) with self.assertWarns(Warning): loss = DiceLoss(include_background=False) loss.forward(chn_input, chn_target) with self.assertWarns(Warning): loss = DiceLoss(softmax=True) loss.forward(chn_input, chn_target) with self.assertWarns(Warning): loss = DiceLoss(to_onehot_y=True) loss.forward(chn_input, chn_target)
def training_step(self, batch, batch_idx): inputs, targets = self.prepare_batch(batch) pred = self(inputs) # diceloss = DiceLoss(include_background=True, to_onehot_y=True) # loss = diceloss.forward(input=probs, target=targets) # dice, iou, _, _ = get_score(batch_preds, batch_targets, include_background=True) # gdloss = GeneralizedDiceLoss(include_background=True, to_onehot_y=True) # loss = gdloss.forward(input=batch_preds, target=batch_targets) # if batch_idx != 0 and ((self.current_epoch >= 1 and dice.item() < 0.5) or batch_idx % 100 == 0): # input = inputs.chunk(inputs.size()[0], 0)[0] # split into 1 in the dimension 0 # target = targets.chunk(targets.size()[0], 0)[0] # split into 1 in the dimension 0 # prob = probs.chunk(probs.size()[0], 0)[0] # split into 1 in the dimension 0 # # really have problem in there, need to fix it # dice_score, _, _, _ = get_score(torch.unsqueeze(prob, 0), torch.unsqueeze(target, 0)) # log_all_info(self, input, target, prob, batch_idx, "training", dice_score.item()) # loss = F.binary_cross_entropy_with_logits(logits, targets) diceloss = DiceLoss(include_background=self.hparams.include_background, to_onehot_y=True) loss = diceloss.forward(input=pred, target=targets) # What is the loos I need to set here? when I am using the batch size? # gdloss = GeneralizedDiceLoss(include_background=True, to_onehot_y=True) # loss = gdloss.forward(input=batch_preds, target=batch_targets) # I cannot use this `TrainResult` right now # the loss for prog_bar is not corrected, is there anything I write wrong? result = pl.TrainResult(minimize=loss) # logs metrics for each training_step, to the progress bar and logger result.log("train_loss", loss, prog_bar=True, sync_dist=True, logger=True, reduce_fx=torch.mean, on_step=True, on_epoch=False) # we cannot compute the matrixs on the patches, because they do not contain all the 138 segmentations # So they would return 0 on some of the classes, making the matrixs not accurate return result
def test_ill_shape(self): loss = DiceLoss() with self.assertRaisesRegex(AssertionError, ""): loss.forward(torch.ones((1, 2, 3)), torch.ones((4, 5, 6)))
def compute_from_aggregating(self, input, target, if_path: bool, type_as_tensor=None, whether_to_return_img=False, result: pl.EvalResult = None): transform = get_val_transform() if if_path: cur_img_subject = torchio.Subject( img=torchio.Image(input, type=torchio.INTENSITY)) cur_label_subject = torchio.Subject( img=torchio.Image(target, type=torchio.LABEL)) preprocessed_img = transform(cur_img_subject) preprocessed_label = transform(cur_label_subject) patch_overlap = self.hparams.patch_overlap # is there any constrain? grid_sampler = torchio.inference.GridSampler( preprocessed_img, self.patch_size, patch_overlap, ) patch_loader = torch.utils.data.DataLoader(grid_sampler) aggregator = torchio.inference.GridAggregator(grid_sampler) for patches_batch in patch_loader: input_tensor = patches_batch['img'][torchio.DATA] # used to convert tensor to CUDA input_tensor = input_tensor.type_as(type_as_tensor['val_dice']) locations = patches_batch[torchio.LOCATION] preds = self(input_tensor) # use cuda labels = preds.argmax(dim=torchio.CHANNELS_DIMENSION, keepdim=True) # use cuda aggregator.add_batch(labels, locations) output_tensor = aggregator.get_output_tensor() # not using cuda! if if_path or whether_to_return_img: return preprocessed_img.img.data, output_tensor, preprocessed_label.img.data else: return output_tensor, preprocessed_label.img.data else: cur_subject = torchio.Subject( img=torchio.Image(tensor=input.squeeze(), type=torchio.INTENSITY), label=torchio.Image(tensor=target.squeeze(), type=torchio.LABEL)) preprocessed_subject = transform(cur_subject) patch_overlap = self.hparams.patch_overlap # is there any constrain? grid_sampler = torchio.inference.GridSampler( preprocessed_subject, self.patch_size, patch_overlap, ) patch_loader = torch.utils.data.DataLoader(grid_sampler) aggregator = torchio.inference.GridAggregator(grid_sampler) dice_loss = [] for patches_batch in patch_loader: input_tensor, target_tensor = patches_batch['img'][ torchio.DATA], patches_batch['label'][torchio.DATA] # used to convert tensor to CUDA input_tensor = input_tensor.type_as(input) locations = patches_batch[torchio.LOCATION] preds_tensor = self(input_tensor) # use cuda # Compute the loss here diceloss = DiceLoss( include_background=self.hparams.include_background, to_onehot_y=True) loss = diceloss.forward(input=preds_tensor, target=target_tensor) dice_loss.append(loss) labels = preds_tensor.argmax(dim=torchio.CHANNELS_DIMENSION, keepdim=True) # use cuda aggregator.add_batch(labels, locations) output_tensor = aggregator.get_output_tensor( ) # not using cuda!!!! if whether_to_return_img: return cur_subject['img'].data, output_tensor, cur_subject[ 'label'].data else: return output_tensor, cur_subject['label'].data, torch.stack( dice_loss)