Esempio n. 1
0
 def store_examples(engine: Engine):
     engine.state.examples = {
         "training": train_examples,
         "validation": val_examples
     }
     engine.logger.info("Example data ready")
     engine.fire_event(CustomEvents.EXAMPLE_DATA_READY)
Esempio n. 2
0
    def _update(engine: Engine, batch: Sequence[torch.Tensor]):
        loss_list = []
        hidden = None

        x, y = batch
        for batch_t in zip(x.split(tbtt_step, dim=dim),
                           y.split(tbtt_step, dim=dim)):
            x_t, y_t = prepare_batch(batch_t,
                                     device=device,
                                     non_blocking=non_blocking)
            # Fire event for start of iteration
            engine.fire_event(Tbptt_Events.TIME_ITERATION_STARTED)
            # Forward, backward and
            model.train()
            optimizer.zero_grad()
            if hidden is None:
                y_pred_t, hidden = model(x_t)
            else:
                hidden = _detach_hidden(hidden)
                y_pred_t, hidden = model(x_t, hidden)
            loss_t = loss_fn(y_pred_t, y_t)
            loss_t.backward()
            optimizer.step()

            # Setting state of engine for consistent behaviour
            engine.state.output = loss_t.item()
            loss_list.append(loss_t.item())

            # Fire event for end of iteration
            engine.fire_event(Tbptt_Events.TIME_ITERATION_COMPLETED)

        # return average loss over the time splits
        return sum(loss_list) / len(loss_list)
Esempio n. 3
0
    def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)  # type: ignore
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch

        # put iteration outputs into engine.state
        # engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}  # type: ignore
        engine.state.output = {self.keys["IMAGE"]: inputs, self.keys["LABEL"]: targets}

        for idx, network in enumerate(self.networks):
            with self.mode(network):
                if self.amp:
                    with torch.cuda.amp.autocast():
                        if isinstance(engine.state.output, dict):
                            engine.state.output.update(
                                {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}
                            )
                else:
                    if isinstance(engine.state.output, dict):
                        engine.state.output.update(
                            {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}
                        )
        engine.fire_event(IterationEvents.FORWARD_COMPLETED)
        engine.fire_event(IterationEvents.MODEL_COMPLETED)

        return engine.state.output
Esempio n. 4
0
 def predict_on_examples(engine: Engine):
     model.eval()
     for tag, (x, y) in engine.state.examples.items():
         with torch.no_grad():
             y_pred = model(x.to(device=args.device))
             y_pred = y_pred.detach().cpu()
         engine.state.examples[tag] = (x, y, y_pred)
     engine.logger.info("Example predictions ready")
     engine.fire_event(CustomEvents.EXAMPLE_PREDICTIONS_READY)
Esempio n. 5
0
def train_function(
    config: Any,
    engine: Engine,
    batch: Any,
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    optimizer: Optimizer,
    device: torch.device,
):
    """Model training step.

    Parameters
    ----------
    config
        config object
    engine
        Engine instance
    batch
        batch in current iteration
    model
        nn.Module model
    loss_fn
        nn.Module loss
    optimizer
        torch optimizer
    device
        device to use for training

    Returns
    -------
    {INSERT HERE}
    """

    model.train()

    samples = batch[0].to(device, non_blocking=True)
    targets = batch[1].to(device, non_blocking=True)

    with autocast(enabled=config.use_amp):
        outputs = model(samples)
        loss = loss_fn(outputs, targets)

    loss.backward()
    engine.state.backward_completed += 1
    engine.fire_event(TrainEvents.BACKWARD_COMPLETED)

    optimizer.step()
    engine.state.optim_step_completed += 1
    engine.fire_event(TrainEvents.OPTIM_STEP_COMPLETED)

    optimizer.zero_grad()

    loss_value = loss.item()
    engine.state.metrics = {"epoch": engine.state.epoch, "train_loss": loss_value}
    return loss_value
 def __call__(self, engine: Engine):
     for reward, steps in self._exp_source.pop_rewards_steps():
         engine.state.episode = getattr(engine.state, "episode", 0) + 1
         engine.state.episode_reward = reward
         engine.state.episode_steps = steps
         engine.state.metrics['reward'] = reward
         engine.state.metrics['steps'] = steps
         self._update_smoothed_metrics(engine, reward, steps)
         engine.fire_event(self.Events.EPISODE_COMPLETED)
         if self._bound_avg_reward is not None and engine.state.metrics['avg_reward'] >= self._bound_avg_reward:
             engine.fire_event(self.Events.BOUND_REWARD_REACHED)
Esempio n. 7
0
    def _iteration(
            self, engine: Engine,
            batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - pred_keys[0]: prediction result of network 0.
            - pred_keys[1]: prediction result of network 1.
            - ... ...
            - pred_keys[N]: prediction result of network N.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        """
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device,
                                   engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch

        # put iteration outputs into engine.state
        engine.state.output = output = {
            Keys.IMAGE: inputs,
            Keys.LABEL: targets
        }
        for idx, network in enumerate(self.networks):
            with eval_mode(network):
                if self.amp:
                    with torch.cuda.amp.autocast():
                        output.update({
                            self.pred_keys[idx]:
                            self.inferer(inputs, network, *args, **kwargs)
                        })
                else:
                    output.update({
                        self.pred_keys[idx]:
                        self.inferer(inputs, network, *args, **kwargs)
                    })
        engine.fire_event(IterationEvents.FORWARD_COMPLETED)

        return output
Esempio n. 8
0
    def __call__(self, engine: Engine):
        # after initialization of the buffer, calling pop_rewards_steps
        # can return more than one entry.
        # however during "normal learning process", calling this method
        # returns only one tuple after an episode has finished.
        for reward, steps in self._exp_source.pop_rewards_steps():
            engine.state.episode = getattr(engine.state, "episode", 0) + 1
            engine.state.episode_reward = reward
            engine.state.episode_steps = steps
            engine.state.metrics['reward'] = reward
            engine.state.metrics['steps'] = steps
            self._update_smoothed_metrics(engine, reward, steps)

            # normally, there is no subsampling of episode actives. So generally this
            # if is executed
            if self._subsample_end_of_episode is None or engine.state.episode % self._subsample_end_of_episode == 0:
                engine.fire_event(EpisodeEvents.EPISODE_COMPLETED)

            if self._bound_avg_reward is not None and engine.state.metrics[
                    'avg_reward'] >= self._bound_avg_reward:
                engine.fire_event(EpisodeEvents.BOUND_REWARD_REACHED)

            if self._best_avg_reward is None:
                self._best_avg_reward = engine.state.metrics['avg_reward']
            elif self._best_avg_reward < engine.state.metrics['avg_reward']:
                engine.fire_event(EpisodeEvents.BEST_REWARD_REACHED)
                self._best_avg_reward = engine.state.metrics['avg_reward']
Esempio n. 9
0
    def test_compute(self, input_params, expected_avg, details_shape):
        dice_metric = MeanDice(**input_params)

        # set up engine

        def _val_func(engine, batch):
            pass

        engine = Engine(_val_func)
        dice_metric.attach(engine=engine, name="mean_dice")
        # test input a list of channel-first tensor
        y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]
        y = torch.Tensor([[[0], [1]], [[0], [1]]])
        engine.state.output = {"pred": y_pred, "label": y}
        engine.fire_event(Events.ITERATION_COMPLETED)

        y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]
        y = torch.Tensor([[[0], [1]], [[1], [0]]])
        engine.state.output = {"pred": y_pred, "label": y}
        engine.fire_event(Events.ITERATION_COMPLETED)

        engine.fire_event(Events.EPOCH_COMPLETED)
        torch.testing.assert_allclose(engine.state.metrics["mean_dice"],
                                      expected_avg)
        self.assertTupleEqual(
            tuple(engine.state.metric_details["mean_dice"].shape),
            details_shape)
Esempio n. 10
0
    def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
        """
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - PRED: prediction result of model.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        """
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)  # type: ignore
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch

        # put iteration outputs into engine.state
        engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}  # type: ignore

        # execute forward computation
        with self.mode(self.network):
            if self.amp:
                with torch.cuda.amp.autocast():
                    engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs)  # type: ignore
            else:
                engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs)  # type: ignore
        engine.fire_event(IterationEvents.FORWARD_COMPLETED)
        engine.fire_event(IterationEvents.MODEL_COMPLETED)

        return engine.state.output
Esempio n. 11
0
    def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch

        # put iteration outputs into engine.state
        engine.state.output = {self.keys["IMAGE"]: inputs, self.keys["LABEL"]: targets}
        # execute forward computation
        with self.mode(self.network):
            if self.amp:
                with torch.cuda.amp.autocast():
                    if self.network_latent:
                        (
                            engine.state.output[self.keys["PRED"]],
                            engine.state.output[self.keys["FORWARD"]],
                            engine.state.output[self.keys["BACKWARD"]],
                        ) = self.inferer(inputs, self.network_latent, *args, **kwargs)
                    else:
                        engine.state.output[self.keys["PRED"]] = self.inferer(inputs, self.network, *args, **kwargs)
            else:
                if self.network_latent:
                    (
                        engine.state.output[self.keys["PRED"]],
                        engine.state.output[self.keys["FORWARD"]],
                        engine.state.output[self.keys["BACKWARD"]],
                    ) = self.inferer(inputs, self.network_latent, *args, **kwargs)
                else:
                    engine.state.output[self.keys["PRED"]] = self.inferer(inputs, self.network, *args, **kwargs)
        engine.fire_event(IterationEvents.FORWARD_COMPLETED)
        engine.fire_event(IterationEvents.MODEL_COMPLETED)

        return engine.state.output
Esempio n. 12
0
    def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
        """
        Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - PRED: prediction result of model.
            - LOSS: loss value computed by loss function.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        """
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device,
                                   engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch
        # put iteration outputs into engine.state
        engine.state.output = output = {
            Keys.IMAGE: inputs,
            Keys.LABEL: targets
        }

        def _compute_pred_loss():
            output[Keys.PRED] = self.inferer(inputs, self.network, *args,
                                             **kwargs)
            engine.fire_event(IterationEvents.FORWARD_COMPLETED)
            output[Keys.LOSS] = self.loss_function(output[Keys.PRED],
                                                   targets).mean()
            engine.fire_event(IterationEvents.LOSS_COMPLETED)

        self.network.train()
        self.optimizer.zero_grad()
        if self.amp and self.scaler is not None:
            with torch.cuda.amp.autocast():
                _compute_pred_loss()
            self.scaler.scale(output[Keys.LOSS]).backward()
            engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            _compute_pred_loss()
            output[Keys.LOSS].backward()
            engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
            self.optimizer.step()
        engine.fire_event(IterationEvents.OPTIMIZER_COMPLETED)

        return output
Esempio n. 13
0
    def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
        if batchdata is None:
            raise StrixException(
                "No data were fed into the Trainer engine. "
                "Consider the possibility that Transforms did not succeed or "
                "there is a problem with your dataset.")
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device,
                                   engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch
        # put iteration outputs into engine.state
        engine.state.output = {
            self.keys["IMAGE"]: inputs,
            self.keys["LABEL"]: targets
        }

        def _compute_pred_loss():
            engine.state.output[self.keys["PRED"]] = self.inferer(
                inputs, self.network, *args, **kwargs)
            engine.fire_event(IterationEvents.FORWARD_COMPLETED)

            if self.ensure_dims:
                engine.state.output[self.keys["LOSS"]] = self.loss_function(
                    *ensure_same_dim(engine.state.output[self.keys["PRED"]],
                                     targets)).mean()
            else:
                engine.state.output[self.keys["LOSS"]] = self.loss_function(
                    engine.state.output[self.keys["PRED"]], targets).mean()

            engine.fire_event(IterationEvents.LOSS_COMPLETED)

        self.network.train()
        self.optimizer.zero_grad()
        if self.amp and self.scaler is not None:
            with torch.cuda.amp.autocast():
                _compute_pred_loss()
            self.scaler.scale(
                engine.state.output[self.keys["LOSS"]]).backward()
            engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            _compute_pred_loss()
            engine.state.output[self.keys["LOSS"]].backward()
            engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
            self.optimizer.step()
        engine.fire_event(IterationEvents.MODEL_COMPLETED)

        return engine.state.output
Esempio n. 14
0
 def __call__(self, engine: Engine):
     for reward, steps in self._exp_source.pop_rewards_steps():
         engine.state.episode = getattr(engine.state, "episode", 0) + 1
         engine.state.episode_reward = reward
         engine.state.episode_steps = steps
         engine.state.metrics['reward'] = reward
         engine.state.metrics['steps'] = steps
         self._update_smoothed_metrics(engine, reward, steps)
         if self._subsample_end_of_episode is None or engine.state.episode % self._subsample_end_of_episode == 0:
             engine.fire_event(EpisodeEvents.EPISODE_COMPLETED)
         if self._bound_avg_reward is not None and engine.state.metrics['avg_reward'] >= self._bound_avg_reward:
             engine.fire_event(EpisodeEvents.BOUND_REWARD_REACHED)
         if self._best_avg_reward is None:
             self._best_avg_reward = engine.state.metrics['avg_reward']
         elif self._best_avg_reward < engine.state.metrics['avg_reward']:
             engine.fire_event(EpisodeEvents.BEST_REWARD_REACHED)
             self._best_avg_reward = engine.state.metrics['avg_reward']
Esempio n. 15
0
    def _iteration(self, engine: Engine,
                   batchdata: Dict[str, Any]) -> Dict[str, torch.Tensor]:
        """
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below item in a dictionary:
            - PRED: prediction result of model.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        """
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device,
                                   engine.non_blocking)
        if len(batch) == 2:
            inputs, _ = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, _, args, kwargs = batch

        def _compute_pred():
            ct = 1.0
            pred = self.inferer(inputs, self.network, *args, **kwargs).cpu()
            pred = nn.functional.softmax(pred, dim=1)
            if not self.tta_val:
                return pred
            else:
                for dims in [[2], [3], [4], (2, 3), (2, 4), (3, 4), (2, 3, 4)]:
                    flip_inputs = torch.flip(inputs, dims=dims)
                    flip_pred = torch.flip(self.inferer(
                        flip_inputs, self.network).cpu(),
                                           dims=dims)
                    flip_pred = nn.functional.softmax(flip_pred, dim=1)
                    del flip_inputs
                    pred += flip_pred
                    del flip_pred
                    ct += 1
                return pred / ct

        # execute forward computation
        with eval_mode(self.network):
            if self.amp:
                with torch.cuda.amp.autocast():
                    predictions = _compute_pred()
            else:
                predictions = _compute_pred()

        inputs = inputs.cpu()
        predictions = self.post_pred(predictions)

        affine = batchdata["image_meta_dict"]["affine"].numpy()[0]
        resample_flag = batchdata["resample_flag"]
        anisotrophy_flag = batchdata["anisotrophy_flag"]
        crop_shape = batchdata["crop_shape"][0].tolist()
        original_shape = batchdata["original_shape"][0].tolist()

        if resample_flag:
            # convert the prediction back to the original (after cropped) shape
            predictions = recovery_prediction(predictions.numpy()[0],
                                              [self.n_classes, *crop_shape],
                                              anisotrophy_flag)
        else:
            predictions = predictions.numpy()

        predictions = predictions[0]
        predictions = np.argmax(predictions, axis=0)

        # pad the prediction back to the original shape
        predictions_org = np.zeros([*original_shape])
        box_start, box_end = batchdata["bbox"][0]
        h_start, w_start, d_start = box_start
        h_end, w_end, d_end = box_end
        predictions_org[h_start:h_end, w_start:w_end,
                        d_start:d_end] = predictions
        del predictions

        filename = batchdata["image_meta_dict"]["filename_or_obj"][0].split(
            "/")[-1]

        print("save {} with shape: {}, mean values: {}".format(
            filename, predictions_org.shape, predictions_org.mean()))
        write_nifti(
            data=predictions_org,
            file_name=os.path.join(self.output_dir, filename),
            affine=affine,
            resample=False,
            output_dtype=np.uint8,
        )
        engine.fire_event(IterationEvents.FORWARD_COMPLETED)
        return {"pred": predictions_org}
Esempio n. 16
0
    def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device,
                                   engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch
        # put iteration outputs into engine.state
        engine.state.output = {
            self.keys["IMAGE"]: inputs,
            self.keys["LABEL"]: targets
        }

        def _compute_pred_loss():
            preds = self.inferer(inputs, self.network, *args, **kwargs)
            engine.state.output[self.keys["PRED"]] = preds
            engine.fire_event(IterationEvents.FORWARD_COMPLETED)
            if not isinstance(preds, tuple):
                raise ValueError(
                    "Predictions must be tuple in multi-task framework",
                    f"but got {type(engine.state.output[self.keys['PRED']])}",
                )
            if not isinstance(targets, tuple):
                raise ValueError(
                    f"Targets must be tuple in multi-task framework, but got {type(targets)}"
                )
            if len(preds) != len(targets):
                raise ValueError(
                    f"Predictions len must equal to targets, but got {len(preds)} != {len(targets)}"
                )

            loss = self.loss_function(preds, targets)

            engine.state.output[self.keys["LOSS"]] = loss
            engine.fire_event(IterationEvents.LOSS_COMPLETED)

        self.network.train()
        # `set_to_none` only work from PyTorch 1.7.0
        if not pytorch_after(1, 7):
            self.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad(set_to_none=self.optim_set_to_none)

        if self.amp and self.scaler is not None:
            with torch.cuda.amp.autocast():
                _compute_pred_loss()
            self.scaler.scale(
                engine.state.output[self.keys["LOSS"]]).backward()
            engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            _compute_pred_loss()
            engine.state.output[self.keys["LOSS"]].backward()
            engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
            self.optimizer.step()
        engine.fire_event(IterationEvents.MODEL_COMPLETED)

        return engine.state.output
Esempio n. 17
0
    def _iteration(self, engine: Engine, batchdata: Dict[str, Any]):
        """
        Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - PRED: prediction result of model.
            - LOSS: loss value computed by loss function.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        """
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device,
                                   engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch
        # put iteration outputs into engine.state
        engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}

        def _compute_pred_loss():
            preds = self.inferer(inputs, self.network, *args, **kwargs)
            if len(preds.size()) - len(targets.size()) == 1:
                # deep supervision mode, need to unbind feature maps first.
                preds = torch.unbind(preds, dim=1)
            engine.state.output[Keys.PRED] = preds
            del preds
            engine.fire_event(IterationEvents.FORWARD_COMPLETED)
            engine.state.output[Keys.LOSS] = sum(
                0.5**i * self.loss_function.forward(p, targets)
                for i, p in enumerate(engine.state.output[Keys.PRED]))
            engine.fire_event(IterationEvents.LOSS_COMPLETED)

        self.network.train()
        self.optimizer.zero_grad()
        if self.amp and self.scaler is not None:
            with torch.cuda.amp.autocast():
                _compute_pred_loss()
            self.scaler.scale(engine.state.output[Keys.LOSS]).backward()
            self.scaler.unscale_(self.optimizer)
            if isinstance(self.network, DistributedDataParallel):
                torch.nn.utils.clip_grad_norm_(
                    self.network.module.parameters(), 12)
            else:
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            _compute_pred_loss()
            engine.state.output[Keys.LOSS].backward()
            engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
            if isinstance(self.network, DistributedDataParallel):
                torch.nn.utils.clip_grad_norm_(
                    self.network.module.parameters(), 12)
            else:
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()
            engine.fire_event(IterationEvents.MODEL_COMPLETED)

        return engine.state.output
Esempio n. 18
0
 def eval_and_log(e: Engine):
     eval_results = self.eval()
     e.state.metrics['val_accuracy'] = eval_results['val'].metrics['accuracy'] 
     e.state.metrics['val_loss'] = eval_results['val'].metrics['avg_loss']
     e.state.eval_results = eval_results
     e.fire_event("EVAL_DONE")
Esempio n. 19
0
    def _iteration(self, engine: Engine,
                   batchdata: Dict[str, Any]) -> Dict[str, torch.Tensor]:
        """
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - PRED: prediction result of model.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        """
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device,
                                   engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch

        targets = targets.cpu()

        def _compute_pred():
            ct = 1.0
            pred = self.inferer(inputs, self.network, *args, **kwargs).cpu()
            pred = nn.functional.softmax(pred, dim=1)
            if not self.tta_val:
                return pred
            else:
                for dims in [[2], [3], [4], (2, 3), (2, 4), (3, 4), (2, 3, 4)]:
                    flip_inputs = torch.flip(inputs, dims=dims)
                    flip_pred = torch.flip(self.inferer(
                        flip_inputs, self.network).cpu(),
                                           dims=dims)
                    flip_pred = nn.functional.softmax(flip_pred, dim=1)
                    del flip_inputs
                    pred += flip_pred
                    del flip_pred
                    ct += 1
                return pred / ct

        # execute forward computation
        with eval_mode(self.network):
            if self.amp:
                with torch.cuda.amp.autocast():
                    predictions = _compute_pred()
            else:
                predictions = _compute_pred()

        inputs = inputs.cpu()

        predictions = self.post_pred(decollate_batch(predictions)[0])
        targets = self.post_label(decollate_batch(targets)[0])

        resample_flag = batchdata["resample_flag"]
        anisotrophy_flag = batchdata["anisotrophy_flag"]
        crop_shape = batchdata["crop_shape"][0].tolist()
        original_shape = batchdata["original_shape"][0].tolist()
        if resample_flag:
            # convert the prediction back to the original (after cropped) shape
            predictions = recovery_prediction(predictions.numpy(),
                                              [self.num_classes, *crop_shape],
                                              anisotrophy_flag)
            predictions = torch.tensor(predictions)

        # put iteration outputs into engine.state
        engine.state.output = {
            Keys.IMAGE: inputs,
            Keys.LABEL: targets.unsqueeze(0)
        }
        engine.state.output[Keys.PRED] = torch.zeros(
            [1, self.num_classes, *original_shape])
        # pad the prediction back to the original shape
        box_start, box_end = batchdata["bbox"][0]
        h_start, w_start, d_start = box_start
        h_end, w_end, d_end = box_end

        engine.state.output[Keys.PRED][0, :, h_start:h_end, w_start:w_end,
                                       d_start:d_end] = predictions
        del predictions

        engine.fire_event(IterationEvents.FORWARD_COMPLETED)
        engine.fire_event(IterationEvents.MODEL_COMPLETED)

        return engine.state.output
Esempio n. 20
0
 def __call__(self, engine: Engine):
     for period, event in self.INTERVAL_TO_EVENT.items():
         if engine.state.iteration % period == 0:
             engine.fire_event(event)