Esempio n. 1
0
 def __call__(self, engine: Engine) -> None:
     elapsed_time = time.time() - self.start_time
     if elapsed_time > self.limit_sec:
         self.logger.info(
             "Reached the time limit: {} sec. Stop training".format(
                 self.limit_sec))
         engine.terminate()
Esempio n. 2
0
 def __call__(self, engine: Engine):
     engine.terminate()
     # save current iteration for next round
     engine.state.dataloader_iter = engine._dataloader_iter
     if engine.state.iteration % engine.state.epoch_length == 0:
         # if current iteration is end of 1 epoch, manually trigger epoch completed event
         engine._fire_event(Events.EPOCH_COMPLETED)
Esempio n. 3
0
    def __call__(self, engine: Engine) -> None:
        output = self._output_transform(engine.state.output)

        def raise_error(x: Union[float, torch.Tensor]) -> None:

            if isinstance(x, numbers.Number):
                x = torch.tensor(x)

            if isinstance(x, torch.Tensor) and not bool(torch.isfinite(x).all()):
                raise RuntimeError("Infinite or NaN tensor found.")

        try:
            apply_to_type(output, (numbers.Number, torch.Tensor), raise_error)
        except RuntimeError:
            self.logger.warning(f"{self.__class__.__name__}: Output '{output}' contains NaN or Inf. Stop training")
            engine.terminate()
Esempio n. 4
0
    def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f: float, diverge_th: float):
        output = trainer.state.output
        loss = output_transform(output)
        lr = self._lr_schedule.get_param()
        self._history["lr"].append(lr)
        if trainer.state.iteration == 1:
            self._best_loss = loss
        else:
            if smooth_f > 0:
                loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
            if loss < self._best_loss:
                self._best_loss = loss
        self._history["loss"].append(loss)

        # Check if the loss has diverged; if it has, stop the trainer
        if self._history["loss"][-1] > diverge_th * self._best_loss:
            self._diverge_flag = True
            self.logger.info("Stopping early, the loss has diverged")
            trainer.terminate()
Esempio n. 5
0
    def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable,
                         smooth_f: float, diverge_th: float) -> None:
        output = trainer.state.output
        loss = output_transform(output)
        if not isinstance(loss, float):
            if isinstance(loss, torch.Tensor):
                if (loss.ndimension() == 0) or (loss.ndimension() == 1
                                                and len(loss) == 1):
                    loss = loss.item()
                else:
                    raise ValueError(
                        "if output of the engine is torch.Tensor, then "
                        "it must be 0d torch.Tensor or 1d torch.Tensor with 1 element, "
                        f"but got torch.Tensor of shape {loss.shape}")
            else:
                raise TypeError(
                    "output of the engine should be of type float or 0d torch.Tensor "
                    "or 1d torch.Tensor with 1 element, "
                    f"but got output of type {type(loss).__name__}")
        loss = idist.all_reduce(loss)
        lr = self._lr_schedule.get_param()  # type: ignore[union-attr]
        self._history["lr"].append(lr)
        if trainer.state.iteration == 1:
            self._best_loss = loss
        else:
            if smooth_f > 0:
                loss = smooth_f * loss + (1 -
                                          smooth_f) * self._history["loss"][-1]
            if loss < self._best_loss:
                self._best_loss = loss
        self._history["loss"].append(loss)

        # Check if the loss has diverged; if it has, stop the trainer
        if self._history["loss"][
                -1] > diverge_th * self._best_loss:  # type: ignore[operator]
            self._diverge_flag = True
            self.logger.info("Stopping early, the loss has diverged")
            trainer.terminate()
Esempio n. 6
0
 def _reached_num_iterations(self, trainer: Engine, num_iter: int):
     if trainer.state.iteration > num_iter:
         trainer.terminate()
Esempio n. 7
0
def test_terminate():
    engine = Engine(lambda e, b: 1)
    assert not engine.should_terminate
    engine.terminate()
    assert engine.should_terminate
Esempio n. 8
0
 def handle_exception(engine: Engine, e: Exception):
     if isinstance(e, KeyboardInterrupt) and engine.state.iteration > 1:
         logger.warning("KeyboardInterapt caught. Exiting.")
         engine.terminate()
     else:
         raise e
Esempio n. 9
0
class Trainer:
    _STEPS_PER_LOSS_WRITE = 10
    _STEPS_PER_GRAD_WRITE = 10
    _STEPS_PER_LR_WRITE = 10

    def __init__(
            self,

            module,
            device,

            train_loss,
            train_loader,
            opt,
            lr_scheduler,
            max_epochs,
            max_grad_norm,

            test_metrics,
            test_loader,
            epochs_per_test,

            early_stopping,
            valid_loss,
            valid_loader,
            max_bad_valid_epochs,

            visualizer,

            writer,
            should_checkpoint_latest,
            should_checkpoint_best_valid
    ):
        self._module = module
        self._module.to(device)
        self._device = device

        self._train_loss = train_loss
        self._train_loader = train_loader
        self._opt = opt
        self._lr_scheduler = lr_scheduler
        self._max_epochs = max_epochs
        self._max_grad_norm = max_grad_norm

        self._test_metrics = test_metrics
        self._test_loader = test_loader
        self._epochs_per_test = epochs_per_test

        self._valid_loss = valid_loss
        self._valid_loader = valid_loader
        self._max_bad_valid_epochs = max_bad_valid_epochs
        self._best_valid_loss = float("inf")
        self._num_bad_valid_epochs = 0

        self._visualizer = visualizer

        self._writer = writer
        self._should_checkpoint_best_valid = should_checkpoint_best_valid

        ### Training

        self._trainer = Engine(self._train_batch)

        AverageMetric().attach(self._trainer)
        ProgressBar(persist=True).attach(self._trainer, ["loss"])

        self._trainer.add_event_handler(Events.EPOCH_STARTED, lambda _: self._module.train())
        self._trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
        self._trainer.add_event_handler(Events.ITERATION_COMPLETED, self._log_training_info)

        if should_checkpoint_latest:
            self._trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: self._save_checkpoint("latest"))

        ### Validation

        if early_stopping:
            self._validator = Engine(self._validate_batch)

            AverageMetric().attach(self._validator)
            ProgressBar(persist=False, desc="Validating").attach(self._validator)

            self._trainer.add_event_handler(Events.EPOCH_COMPLETED, self._validate)
            self._validator.add_event_handler(Events.EPOCH_STARTED, lambda _: self._module.eval())

        ### Testing

        self._tester = Engine(self._test_batch)

        AverageMetric().attach(self._tester)
        ProgressBar(persist=False, desc="Testing").attach(self._tester)

        self._trainer.add_event_handler(Events.EPOCH_COMPLETED, self._test)
        self._tester.add_event_handler(Events.EPOCH_STARTED, lambda _: self._module.eval())

    def train(self):
        self._trainer.run(data=self._train_loader, max_epochs=self._max_epochs)

    def _train_batch(self, engine, batch):
        x, _ = batch # TODO: Potentially pass y also for genericity
        x = x.to(self._device)

        self._opt.zero_grad()

        loss = self._train_loss(self._module, x).mean()
        loss.backward()

        if self._max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(self._module.parameters(), self._max_grad_norm)

        self._opt.step()

        self._lr_scheduler.step()

        return {"loss": loss}

    @torch.no_grad()
    def _test(self, engine):
        epoch = engine.state.epoch
        if (epoch - 1) % self._epochs_per_test == 0: # Test after first epoch
            state = self._tester.run(data=self._test_loader)

            for k, v in state.metrics.items():
                self._writer.write_scalar(f"test/{k}", v, global_step=engine.state.epoch)

            self._visualizer.visualize(self._module, epoch)

    def _test_batch(self, engine, batch):
        x, _ = batch
        x = x.to(self._device)
        return self._test_metrics(self._module, x)

    @torch.no_grad()
    def _validate(self, engine):
        state = self._validator.run(data=self._valid_loader)
        valid_loss = state.metrics["loss"]

        if valid_loss < self._best_valid_loss:
            print(f"Best validation loss {valid_loss} after epoch {engine.state.epoch}")
            self._num_bad_valid_epochs = 0
            self._best_valid_loss = valid_loss

            if self._should_checkpoint_best_valid:
                self._save_checkpoint(tag="best_valid")

        else:
            self._num_bad_valid_epochs += 1

            # We do this manually (i.e. don't use Ignite's early stopping) to permit
            # saving/resuming more easily
            if self._num_bad_valid_epochs > self._max_bad_valid_epochs:
                print(
                    f"No validation improvement after {self._num_bad_valid_epochs} epochs. Terminating."
                )
                self._trainer.terminate()

    def _validate_batch(self, engine, batch):
        x, _ = batch
        x = x.to(self._device)
        return {"loss": self._valid_loss(self._module, x)}

    def _log_training_info(self, engine):
        i = engine.state.iteration

        if i % self._STEPS_PER_LOSS_WRITE == 0:
            loss = engine.state.output["loss"]
            self._writer.write_scalar("train/loss", loss, global_step=i)

        # TODO: Inefficient to recompute this if we are doing gradient clipping
        if i % self._STEPS_PER_GRAD_WRITE == 0:
            self._writer.write_scalar("train/grad-norm", self._get_grad_norm(), global_step=i)

        # TODO: We should do this _before_ calling self._lr_scheduler.step(), since
        # we will not correspond to the learning rate used at iteration i otherwise
        if i % self._STEPS_PER_LR_WRITE == 0:
            self._writer.write_scalar("train/lr", self._get_lr(), global_step=i)

    def _get_grad_norm(self):
        norm = 0
        for param in self._module.parameters():
            if param.grad is not None:
                norm += param.grad.norm().item()**2
        return np.sqrt(norm)

    def _get_lr(self):
        param_group, = self._opt.param_groups
        return param_group["lr"]

    def _save_checkpoint(self, tag):
        # We do this manually (i.e. don't use Ignite's checkpointing) because
        # Ignite only allows saving objects, not scalars (e.g. the current epoch) 
        checkpoint = {
            "epoch": self._trainer.state.epoch,
            "iteration": self._trainer.state.iteration,
            "module_state_dict": self._module.state_dict(),
            "opt_state_dict": self._opt.state_dict(),
            "best_valid_loss": self._best_valid_loss,
            "num_bad_valid_epochs": self._num_bad_valid_epochs
        }

        self._writer.write_checkpoint(tag, checkpoint)
Esempio n. 10
0
 def loss_diverged(engine: Engine, finder):
     if finder.history["loss"][-1] > diverge_th * finder.best_loss:
         engine.terminate()
         finder.logger.info("Stopping early, the loss has diverged")
Esempio n. 11
0
    def __call__(self, engine: Engine) -> None:
        for i, batchdata in enumerate(self.data_loader):
            batch = self.prepare_batch_fn(batchdata, self.device, False)
            if len(batch) == 2:
                inputs, targets = batch
            else:
                raise NotImplementedError

            if isinstance(inputs, (tuple, list)):
                self.logger.warn(
                    f"Got multiple inputs with size of {len(batch)},"
                    "select the first one as image data.")
                origin_img = inputs[0].cpu().detach().numpy().squeeze(1)
            else:
                origin_img = inputs.cpu().detach().numpy().squeeze(1)

            self.logger.debug(
                f"Input len: {len(inputs)}, shape: {origin_img.shape}")

            cam_result = self.cam(
                inputs,
                class_idx=self.target_class,
                img_spatial_size=origin_img.shape[1:],
            )

            self.logger.debug(f"Image batchdata shape: {origin_img.shape}, "
                              f"CAM batchdata shape: {cam_result.shape}")

            if len(origin_img.shape[1:]) == 3:
                for j, (img_slice,
                        cam_slice) in enumerate(zip(origin_img, cam_result)):
                    file_name = (
                        f"batch{i}_{j}_cam_{self.suffix}_{self.target_layers}.nii.gz"
                    )
                    nib.save(
                        nib.Nifti1Image(img_slice.squeeze(), np.eye(4)),
                        self.save_dir / f"batch{i}_{j}_images.nii.gz",
                    )

                    if cam_slice.shape[0] > 1 and self.fusion:
                        output_cam = cam_slice.mean(axis=0).squeeze()
                    elif self.hierarchical:
                        output_cam = np.flip(cam_slice.transpose(1, 2, 3, 0),
                                             3).squeeze()
                    else:
                        output_cam = cam_slice.transpose(1, 2, 3, 0).squeeze()

                    nib.save(
                        nib.Nifti1Image(output_cam, np.eye(4)),
                        self.save_dir / file_name,
                    )

            elif len(origin_img.shape[1:]) == 2:
                cam_result = np.uint8(cam_result.squeeze(1) * 255)
                for j, (img_slice,
                        cam_slice) in enumerate(zip(origin_img, cam_result)):
                    img_slice = np.uint8(Normalize2(img_slice) * 255)

                    img_slice = Image.fromarray(img_slice)
                    no_trans_heatmap, heatmap_on_image = apply_colormap_on_image(
                        img_slice, cam_slice, "hsv")

                    heatmap_on_image.save(self.save_dir /
                                          f"batch{i}_{j}_heatmap_on_img.png")
            else:
                raise NotImplementedError(
                    f"Cannot support ({origin_img.shape}) data.")

        engine.terminate()