def store_examples(engine: Engine): engine.state.examples = { "training": train_examples, "validation": val_examples } engine.logger.info("Example data ready") engine.fire_event(CustomEvents.EXAMPLE_DATA_READY)
def _update(engine: Engine, batch: Sequence[torch.Tensor]): loss_list = [] hidden = None x, y = batch for batch_t in zip(x.split(tbtt_step, dim=dim), y.split(tbtt_step, dim=dim)): x_t, y_t = prepare_batch(batch_t, device=device, non_blocking=non_blocking) # Fire event for start of iteration engine.fire_event(Tbptt_Events.TIME_ITERATION_STARTED) # Forward, backward and model.train() optimizer.zero_grad() if hidden is None: y_pred_t, hidden = model(x_t) else: hidden = _detach_hidden(hidden) y_pred_t, hidden = model(x_t, hidden) loss_t = loss_fn(y_pred_t, y_t) loss_t.backward() optimizer.step() # Setting state of engine for consistent behaviour engine.state.output = loss_t.item() loss_list.append(loss_t.item()) # Fire event for end of iteration engine.fire_event(Tbptt_Events.TIME_ITERATION_COMPLETED) # return average loss over the time splits return sum(loss_list) / len(loss_list)
def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () kwargs: Dict = {} else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state # engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore engine.state.output = {self.keys["IMAGE"]: inputs, self.keys["LABEL"]: targets} for idx, network in enumerate(self.networks): with self.mode(network): if self.amp: with torch.cuda.amp.autocast(): if isinstance(engine.state.output, dict): engine.state.output.update( {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)} ) else: if isinstance(engine.state.output, dict): engine.state.output.update( {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)} ) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output
def predict_on_examples(engine: Engine): model.eval() for tag, (x, y) in engine.state.examples.items(): with torch.no_grad(): y_pred = model(x.to(device=args.device)) y_pred = y_pred.detach().cpu() engine.state.examples[tag] = (x, y, y_pred) engine.logger.info("Example predictions ready") engine.fire_event(CustomEvents.EXAMPLE_PREDICTIONS_READY)
def train_function( config: Any, engine: Engine, batch: Any, model: torch.nn.Module, loss_fn: torch.nn.Module, optimizer: Optimizer, device: torch.device, ): """Model training step. Parameters ---------- config config object engine Engine instance batch batch in current iteration model nn.Module model loss_fn nn.Module loss optimizer torch optimizer device device to use for training Returns ------- {INSERT HERE} """ model.train() samples = batch[0].to(device, non_blocking=True) targets = batch[1].to(device, non_blocking=True) with autocast(enabled=config.use_amp): outputs = model(samples) loss = loss_fn(outputs, targets) loss.backward() engine.state.backward_completed += 1 engine.fire_event(TrainEvents.BACKWARD_COMPLETED) optimizer.step() engine.state.optim_step_completed += 1 engine.fire_event(TrainEvents.OPTIM_STEP_COMPLETED) optimizer.zero_grad() loss_value = loss.item() engine.state.metrics = {"epoch": engine.state.epoch, "train_loss": loss_value} return loss_value
def __call__(self, engine: Engine): for reward, steps in self._exp_source.pop_rewards_steps(): engine.state.episode = getattr(engine.state, "episode", 0) + 1 engine.state.episode_reward = reward engine.state.episode_steps = steps engine.state.metrics['reward'] = reward engine.state.metrics['steps'] = steps self._update_smoothed_metrics(engine, reward, steps) engine.fire_event(self.Events.EPISODE_COMPLETED) if self._bound_avg_reward is not None and engine.state.metrics['avg_reward'] >= self._bound_avg_reward: engine.fire_event(self.Events.BOUND_REWARD_REACHED)
def _iteration( self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: - IMAGE: image Tensor data for model input, already moved to device. - LABEL: label Tensor data corresponding to the image, already moved to device. - pred_keys[0]: prediction result of network 0. - pred_keys[1]: prediction result of network 1. - ... ... - pred_keys[N]: prediction result of network N. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: ValueError: When ``batchdata`` is None. """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, targets = batch args: Tuple = () kwargs: Dict = {} else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state engine.state.output = output = { Keys.IMAGE: inputs, Keys.LABEL: targets } for idx, network in enumerate(self.networks): with eval_mode(network): if self.amp: with torch.cuda.amp.autocast(): output.update({ self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs) }) else: output.update({ self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs) }) engine.fire_event(IterationEvents.FORWARD_COMPLETED) return output
def __call__(self, engine: Engine): # after initialization of the buffer, calling pop_rewards_steps # can return more than one entry. # however during "normal learning process", calling this method # returns only one tuple after an episode has finished. for reward, steps in self._exp_source.pop_rewards_steps(): engine.state.episode = getattr(engine.state, "episode", 0) + 1 engine.state.episode_reward = reward engine.state.episode_steps = steps engine.state.metrics['reward'] = reward engine.state.metrics['steps'] = steps self._update_smoothed_metrics(engine, reward, steps) # normally, there is no subsampling of episode actives. So generally this # if is executed if self._subsample_end_of_episode is None or engine.state.episode % self._subsample_end_of_episode == 0: engine.fire_event(EpisodeEvents.EPISODE_COMPLETED) if self._bound_avg_reward is not None and engine.state.metrics[ 'avg_reward'] >= self._bound_avg_reward: engine.fire_event(EpisodeEvents.BOUND_REWARD_REACHED) if self._best_avg_reward is None: self._best_avg_reward = engine.state.metrics['avg_reward'] elif self._best_avg_reward < engine.state.metrics['avg_reward']: engine.fire_event(EpisodeEvents.BEST_REWARD_REACHED) self._best_avg_reward = engine.state.metrics['avg_reward']
def test_compute(self, input_params, expected_avg, details_shape): dice_metric = MeanDice(**input_params) # set up engine def _val_func(engine, batch): pass engine = Engine(_val_func) dice_metric.attach(engine=engine, name="mean_dice") # test input a list of channel-first tensor y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] y = torch.Tensor([[[0], [1]], [[0], [1]]]) engine.state.output = {"pred": y_pred, "label": y} engine.fire_event(Events.ITERATION_COMPLETED) y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] y = torch.Tensor([[[0], [1]], [[1], [0]]]) engine.state.output = {"pred": y_pred, "label": y} engine.fire_event(Events.ITERATION_COMPLETED) engine.fire_event(Events.EPOCH_COMPLETED) torch.testing.assert_allclose(engine.state.metrics["mean_dice"], expected_avg) self.assertTupleEqual( tuple(engine.state.metric_details["mean_dice"].shape), details_shape)
def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: - IMAGE: image Tensor data for model input, already moved to device. - LABEL: label Tensor data corresponding to the image, already moved to device. - PRED: prediction result of model. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: ValueError: When ``batchdata`` is None. """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () kwargs: Dict = {} else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore # execute forward computation with self.mode(self.network): if self.amp: with torch.cuda.amp.autocast(): engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore else: engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output
def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, targets = batch args: Tuple = () kwargs: Dict = {} else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state engine.state.output = {self.keys["IMAGE"]: inputs, self.keys["LABEL"]: targets} # execute forward computation with self.mode(self.network): if self.amp: with torch.cuda.amp.autocast(): if self.network_latent: ( engine.state.output[self.keys["PRED"]], engine.state.output[self.keys["FORWARD"]], engine.state.output[self.keys["BACKWARD"]], ) = self.inferer(inputs, self.network_latent, *args, **kwargs) else: engine.state.output[self.keys["PRED"]] = self.inferer(inputs, self.network, *args, **kwargs) else: if self.network_latent: ( engine.state.output[self.keys["PRED"]], engine.state.output[self.keys["FORWARD"]], engine.state.output[self.keys["BACKWARD"]], ) = self.inferer(inputs, self.network_latent, *args, **kwargs) else: engine.state.output[self.keys["PRED"]] = self.inferer(inputs, self.network, *args, **kwargs) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output
def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: - IMAGE: image Tensor data for model input, already moved to device. - LABEL: label Tensor data corresponding to the image, already moved to device. - PRED: prediction result of model. - LOSS: loss value computed by loss function. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: ValueError: When ``batchdata`` is None. """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, targets = batch args: Tuple = () kwargs: Dict = {} else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state engine.state.output = output = { Keys.IMAGE: inputs, Keys.LABEL: targets } def _compute_pred_loss(): output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) engine.fire_event(IterationEvents.FORWARD_COMPLETED) output[Keys.LOSS] = self.loss_function(output[Keys.PRED], targets).mean() engine.fire_event(IterationEvents.LOSS_COMPLETED) self.network.train() self.optimizer.zero_grad() if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): _compute_pred_loss() self.scaler.scale(output[Keys.LOSS]).backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.scaler.step(self.optimizer) self.scaler.update() else: _compute_pred_loss() output[Keys.LOSS].backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.optimizer.step() engine.fire_event(IterationEvents.OPTIMIZER_COMPLETED) return output
def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): if batchdata is None: raise StrixException( "No data were fed into the Trainer engine. " "Consider the possibility that Transforms did not succeed or " "there is a problem with your dataset.") raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, targets = batch args: Tuple = () kwargs: Dict = {} else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state engine.state.output = { self.keys["IMAGE"]: inputs, self.keys["LABEL"]: targets } def _compute_pred_loss(): engine.state.output[self.keys["PRED"]] = self.inferer( inputs, self.network, *args, **kwargs) engine.fire_event(IterationEvents.FORWARD_COMPLETED) if self.ensure_dims: engine.state.output[self.keys["LOSS"]] = self.loss_function( *ensure_same_dim(engine.state.output[self.keys["PRED"]], targets)).mean() else: engine.state.output[self.keys["LOSS"]] = self.loss_function( engine.state.output[self.keys["PRED"]], targets).mean() engine.fire_event(IterationEvents.LOSS_COMPLETED) self.network.train() self.optimizer.zero_grad() if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): _compute_pred_loss() self.scaler.scale( engine.state.output[self.keys["LOSS"]]).backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.scaler.step(self.optimizer) self.scaler.update() else: _compute_pred_loss() engine.state.output[self.keys["LOSS"]].backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.optimizer.step() engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output
def __call__(self, engine: Engine): for reward, steps in self._exp_source.pop_rewards_steps(): engine.state.episode = getattr(engine.state, "episode", 0) + 1 engine.state.episode_reward = reward engine.state.episode_steps = steps engine.state.metrics['reward'] = reward engine.state.metrics['steps'] = steps self._update_smoothed_metrics(engine, reward, steps) if self._subsample_end_of_episode is None or engine.state.episode % self._subsample_end_of_episode == 0: engine.fire_event(EpisodeEvents.EPISODE_COMPLETED) if self._bound_avg_reward is not None and engine.state.metrics['avg_reward'] >= self._bound_avg_reward: engine.fire_event(EpisodeEvents.BOUND_REWARD_REACHED) if self._best_avg_reward is None: self._best_avg_reward = engine.state.metrics['avg_reward'] elif self._best_avg_reward < engine.state.metrics['avg_reward']: engine.fire_event(EpisodeEvents.BEST_REWARD_REACHED) self._best_avg_reward = engine.state.metrics['avg_reward']
def _iteration(self, engine: Engine, batchdata: Dict[str, Any]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below item in a dictionary: - PRED: prediction result of model. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: ValueError: When ``batchdata`` is None. """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, _ = batch args: Tuple = () kwargs: Dict = {} else: inputs, _, args, kwargs = batch def _compute_pred(): ct = 1.0 pred = self.inferer(inputs, self.network, *args, **kwargs).cpu() pred = nn.functional.softmax(pred, dim=1) if not self.tta_val: return pred else: for dims in [[2], [3], [4], (2, 3), (2, 4), (3, 4), (2, 3, 4)]: flip_inputs = torch.flip(inputs, dims=dims) flip_pred = torch.flip(self.inferer( flip_inputs, self.network).cpu(), dims=dims) flip_pred = nn.functional.softmax(flip_pred, dim=1) del flip_inputs pred += flip_pred del flip_pred ct += 1 return pred / ct # execute forward computation with eval_mode(self.network): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() inputs = inputs.cpu() predictions = self.post_pred(predictions) affine = batchdata["image_meta_dict"]["affine"].numpy()[0] resample_flag = batchdata["resample_flag"] anisotrophy_flag = batchdata["anisotrophy_flag"] crop_shape = batchdata["crop_shape"][0].tolist() original_shape = batchdata["original_shape"][0].tolist() if resample_flag: # convert the prediction back to the original (after cropped) shape predictions = recovery_prediction(predictions.numpy()[0], [self.n_classes, *crop_shape], anisotrophy_flag) else: predictions = predictions.numpy() predictions = predictions[0] predictions = np.argmax(predictions, axis=0) # pad the prediction back to the original shape predictions_org = np.zeros([*original_shape]) box_start, box_end = batchdata["bbox"][0] h_start, w_start, d_start = box_start h_end, w_end, d_end = box_end predictions_org[h_start:h_end, w_start:w_end, d_start:d_end] = predictions del predictions filename = batchdata["image_meta_dict"]["filename_or_obj"][0].split( "/")[-1] print("save {} with shape: {}, mean values: {}".format( filename, predictions_org.shape, predictions_org.mean())) write_nifti( data=predictions_org, file_name=os.path.join(self.output_dir, filename), affine=affine, resample=False, output_dtype=np.uint8, ) engine.fire_event(IterationEvents.FORWARD_COMPLETED) return {"pred": predictions_org}
def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, targets = batch args: Tuple = () kwargs: Dict = {} else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state engine.state.output = { self.keys["IMAGE"]: inputs, self.keys["LABEL"]: targets } def _compute_pred_loss(): preds = self.inferer(inputs, self.network, *args, **kwargs) engine.state.output[self.keys["PRED"]] = preds engine.fire_event(IterationEvents.FORWARD_COMPLETED) if not isinstance(preds, tuple): raise ValueError( "Predictions must be tuple in multi-task framework", f"but got {type(engine.state.output[self.keys['PRED']])}", ) if not isinstance(targets, tuple): raise ValueError( f"Targets must be tuple in multi-task framework, but got {type(targets)}" ) if len(preds) != len(targets): raise ValueError( f"Predictions len must equal to targets, but got {len(preds)} != {len(targets)}" ) loss = self.loss_function(preds, targets) engine.state.output[self.keys["LOSS"]] = loss engine.fire_event(IterationEvents.LOSS_COMPLETED) self.network.train() # `set_to_none` only work from PyTorch 1.7.0 if not pytorch_after(1, 7): self.optimizer.zero_grad() else: self.optimizer.zero_grad(set_to_none=self.optim_set_to_none) if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): _compute_pred_loss() self.scaler.scale( engine.state.output[self.keys["LOSS"]]).backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.scaler.step(self.optimizer) self.scaler.update() else: _compute_pred_loss() engine.state.output[self.keys["LOSS"]].backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.optimizer.step() engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output
def _iteration(self, engine: Engine, batchdata: Dict[str, Any]): """ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: - IMAGE: image Tensor data for model input, already moved to device. - LABEL: label Tensor data corresponding to the image, already moved to device. - PRED: prediction result of model. - LOSS: loss value computed by loss function. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: ValueError: When ``batchdata`` is None. """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, targets = batch args: Tuple = () kwargs: Dict = {} else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} def _compute_pred_loss(): preds = self.inferer(inputs, self.network, *args, **kwargs) if len(preds.size()) - len(targets.size()) == 1: # deep supervision mode, need to unbind feature maps first. preds = torch.unbind(preds, dim=1) engine.state.output[Keys.PRED] = preds del preds engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.state.output[Keys.LOSS] = sum( 0.5**i * self.loss_function.forward(p, targets) for i, p in enumerate(engine.state.output[Keys.PRED])) engine.fire_event(IterationEvents.LOSS_COMPLETED) self.network.train() self.optimizer.zero_grad() if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): _compute_pred_loss() self.scaler.scale(engine.state.output[Keys.LOSS]).backward() self.scaler.unscale_(self.optimizer) if isinstance(self.network, DistributedDataParallel): torch.nn.utils.clip_grad_norm_( self.network.module.parameters(), 12) else: torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.scaler.step(self.optimizer) self.scaler.update() else: _compute_pred_loss() engine.state.output[Keys.LOSS].backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) if isinstance(self.network, DistributedDataParallel): torch.nn.utils.clip_grad_norm_( self.network.module.parameters(), 12) else: torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output
def eval_and_log(e: Engine): eval_results = self.eval() e.state.metrics['val_accuracy'] = eval_results['val'].metrics['accuracy'] e.state.metrics['val_loss'] = eval_results['val'].metrics['avg_loss'] e.state.eval_results = eval_results e.fire_event("EVAL_DONE")
def _iteration(self, engine: Engine, batchdata: Dict[str, Any]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: - IMAGE: image Tensor data for model input, already moved to device. - LABEL: label Tensor data corresponding to the image, already moved to device. - PRED: prediction result of model. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: ValueError: When ``batchdata`` is None. """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, targets = batch args: Tuple = () kwargs: Dict = {} else: inputs, targets, args, kwargs = batch targets = targets.cpu() def _compute_pred(): ct = 1.0 pred = self.inferer(inputs, self.network, *args, **kwargs).cpu() pred = nn.functional.softmax(pred, dim=1) if not self.tta_val: return pred else: for dims in [[2], [3], [4], (2, 3), (2, 4), (3, 4), (2, 3, 4)]: flip_inputs = torch.flip(inputs, dims=dims) flip_pred = torch.flip(self.inferer( flip_inputs, self.network).cpu(), dims=dims) flip_pred = nn.functional.softmax(flip_pred, dim=1) del flip_inputs pred += flip_pred del flip_pred ct += 1 return pred / ct # execute forward computation with eval_mode(self.network): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() inputs = inputs.cpu() predictions = self.post_pred(decollate_batch(predictions)[0]) targets = self.post_label(decollate_batch(targets)[0]) resample_flag = batchdata["resample_flag"] anisotrophy_flag = batchdata["anisotrophy_flag"] crop_shape = batchdata["crop_shape"][0].tolist() original_shape = batchdata["original_shape"][0].tolist() if resample_flag: # convert the prediction back to the original (after cropped) shape predictions = recovery_prediction(predictions.numpy(), [self.num_classes, *crop_shape], anisotrophy_flag) predictions = torch.tensor(predictions) # put iteration outputs into engine.state engine.state.output = { Keys.IMAGE: inputs, Keys.LABEL: targets.unsqueeze(0) } engine.state.output[Keys.PRED] = torch.zeros( [1, self.num_classes, *original_shape]) # pad the prediction back to the original shape box_start, box_end = batchdata["bbox"][0] h_start, w_start, d_start = box_start h_end, w_end, d_end = box_end engine.state.output[Keys.PRED][0, :, h_start:h_end, w_start:w_end, d_start:d_end] = predictions del predictions engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output
def __call__(self, engine: Engine): for period, event in self.INTERVAL_TO_EVENT.items(): if engine.state.iteration % period == 0: engine.fire_event(event)