def __call__(self, batch_context: ctx.BatchContext,
                 task_context: ctx.TaskContext, context: ctx.Context) -> None:
        if not isinstance(context, ctx.TorchTrainContext):
            raise ValueError(
                msg.get_type_error_msg(context, ctx.TorchTrainContext))

        context.optimizer.zero_grad()

        batch_context.input['images'] = batch_context.input['images'].float(
        ).to(context.device)
        gt = batch_context.input['labels'].long().to(context.device)

        net1_logits = self.test_model(batch_context.input['images'])
        net_prediction = net1_logits.argmax(dim=1)

        batch_context.input['labels'] = (net_prediction != gt).long()

        logits = context.model(self.test_model.features)
        batch_context.output['logits'] = logits

        loss = self.criterion(logits, batch_context.input['labels'])
        loss.backward()
        context.optimizer.step()

        batch_context.metrics['loss'] = loss.item()
    def __call__(self, batch_context: ctx.BatchContext,
                 task_context: ctx.TaskContext, context: ctx.Context) -> None:
        if not isinstance(context,
                          (ctx.TorchTrainContext, ctx.TorchTestContext)):
            raise ValueError(
                msg.get_type_error_msg(
                    context, (ctx.TorchTrainContext, ctx.TorchTestContext)))

        batch_context.input['images'] = batch_context.input['images'].float(
        ).to(context.device)
        batch_context.input['labels'] = batch_context.input['labels'].long(
        ).to(context.device)

        pred = batch_context.input['labels'][:, 1]
        inpt = torch.cat(
            [batch_context.input['images'],
             pred.unsqueeze(1).float()], dim=1)

        logits = context.model(inpt)
        batch_context.output['logits'] = logits

        probabilities = F.softmax(logits, 1)
        batch_context.output['probabilities'] = probabilities
        # subject_eval needs clean (non-modified) labels
        batch_context.output['labels'] = batch_context.input['labels']
コード例 #3
0
    def __call__(self, batch_context: ctx.BatchContext,
                 task_context: ctx.TaskContext, context: ctx.Context) -> None:
        if not isinstance(context,
                          (ctx.TorchTrainContext, ctx.TorchTestContext)):
            raise ValueError(
                msg.get_type_error_msg(
                    context, (ctx.TorchTrainContext, ctx.TorchTestContext)))

        batch_context.input['images'] = batch_context.input['images'].float(
        ).to(context.device)
        batch_context.input['labels'] = batch_context.input['labels'].long(
        ).to(context.device)

        pred = batch_context.input['labels'][:, 1]
        inpt = torch.cat(
            [batch_context.input['images'],
             pred.unsqueeze(1).float()], dim=1)

        logits = context.model(inpt)
        batch_context.output['logits'] = logits

        probabilities = F.softmax(logits, 1)
        batch_context.output['probabilities'] = probabilities
        # add the existing prediction to be reproduced
        batch_context.output['orig_prediction'] = pred.unsqueeze(1)
    def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None:
        if not isinstance(context, ctx.TorchTrainContext):
            raise ValueError(msg.get_type_error_msg(context, ctx.TorchTrainContext))

        context.optimizer.zero_grad()

        batch_context.input['images'] = batch_context.input['images'].float().to(context.device)
        batch_context.input['labels'] = batch_context.input['labels'].long().to(context.device)

        prediciton = batch_context.input['labels'][:, 1, ...]
        gt = batch_context.input['labels'][:, 0, ...]

        labels = (prediciton != gt).long()
        # update with correct label for the evaluation
        batch_context.input['labels'] = labels

        inpt = torch.cat([batch_context.input['images'], prediciton.unsqueeze(1).float()], dim=1)
        logits = context.model(inpt)
        batch_context.output['logits'] = logits

        loss = self.criterion(logits, labels)
        loss.backward()
        context.optimizer.step()

        batch_context.metrics['loss'] = loss.item()
コード例 #5
0
    def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None:
        if not isinstance(context, (ctx.TorchTrainContext, ctx.TorchTestContext)):
            raise ValueError(msg.get_type_error_msg(context, (ctx.TorchTrainContext, ctx.TorchTestContext)))

        batch_context.input['images'] = batch_context.input['images'].float().to(context.device)

        ensemble_probabilities = []
        logits = context.model(batch_context.input['images'])
        probs = F.softmax(logits, 1)
        ensemble_probabilities.append(probs)
        for additional_model in self.additional_models:
            logits = additional_model(batch_context.input['images'])
            probs = F.softmax(logits, 1)
            ensemble_probabilities.append(probs)

        ensemble_probabilities = torch.stack(ensemble_probabilities)
        batch_context.output['multi_probabilities'] = ensemble_probabilities
    def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None:
        if not isinstance(context, ctx.TorchTestContext):
            raise ValueError(msg.get_type_error_msg(context, ctx.TorchTestContext))

        batch_context.input['images'] = batch_context.input['images'].float().to(context.device)

        mean_logits, sigma = context.model(batch_context.input['images'])
        batch_context.output['logits'] = mean_logits

        if self.is_log_sigma:
            sigma = sigma.exp()
        else:
            sigma = sigma.abs()
        batch_context.output['sigma'] = sigma

        probabilities = F.softmax(batch_context.output['logits'], 1)
        batch_context.output['probabilities'] = probabilities
    def __call__(self, batch_context: ctx.BatchContext,
                 task_context: ctx.TaskContext, context: ctx.Context) -> None:
        if not isinstance(context,
                          (ctx.TorchTrainContext, ctx.TorchTestContext)):
            raise ValueError(
                msg.get_type_error_msg(
                    context, (ctx.TorchTrainContext, ctx.TorchTestContext)))

        batch_context.input['images'] = batch_context.input['images'].float(
        ).to(context.device)

        net1_logits = self.test_model(batch_context.input['images'])
        net_prediction = net1_logits.argmax(dim=1, keepdim=True)
        batch_context.output['net_predictions'] = net_prediction

        logits = context.model(self.test_model.features)
        probabilities = F.softmax(logits, 1)
        batch_context.output['probabilities'] = probabilities
コード例 #8
0
    def __call__(self, batch_context: ctx.BatchContext,
                 task_context: ctx.TaskContext, context: ctx.Context) -> None:
        probabilities = F.softmax(batch_context.output['logits'], 1)
        probabilities = th.channel_to_end(probabilities).contiguous()
        _, prediction = probabilities.max(-1)

        batch_context.output['probabilities'] = probabilities
        batch_context.output['prediction'] = prediction

        to_eval = {
            'prediction': prediction,
            'probabilities': probabilities,
            'target': batch_context.input['labels']
        }
        results = {}
        self.evaluate(to_eval, results)
        batch_context.metrics.update(results)
        batch_context.score = results['dice']
    def __call__(self, batch_context: ctx.BatchContext,
                 task_context: ctx.TaskContext, context: ctx.Context) -> None:
        if not isinstance(context, ctx.TorchTrainContext):
            raise ValueError(
                msg.get_type_error_msg(context, ctx.TorchTrainContext))

        context.optimizer.zero_grad()

        batch_context.input['images'] = batch_context.input['images'].float(
        ).to(context.device)
        batch_context.input['labels'] = batch_context.input['labels'].long(
        ).to(context.device)

        mean_logits, sigma = context.model(batch_context.input['images'])
        loss = self.criterion(mean_logits, sigma,
                              batch_context.input['labels'])
        loss.backward()
        context.optimizer.step()

        batch_context.metrics['loss'] = loss.item()
コード例 #10
0
    def __call__(self, batch_context: ctx.BatchContext,
                 task_context: ctx.TaskContext, context: ctx.Context) -> None:
        if not isinstance(context,
                          (ctx.TorchTrainContext, ctx.TorchTestContext)):
            raise ValueError(
                msg.get_type_error_msg(
                    context, (ctx.TorchTrainContext, ctx.TorchTestContext)))

        batch_context.input['images'] = batch_context.input['images'].float(
        ).to(context.device)
        if self.has_labels:
            batch_context.input['labels'] = batch_context.input['labels'].long(
            ).to(context.device)

        logits = context.model(batch_context.input['images'])
        batch_context.output['logits'] = logits

        if self.do_probs:
            probabilities = F.softmax(logits, 1)
            batch_context.output['probabilities'] = probabilities
コード例 #11
0
    def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None:

        if self.remove_multi_probs:
            multi_probabilities = batch_context.output.pop('multi_probabilities')
        else:
            multi_probabilities = batch_context.output['multi_probabilities']

        probabilities = multi_probabilities.mean(dim=0)
        batch_context.output['probabilities'] = probabilities

        entropy = th.entropy(probabilities, dim=1, keepdim=True)
        batch_context.output['entropy'] = entropy

        if self.do_mi:
            expected_entropy = th.entropy(multi_probabilities, dim=2, keepdim=True).mean(dim=0)
            mutual_info = entropy - expected_entropy
            batch_context.output['mutual_info'] = mutual_info

        if self.do_var:
            # as done by the bayesian segnet -> not sure if best solution
            variance = multi_probabilities.var(dim=0).mean(dim=1, keepdim=True)
            batch_context.output['variance'] = variance
コード例 #12
0
    def __call__(self, batch_context: ctx.BatchContext, task_context: ctx.TaskContext, context: ctx.Context) -> None:
        if not isinstance(context, (ctx.TorchTrainContext, ctx.TorchTestContext)):
            raise ValueError(msg.get_type_error_msg(context, (ctx.TorchTrainContext, ctx.TorchTestContext)))

        batch_context.input['images'] = batch_context.input['images'].float().to(context.device)

        # weight scaling part, just for comparison
        ws_logits = context.model(batch_context.input['images'])
        ws_probabilities = F.softmax(ws_logits, 1)
        batch_context.output['ws_probabilities'] = ws_probabilities

        th.set_dropout_mode(context.model, is_train=True)

        # mc part
        mc_probabilities = []
        for i in range(self.mc_steps):
            logits = context.model(batch_context.input['images'])
            probs = F.softmax(logits, 1)
            mc_probabilities.append(probs)
        mc_probabilities = torch.stack(mc_probabilities)
        batch_context.output['multi_probabilities'] = mc_probabilities

        # reset to eval for next batch
        th.set_dropout_mode(context.model, is_train=False)
コード例 #13
0
 def __call__(self, batch_context: ctx.BatchContext,
              task_context: ctx.TaskContext, context: ctx.Context) -> None:
     batch_context.output['labels'] = batch_context.input[
         'labels'].unsqueeze(1)  # re-add previously removed dim
コード例 #14
0
 def __call__(self, batch_context: ctx.BatchContext,
              task_context: ctx.TaskContext, context: ctx.Context) -> None:
     probabilities = F.softmax(batch_context.output['logits'], 1)
     batch_context.output['probabilities'] = probabilities
     batch_context.output['labels'] = batch_context.input[
         'labels'].unsqueeze(1)  # re-add previously removed dim