def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None: if not isinstance(context, (ctx.TorchTrainContext, ctx.TorchTestContext)): raise ValueError( msg.get_type_error_msg( context, (ctx.TorchTrainContext, ctx.TorchTestContext))) batch_context.input['images'] = batch_context.input['images'].float( ).to(context.device) batch_context.input['labels'] = batch_context.input['labels'].long( ).to(context.device) pred = batch_context.input['labels'][:, 1] inpt = torch.cat( [batch_context.input['images'], pred.unsqueeze(1).float()], dim=1) logits = context.model(inpt) batch_context.output['logits'] = logits probabilities = F.softmax(logits, 1) batch_context.output['probabilities'] = probabilities # add the existing prediction to be reproduced batch_context.output['orig_prediction'] = pred.unsqueeze(1) # subject_eval needs clean (non-modified) labels batch_context.output['labels'] = batch_context.input['labels']
def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None: if not isinstance(context, ctx.TorchTrainContext): raise ValueError(msg.get_type_error_msg(context, ctx.TorchTrainContext)) context.optimizer.zero_grad() batch_context.input['images'] = batch_context.input['images'].float().to(context.device) batch_context.input['labels'] = batch_context.input['labels'].long().to(context.device) prediciton = batch_context.input['labels'][:, 1, ...] gt = batch_context.input['labels'][:, 0, ...] labels = (prediciton != gt).long() # update with correct label for the evaluation batch_context.input['labels'] = labels inpt = torch.cat([batch_context.input['images'], prediciton.unsqueeze(1).float()], dim=1) logits = context.model(inpt) batch_context.output['logits'] = logits loss = self.criterion(logits, labels) loss.backward() context.optimizer.step() batch_context.metrics['loss'] = loss.item()
def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None: if not isinstance(context, ctx.TorchTrainContext): raise ValueError( msg.get_type_error_msg(context, ctx.TorchTrainContext)) context.optimizer.zero_grad() batch_context.input['images'] = batch_context.input['images'].float( ).to(context.device) gt = batch_context.input['labels'].long().to(context.device) net1_logits = self.test_model(batch_context.input['images']) net_prediction = net1_logits.argmax(dim=1) batch_context.input['labels'] = (net_prediction != gt).long() logits = context.model(self.test_model.features) batch_context.output['logits'] = logits loss = self.criterion(logits, batch_context.input['labels']) loss.backward() context.optimizer.step() batch_context.metrics['loss'] = loss.item()
def _on_test_subject_end(subject_context: ctx.SubjectContext, task_context: ctx.TaskContext, context: ctx.TestContext): if not isinstance(context, ctx.TorchTestContext): raise ValueError( msg.get_type_error_msg(context, ctx.TorchTestContext)) conficence = subject_context.subject_data['probabilities'] confidence = conficence[..., 1] # foreground class prediction = subject_context.subject_data['orig_prediction'] img_probs = subject_context.subject_data[ 'properties'] # type: conversion.ImageProperties confidence_img = conversion.NumpySimpleITKImageBridge.convert( confidence, img_probs) prediction_img = conversion.NumpySimpleITKImageBridge.convert( prediction, img_probs) subject = subject_context.subject_data['subject'] sitk.WriteImage( confidence_img, os.path.join(context.test_dir, '{}_confidence.nii.gz'.format(subject))) sitk.WriteImage( prediction_img, os.path.join(context.test_dir, '{}_prediction.nii.gz'.format(subject)))
def _on_test_subject_end(subject_context: ctx.SubjectContext, task_context: ctx.TaskContext, context: ctx.TestContext): if not isinstance(context, ctx.TorchTestContext): raise ValueError( msg.get_type_error_msg(context, ctx.TorchTestContext)) probabilities = subject_context.subject_data['probabilities'] predictions = np.argmax(probabilities, axis=-1).astype(np.uint8) probabilities = probabilities[..., 1] # foreground class img_probs = subject_context.subject_data[ 'properties'] # type: conversion.ImageProperties probability_img = conversion.NumpySimpleITKImageBridge.convert( probabilities, img_probs) prediction_img = conversion.NumpySimpleITKImageBridge.convert( predictions, img_probs) subject = subject_context.subject_data['subject'] sitk.WriteImage( probability_img, os.path.join(context.test_dir, '{}_probabilities.nii.gz'.format(subject))) sitk.WriteImage( prediction_img, os.path.join(context.test_dir, '{}_prediction.nii.gz'.format(subject)))
def on_test_subject_end(self, subject_context: ctx.SubjectContext, task_context: ctx.TaskContext, context: ctx.TestContext): if not isinstance(context, ctx.TorchTestContext): raise ValueError(msg.get_type_error_msg(context, ctx.TorchTestContext)) thread.do_work(WriteHook._on_test_subject_end, subject_context, task_context, context, in_background=True)
def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None: if not isinstance(context, (ctx.TorchTrainContext, ctx.TorchTestContext)): raise ValueError(msg.get_type_error_msg(context, (ctx.TorchTrainContext, ctx.TorchTestContext))) batch_context.input['images'] = batch_context.input['images'].float().to(context.device) batch_context.input['labels'] = batch_context.input['labels'].long().to(context.device) pred = batch_context.input['labels'][:, 1] inpt = torch.cat([batch_context.input['images'], pred.unsqueeze(1).float()], dim=1) logits = context.model(inpt) batch_context.output['logits'] = logits probabilities = F.softmax(logits, 1) batch_context.output['probabilities'] = probabilities
def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None: if not isinstance(context, (ctx.TorchTrainContext, ctx.TorchTestContext)): raise ValueError(msg.get_type_error_msg(context, (ctx.TorchTrainContext, ctx.TorchTestContext))) batch_context.input['images'] = batch_context.input['images'].float().to(context.device) ensemble_probabilities = [] logits = context.model(batch_context.input['images']) probs = F.softmax(logits, 1) ensemble_probabilities.append(probs) for additional_model in self.additional_models: logits = additional_model(batch_context.input['images']) probs = F.softmax(logits, 1) ensemble_probabilities.append(probs) ensemble_probabilities = torch.stack(ensemble_probabilities) batch_context.output['multi_probabilities'] = ensemble_probabilities
def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None: if not isinstance(context, ctx.TorchTestContext): raise ValueError(msg.get_type_error_msg(context, ctx.TorchTestContext)) batch_context.input['images'] = batch_context.input['images'].float().to(context.device) mean_logits, sigma = context.model(batch_context.input['images']) batch_context.output['logits'] = mean_logits if self.is_log_sigma: sigma = sigma.exp() else: sigma = sigma.abs() batch_context.output['sigma'] = sigma probabilities = F.softmax(batch_context.output['logits'], 1) batch_context.output['probabilities'] = probabilities
def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None: if not isinstance(context, (ctx.TorchTrainContext, ctx.TorchTestContext)): raise ValueError( msg.get_type_error_msg( context, (ctx.TorchTrainContext, ctx.TorchTestContext))) batch_context.input['images'] = batch_context.input['images'].float( ).to(context.device) net1_logits = self.test_model(batch_context.input['images']) net_prediction = net1_logits.argmax(dim=1, keepdim=True) batch_context.output['net_predictions'] = net_prediction logits = context.model(self.test_model.features) probabilities = F.softmax(logits, 1) batch_context.output['probabilities'] = probabilities
def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None: if not isinstance(context, ctx.TorchTrainContext): raise ValueError( msg.get_type_error_msg(context, ctx.TorchTrainContext)) context.optimizer.zero_grad() batch_context.input['images'] = batch_context.input['images'].float( ).to(context.device) batch_context.input['labels'] = batch_context.input['labels'].long( ).to(context.device) mean_logits, sigma = context.model(batch_context.input['images']) loss = self.criterion(mean_logits, sigma, batch_context.input['labels']) loss.backward() context.optimizer.step() batch_context.metrics['loss'] = loss.item()
def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None: if not isinstance(context, (ctx.TorchTrainContext, ctx.TorchTestContext)): raise ValueError( msg.get_type_error_msg( context, (ctx.TorchTrainContext, ctx.TorchTestContext))) batch_context.input['images'] = batch_context.input['images'].float( ).to(context.device) if self.has_labels: batch_context.input['labels'] = batch_context.input['labels'].long( ).to(context.device) logits = context.model(batch_context.input['images']) batch_context.output['logits'] = logits if self.do_probs: probabilities = F.softmax(logits, 1) batch_context.output['probabilities'] = probabilities
def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None: if not isinstance(context, (ctx.TorchTrainContext, ctx.TorchTestContext)): raise ValueError(msg.get_type_error_msg(context, (ctx.TorchTrainContext, ctx.TorchTestContext))) batch_context.input['images'] = batch_context.input['images'].float().to(context.device) # weight scaling part, just for comparison ws_logits = context.model(batch_context.input['images']) ws_probabilities = F.softmax(ws_logits, 1) batch_context.output['ws_probabilities'] = ws_probabilities th.set_dropout_mode(context.model, is_train=True) # mc part mc_probabilities = [] for i in range(self.mc_steps): logits = context.model(batch_context.input['images']) probs = F.softmax(logits, 1) mc_probabilities.append(probs) mc_probabilities = torch.stack(mc_probabilities) batch_context.output['multi_probabilities'] = mc_probabilities # reset to eval for next batch th.set_dropout_mode(context.model, is_train=False)