Esempio n. 1
0
 def attach(self, engine: Engine) -> None:
     if self._name is None:
         self.logger = engine.logger
     if not engine.has_event_handler(self, Events.ITERATION_COMPLETED):
         engine.add_event_handler(Events.ITERATION_COMPLETED, self)
     if not engine.has_event_handler(self.saver.finalize, Events.COMPLETED):
         engine.add_event_handler(Events.COMPLETED, lambda engine: self.saver.finalize())
Esempio n. 2
0
    def attach(self, engine: Engine, name: str) -> None:
        """
        Attaches current metric to provided engine. On the end of engine's run,
        `engine.state.metrics` dictionary will contain computed metric's value under provided name.

        Args:
            engine (Engine): the engine to which the metric must be attached
            name (str): the name of the metric to attach

        Example:

        .. code-block:: python

            metric = ...
            metric.attach(engine, "mymetric")

            assert "mymetric" in engine.run(data).metrics

            assert metric.is_attached(engine)
        """
        engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name)
        if not engine.has_event_handler(self.started, Events.EPOCH_STARTED):
            engine.add_event_handler(Events.EPOCH_STARTED, self.started)
        if not engine.has_event_handler(self.iteration_completed,
                                        Events.ITERATION_COMPLETED):
            engine.add_event_handler(Events.ITERATION_COMPLETED,
                                     self.iteration_completed)
Esempio n. 3
0
def test_remove_event_handler_on_callable_events():

    engine = Engine(lambda e, b: 1)

    def foo(e):
        pass

    assert not engine.has_event_handler(foo)

    engine.add_event_handler(Events.EPOCH_STARTED, foo)
    assert engine.has_event_handler(foo)
    engine.remove_event_handler(foo, Events.EPOCH_STARTED)
    assert not engine.has_event_handler(foo)

    def bar(e):
        pass

    engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar)
    assert engine.has_event_handler(bar)
    engine.remove_event_handler(bar, Events.EPOCH_COMPLETED)
    assert not engine.has_event_handler(foo)

    with pytest.raises(
            TypeError,
            match=r"Argument event_name should not be a filtered event"):
        engine.remove_event_handler(bar, Events.EPOCH_COMPLETED(every=3))
Esempio n. 4
0
    def detach(self, engine: Engine) -> None:
        """
        Detaches current metric from the engine and no metric's computation is done during the run.
        This method in conjunction with :meth:`~ignite.metrics.Metric.attach` can be useful if several
        metrics need to be computed with different periods. For example, one metric is computed every training epoch
        and another metric (e.g. more expensive one) is done every n-th training epoch.

        Args:
            engine (Engine): the engine from which the metric must be detached

        Example:

        .. code-block:: python

            metric = ...
            engine = ...
            metric.detach(engine)

            assert "mymetric" not in engine.run(data).metrics

            assert not metric.is_attached(engine)
        """
        if engine.has_event_handler(self.completed, Events.EPOCH_COMPLETED):
            engine.remove_event_handler(self.completed, Events.EPOCH_COMPLETED)
        if engine.has_event_handler(self.started, Events.EPOCH_STARTED):
            engine.remove_event_handler(self.started, Events.EPOCH_STARTED)
        if engine.has_event_handler(self.iteration_completed,
                                    Events.ITERATION_COMPLETED):
            engine.remove_event_handler(self.iteration_completed,
                                        Events.ITERATION_COMPLETED)
Esempio n. 5
0
    def attach(self, engine: Engine):
        """
        Attaches lr_finder to engine.
        It is recommended to use `with lr_finder.attach(engine)` instead of
        explicitly detaching using `lr_finder.detach()`
        Args:
            engine: lr_finder is attached to this engine

        Notes:
            lr_finder cannot be attached to more than one engine at a time

        Returns:
            self
        """
        if self._engine:
            raise AlreadyAttachedError(
                "This LRFinder is already attached. create a new one or use lr_finder.detach()"
            )
        self._engine = weakref.ref(engine)
        if not engine.has_event_handler(self._run):
            engine.add_event_handler(Events.STARTED, self._run)
        if not engine.has_event_handler(self._warning):
            engine.add_event_handler(Events.COMPLETED, self._warning)
        if not engine.has_event_handler(self._reset):
            engine.add_event_handler(Events.COMPLETED, self._reset)

        return self
Esempio n. 6
0
    def attach(self, engine: Engine) -> None:
        """
        Register a set of Ignite Event-Handlers to a specified Ignite engine.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.

        """
        if self.name is None:
            self.logger = engine.logger
        if self.logger.getEffectiveLevel(
        ) > logging.INFO or logging.root.getEffectiveLevel() > logging.INFO:
            warnings.warn(
                "the effective log level of engine logger or RootLogger is higher than INFO, may not record log,"
                " please call `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` to enable it."
            )
        if self.iteration_log and not engine.has_event_handler(
                self.iteration_completed, Events.ITERATION_COMPLETED):
            engine.add_event_handler(Events.ITERATION_COMPLETED,
                                     self.iteration_completed)
        if self.epoch_log and not engine.has_event_handler(
                self.epoch_completed, Events.EPOCH_COMPLETED):
            engine.add_event_handler(Events.EPOCH_COMPLETED,
                                     self.epoch_completed)
        if not engine.has_event_handler(self.exception_raised,
                                        Events.EXCEPTION_RAISED):
            engine.add_event_handler(Events.EXCEPTION_RAISED,
                                     self.exception_raised)
Esempio n. 7
0
 def attach(self, engine: Engine, name: str) -> None:
     engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name)
     if not engine.has_event_handler(self.started, Events.EPOCH_STARTED):
         engine.add_event_handler(Events.EPOCH_STARTED, self.started)
     if not engine.has_event_handler(self.iteration_completed,
                                     Events.ITERATION_COMPLETED):
         engine.add_event_handler(Events.ITERATION_COMPLETED,
                                  self.iteration_completed)
Esempio n. 8
0
    def _run(
        self,
        trainer: Engine,
        optimizer: Optimizer,
        output_transform: Callable,
        num_iter: int,
        start_lr: float,
        end_lr: float,
        step_mode: str,
        smooth_f: float,
        diverge_th: float,
    ) -> None:

        self._history = {"lr": [], "loss": []}
        self._best_loss = None
        self._diverge_flag = False

        # attach LRScheduler to trainer.
        if num_iter is None:
            num_iter = trainer.state.epoch_length * trainer.state.max_epochs
        else:
            max_iter = trainer.state.epoch_length * trainer.state.max_epochs  # type: ignore[operator]
            if max_iter < num_iter:
                max_iter = num_iter
                trainer.state.max_iters = num_iter
                trainer.state.max_epochs = ceil(
                    num_iter /
                    trainer.state.epoch_length)  # type: ignore[operator]

        if not trainer.has_event_handler(self._reached_num_iterations):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._reached_num_iterations, num_iter)

        # attach loss and lr logging
        if not trainer.has_event_handler(self._log_lr_and_loss):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._log_lr_and_loss, output_transform,
                                      smooth_f, diverge_th)

        self.logger.debug(f"Running LR finder for {num_iter} iterations")
        if start_lr is None:
            start_lr = optimizer.param_groups[0]["lr"]
        # Initialize the proper learning rate policy
        if step_mode.lower() == "exp":
            start_lr = [start_lr] * len(optimizer.param_groups)  # type: ignore
            self._lr_schedule = LRScheduler(
                _ExponentialLR(optimizer, start_lr, end_lr, num_iter))
        else:
            self._lr_schedule = PiecewiseLinear(optimizer,
                                                param_name="lr",
                                                milestones_values=[
                                                    (0, start_lr),
                                                    (num_iter, end_lr)
                                                ])
        if not trainer.has_event_handler(self._lr_schedule):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._lr_schedule, num_iter)
Esempio n. 9
0
 def _internal_attach(self, engine: Engine) -> None:
     for index, metric in enumerate(itertools.chain(self.args, self.kwargs.values())):
         if isinstance(metric, MetricsLambda):
             metric._internal_attach(engine)
         elif isinstance(metric, Metric):
             if not engine.has_event_handler(metric.started, Events.EPOCH_STARTED):
                 engine.add_event_handler(Events.EPOCH_STARTED, metric.started)
             if not engine.has_event_handler(metric.iteration_completed, Events.ITERATION_COMPLETED):
                 engine.add_event_handler(Events.ITERATION_COMPLETED, metric.iteration_completed)
Esempio n. 10
0
    def _run(
        self,
        trainer: Engine,
        optimizer: Optimizer,
        output_transform: Callable,
        num_iter: int,
        end_lr: float,
        step_mode: str,
        smooth_f: float,
        diverge_th: float,
    ):

        self._history = {"lr": [], "loss": []}
        self._best_loss = None
        self._diverge_flag = False

        # attach LRScheduler to trainer.
        if num_iter is None:
            num_iter = trainer.state.epoch_length * trainer.state.max_epochs
        else:
            max_iter = trainer.state.epoch_length * trainer.state.max_epochs
            if num_iter > max_iter:
                warnings.warn(
                    "Desired num_iter {} is unreachable with the current run setup of {} iteration "
                    "({} epochs)".format(num_iter, max_iter,
                                         trainer.state.max_epochs),
                    UserWarning,
                )

        if not trainer.has_event_handler(self._reached_num_iterations):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._reached_num_iterations, num_iter)

        # attach loss and lr logging
        if not trainer.has_event_handler(self._log_lr_and_loss):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._log_lr_and_loss, output_transform,
                                      smooth_f, diverge_th)

        self.logger.debug(
            "Running LR finder for {} iterations".format(num_iter))
        # Initialize the proper learning rate policy
        if step_mode.lower() == "exp":
            self._lr_schedule = LRScheduler(
                _ExponentialLR(optimizer, end_lr, num_iter))
        else:
            start_lr = optimizer.param_groups[0]["lr"]
            self._lr_schedule = PiecewiseLinear(optimizer,
                                                param_name="lr",
                                                milestones_values=[
                                                    (0, start_lr),
                                                    (num_iter, end_lr)
                                                ])
        if not trainer.has_event_handler(self._lr_schedule):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._lr_schedule, num_iter)
Esempio n. 11
0
    def attach(self, engine: Engine):
        """Register a set of Ignite Event-Handlers to a specified Ignite engine.

        Args:
            engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator.

        """
        if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
            engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
        if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
            engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed)
Esempio n. 12
0
 def attach(self, engine: Engine) -> None:
     """
     Args:
         engine: Ignite Engine, it can be a trainer, validator or evaluator.
     """
     if self._name is None:
         self.logger = engine.logger
     if not engine.has_event_handler(self, Events.ITERATION_COMPLETED):
         engine.add_event_handler(Events.ITERATION_COMPLETED, self)
     if not engine.has_event_handler(self.saver.finalize, Events.COMPLETED):
         engine.add_event_handler(Events.COMPLETED, lambda engine: self.saver.finalize())
Esempio n. 13
0
 def _internal_attach(self, engine: Engine, usage: MetricUsage) -> None:
     self.engine = engine
     for index, metric in enumerate(itertools.chain(self.args, self.kwargs.values())):
         if isinstance(metric, MetricsLambda):
             metric._internal_attach(engine, usage)
         elif isinstance(metric, Metric):
             # NB : metrics is attached partially
             # We must not use is_attached() but rather if these events exist
             if not engine.has_event_handler(metric.started, usage.STARTED):
                 engine.add_event_handler(usage.STARTED, metric.started)
             if not engine.has_event_handler(metric.iteration_completed, usage.ITERATION_COMPLETED):
                 engine.add_event_handler(usage.ITERATION_COMPLETED, metric.iteration_completed)
Esempio n. 14
0
    def _detach(self, trainer: Engine):
        """
        Detaches lr_finder from trainer.

        Args:
            trainer: the trainer to detach form.
        """

        if trainer.has_event_handler(self._run, Events.STARTED):
            trainer.remove_event_handler(self._run, Events.STARTED)
        if trainer.has_event_handler(self._warning, Events.COMPLETED):
            trainer.remove_event_handler(self._warning, Events.COMPLETED)
        if trainer.has_event_handler(self._reset, Events.COMPLETED):
            trainer.remove_event_handler(self._reset, Events.COMPLETED)
Esempio n. 15
0
 def attach(self, engine: Engine) -> None:
     """
     Args:
         engine: Ignite Engine, it can be a trainer, validator or evaluator.
     """
     if self._name is None:
         self.logger = engine.logger
     if not engine.has_event_handler(self._started, Events.EPOCH_STARTED):
         engine.add_event_handler(Events.EPOCH_STARTED, self._started)
     if not engine.has_event_handler(self, Events.ITERATION_COMPLETED):
         engine.add_event_handler(Events.ITERATION_COMPLETED, self)
     if not engine.has_event_handler(self._finalize,
                                     Events.EPOCH_COMPLETED):
         engine.add_event_handler(Events.EPOCH_COMPLETED, self._finalize)
Esempio n. 16
0
 def _internal_is_attached(self, engine: Engine, usage: MetricUsage) -> bool:
     # if no engine, metrics is not attached
     if engine is None:
         return False
     # check recursively if metrics are attached
     is_detached = False
     for metric in itertools.chain(self.args, self.kwargs.values()):
         if isinstance(metric, MetricsLambda):
             if not metric._internal_is_attached(engine, usage):
                 is_detached = True
         elif isinstance(metric, Metric):
             if not engine.has_event_handler(metric.started, usage.STARTED):
                 is_detached = True
             if not engine.has_event_handler(metric.iteration_completed, usage.ITERATION_COMPLETED):
                 is_detached = True
     return not is_detached
Esempio n. 17
0
    def attach(self, engine: Engine) -> None:
        """
        Register a set of Ignite Event-Handlers to a specified Ignite engine.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.

        """
        if self._name is None:
            self.logger = engine.logger
        if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
            engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
        if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
            engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed)
        if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED):
            engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised)
Esempio n. 18
0
    def detach(
        self, engine: Engine, usage: Union[str,
                                           MetricUsage] = EpochWise()) -> None:
        """
        Detaches current metric from the engine and no metric's computation is done during the run.
        This method in conjunction with :meth:`~ignite.metrics.Metric.attach` can be useful if several
        metrics need to be computed with different periods. For example, one metric is computed every training epoch
        and another metric (e.g. more expensive one) is done every n-th training epoch.

        Args:
            engine (Engine): the engine from which the metric must be detached
            usage (str or MetricUsage, optional): the usage of the metric. Valid string values should be
                'epoch_wise' (default) or 'batch_wise'.

        Example:

        .. code-block:: python

            metric = ...
            engine = ...
            metric.detach(engine)

            assert "mymetric" not in engine.run(data).metrics

            assert not metric.is_attached(engine)

        Example with usage:

        .. code-block:: python

            metric = ...
            engine = ...
            metric.detach(engine, usage="batch_wise")

            assert "mymetric" not in engine.run(data).metrics

            assert not metric.is_attached(engine, usage="batch_wise")
        """
        usage = self._check_usage(usage)
        if engine.has_event_handler(self.completed, usage.COMPLETED):
            engine.remove_event_handler(self.completed, usage.COMPLETED)
        if engine.has_event_handler(self.started, usage.STARTED):
            engine.remove_event_handler(self.started, usage.STARTED)
        if engine.has_event_handler(self.iteration_completed,
                                    usage.ITERATION_COMPLETED):
            engine.remove_event_handler(self.iteration_completed,
                                        usage.ITERATION_COMPLETED)
Esempio n. 19
0
    def attach(
        self,
        engine: Engine,
        name: str,
        usage: Union[str, MetricUsage] = EpochWise()) -> None:
        """
        Attaches current metric to provided engine. On the end of engine's run, `engine.state.metrics` dictionary will
        contain computed metric's value under provided name.

        Args:
            engine (Engine): the engine to which the metric must be attached
            name (str): the name of the metric to attach
            usage (str or MetricUsage, optional): the usage of the metric. Valid string values should be
                :attr:`ignite.metrics.EpochWise.usage_name` (default) or
                :attr:`ignite.metrics.BatchWise.usage_name`.

        Example:

        .. code-block:: python

            metric = ...
            metric.attach(engine, "mymetric")

            assert "mymetric" in engine.run(data).metrics

            assert metric.is_attached(engine)

        Example with usage:

        .. code-block:: python

            metric = ...
            metric.attach(engine, "mymetric", usage=BatchWise.usage_name)

            assert "mymetric" in engine.run(data).metrics

            assert metric.is_attached(engine, usage=BatchWise.usage_name)
        """
        usage = self._check_usage(usage)
        if not engine.has_event_handler(self.started, usage.STARTED):
            engine.add_event_handler(usage.STARTED, self.started)
        if not engine.has_event_handler(self.iteration_completed,
                                        usage.ITERATION_COMPLETED):
            engine.add_event_handler(usage.ITERATION_COMPLETED,
                                     self.iteration_completed)
        engine.add_event_handler(usage.COMPLETED, self.completed, name)
Esempio n. 20
0
    def attach(self, engine: Engine):
        if not isinstance(engine, Engine):
            raise TypeError("Argument engine should be ignite.engine.Engine, "
                            "but given {}".format(type(engine)))

        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED].insert(
                0, (self._as_first_started, (engine, ), {}))
Esempio n. 21
0
def test_remove_event_handler_on_callable_events():

    engine = Engine(lambda e, b: 1)

    def foo(e):
        pass

    assert not engine.has_event_handler(foo)

    engine.add_event_handler(Events.EPOCH_STARTED, foo)
    assert engine.has_event_handler(foo)
    engine.remove_event_handler(foo, Events.EPOCH_STARTED)
    assert not engine.has_event_handler(foo)

    def bar(e):
        pass

    engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar)
    assert engine.has_event_handler(bar)
    engine.remove_event_handler(bar, Events.EPOCH_COMPLETED)
    assert not engine.has_event_handler(bar)

    engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar)
    assert engine.has_event_handler(bar)
    engine.remove_event_handler(bar, Events.EPOCH_COMPLETED(every=3))
    assert not engine.has_event_handler(bar)
Esempio n. 22
0
    def is_attached(self, engine: Engine) -> bool:
        """
        Checks if current metric is attached to provided engine. If attached, metric's computed
        value is written to `engine.state.metrics` dictionary.

        Args:
            engine (Engine): the engine checked from which the metric should be attached
        """
        return engine.has_event_handler(self.completed, Events.EPOCH_COMPLETED)
Esempio n. 23
0
    def attach(self, engine: Engine) -> None:
        """
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        """

        for name, info in engine.data_loader.dataset.info.items():
            self.case_names.append(name)
            self.probs_maps[name] = np.zeros(info['mask_dims'])
            self.levels[name] = info['level']

        if self._name is None:
            self.logger = engine.logger
        if not engine.has_event_handler(self, Events.ITERATION_COMPLETED):
            engine.add_event_handler(Events.ITERATION_COMPLETED, self)
        if not engine.has_event_handler(self.finalize, Events.COMPLETED):
            engine.add_event_handler(Events.COMPLETED,
                                     lambda engine: self.finalize())
Esempio n. 24
0
    def attach(self, engine: Engine) -> None:
        """
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        """

        self.num_images = len(engine.data_loader.dataset.data)

        for sample in engine.data_loader.dataset.data:
            name = sample["name"]
            self.prob_map[name] = np.zeros(sample["mask_shape"], dtype=self.dtype)
            self.counter[name] = len(sample["mask_locations"])
            self.level[name] = sample["level"]

        if self._name is None:
            self.logger = engine.logger
        if not engine.has_event_handler(self, Events.ITERATION_COMPLETED):
            engine.add_event_handler(Events.ITERATION_COMPLETED, self)
        if not engine.has_event_handler(self.finalize, Events.COMPLETED):
            engine.add_event_handler(Events.COMPLETED, self.finalize)
Esempio n. 25
0
    def is_attached(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise()) -> bool:
        """
        Checks if current metric is attached to provided engine. If attached, metric's computed
        value is written to `engine.state.metrics` dictionary.

        Args:
            engine: the engine checked from which the metric should be attached
            usage: the usage of the metric. Valid string values should be
                'epoch_wise' (default) or 'batch_wise'.
        """
        usage = self._check_usage(usage)
        return engine.has_event_handler(self.completed, usage.COMPLETED)
Esempio n. 26
0
    def attach(self, engine: Engine) -> None:
        """
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        """

        image_data = engine.data_loader.dataset.image_data  # type: ignore
        self.num_images = len(image_data)

        # Initialized probability maps for all the images
        for sample in image_data:
            name = sample[ProbMapKeys.NAME]
            self.counter[name] = sample[ProbMapKeys.COUNT]
            self.prob_map[name] = np.zeros(sample[ProbMapKeys.SIZE],
                                           dtype=self.dtype)

        if self._name is None:
            self.logger = engine.logger
        if not engine.has_event_handler(self, Events.ITERATION_COMPLETED):
            engine.add_event_handler(Events.ITERATION_COMPLETED, self)
        if not engine.has_event_handler(self.finalize, Events.COMPLETED):
            engine.add_event_handler(Events.COMPLETED, self.finalize)
Esempio n. 27
0
    def attach(self, engine: Engine) -> None:
        """Attach HandlersTimeProfiler to the given engine.

        Args:
            engine: the instance of Engine to attach
        """
        if not isinstance(engine, Engine):
            raise TypeError(
                f"Argument engine should be ignite.engine.Engine, but given {type(engine)}"
            )

        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED].insert(
                0, (self._as_first_started, (engine, ), {}))
Esempio n. 28
0
class Trainer(object):
    def __init__(self: TrainerType,
                 model: nn.Module,
                 optimizer: Optimizer,
                 checkpoint_dir: str = '../../checkpoints',
                 experiment_name: str = 'experiment',
                 model_checkpoint: Optional[str] = None,
                 optimizer_checkpoint: Optional[str] = None,
                 metrics: types.GenericDict = None,
                 patience: int = 10,
                 validate_every: int = 1,
                 accumulation_steps: int = 1,
                 loss_fn: Union[_Loss, DataParallelCriterion] = None,
                 non_blocking: bool = True,
                 retain_graph: bool = False,
                 dtype: torch.dtype = torch.float,
                 device: str = 'cpu',
                 parallel: bool = False) -> None:
        self.dtype = dtype
        self.retain_graph = retain_graph
        self.non_blocking = non_blocking
        self.device = device
        self.loss_fn = loss_fn
        self.validate_every = validate_every
        self.patience = patience
        self.accumulation_steps = accumulation_steps
        self.checkpoint_dir = checkpoint_dir

        model_checkpoint = self._check_checkpoint(model_checkpoint)
        optimizer_checkpoint = self._check_checkpoint(optimizer_checkpoint)

        self.model = cast(
            nn.Module,
            from_checkpoint(model_checkpoint,
                            model,
                            map_location=torch.device('cpu')))
        self.model = self.model.type(dtype).to(device)
        self.optimizer = from_checkpoint(optimizer_checkpoint, optimizer)
        self.parallel = parallel
        if parallel:
            if device == 'cpu':
                raise ValueError("parallel can be used only with cuda device")
            self.model = DataParallelModel(self.model).to(device)
            self.loss_fn = DataParallelCriterion(self.loss_fn)  # type: ignore
        if metrics is None:
            metrics = {}
        if 'loss' not in metrics:
            if self.parallel:
                metrics['loss'] = Loss(
                    lambda x, y: self.loss_fn(x, y).mean())  # type: ignore
            else:
                metrics['loss'] = Loss(self.loss_fn)
        self.trainer = Engine(self.train_step)
        self.train_evaluator = Engine(self.eval_step)
        self.valid_evaluator = Engine(self.eval_step)
        for name, metric in metrics.items():
            metric.attach(self.train_evaluator, name)
            metric.attach(self.valid_evaluator, name)

        self.pbar = ProgressBar()
        self.val_pbar = ProgressBar(desc='Validation')

        if checkpoint_dir is not None:
            self.checkpoint = CheckpointHandler(checkpoint_dir,
                                                experiment_name,
                                                score_name='validation_loss',
                                                score_function=self._score_fn,
                                                n_saved=2,
                                                require_empty=False,
                                                save_as_state_dict=True)

        self.early_stop = EarlyStopping(patience, self._score_fn, self.trainer)

        self.val_handler = EvaluationHandler(pbar=self.pbar,
                                             validate_every=1,
                                             early_stopping=self.early_stop)
        self.attach()
        log.info(
            f'Trainer configured to run {experiment_name}\n'
            f'\tpretrained model: {model_checkpoint} {optimizer_checkpoint}\n'
            f'\tcheckpoint directory: {checkpoint_dir}\n'
            f'\tpatience: {patience}\n'
            f'\taccumulation steps: {accumulation_steps}\n'
            f'\tnon blocking: {non_blocking}\n'
            f'\tretain graph: {retain_graph}\n'
            f'\tdevice: {device}\n'
            f'\tmodel dtype: {dtype}\n'
            f'\tparallel: {parallel}')

    def _check_checkpoint(self: TrainerType,
                          ckpt: Optional[str]) -> Optional[str]:
        if ckpt is None:
            return ckpt
        if system.is_url(ckpt):
            ckpt = system.download_url(cast(str, ckpt), self.checkpoint_dir)
        ckpt = os.path.join(self.checkpoint_dir, ckpt)
        return ckpt

    @staticmethod
    def _score_fn(engine: Engine) -> float:
        """Returns the scoring metric for checkpointing and early stopping

        Args:
            engine (ignite.engine.Engine): The engine that calculates
            the val loss

        Returns:
            (float): The validation loss
        """
        negloss: float = -engine.state.metrics['loss']
        return negloss

    def parse_batch(self: TrainerType,
                    batch: List[torch.Tensor]) -> Tuple[torch.Tensor, ...]:
        inputs = to_device(batch[0],
                           device=self.device,
                           non_blocking=self.non_blocking)
        targets = to_device(batch[1],
                            device=self.device,
                            non_blocking=self.non_blocking)
        return inputs, targets

    def get_predictions_and_targets(
            self: TrainerType,
            batch: List[torch.Tensor]) -> Tuple[torch.Tensor, ...]:
        inputs, targets = self.parse_batch(batch)
        y_pred = self.model(inputs)
        return y_pred, targets

    def train_step(self: TrainerType, engine: Engine,
                   batch: List[torch.Tensor]) -> float:
        self.model.train()

        y_pred, targets = self.get_predictions_and_targets(batch)
        loss = self.loss_fn(y_pred, targets.long())  # type: ignore
        if self.parallel:
            loss = loss.mean()
        loss = loss / self.accumulation_steps
        loss.backward(retain_graph=self.retain_graph)
        if (self.trainer.state.iteration + 1) % self.accumulation_steps == 0:
            self.optimizer.step()  # type: ignore
            self.optimizer.zero_grad()
        loss_value: float = loss.item()
        return loss_value

    def eval_step(self: TrainerType, engine: Engine,
                  batch: List[torch.Tensor]) -> Tuple[torch.Tensor, ...]:
        self.model.eval()
        with torch.no_grad():
            y_pred, targets = self.get_predictions_and_targets(batch)
            return y_pred, targets

    def predict(self: TrainerType, dataloader: DataLoader) -> State:
        return self.valid_evaluator.run(dataloader)

    def fit(self: TrainerType,
            train_loader: DataLoader,
            val_loader: DataLoader,
            epochs: int = 50) -> State:
        log.info('Trainer will run for\n'
                 f'model: {self.model}\n'
                 f'optimizer: {self.optimizer}\n'
                 f'loss: {self.loss_fn}')
        self.val_handler.attach(self.trainer,
                                self.train_evaluator,
                                train_loader,
                                validation=False)
        self.val_handler.attach(self.trainer,
                                self.valid_evaluator,
                                val_loader,
                                validation=True)
        self.model.zero_grad()
        self.trainer.run(train_loader, max_epochs=epochs)
        best_score = (-self.early_stop.best_score if self.early_stop else
                      self.valid_evaluator.state.metrics['loss'])
        return best_score

    def overfit_single_batch(self: TrainerType,
                             train_loader: DataLoader) -> State:
        single_batch = [next(iter(train_loader))]

        if self.trainer.has_event_handler(self.val_handler,
                                          Events.EPOCH_COMPLETED):
            self.trainer.remove_event_handler(self.val_handler,
                                              Events.EPOCH_COMPLETED)

        self.val_handler.attach(
            self.trainer,
            self.train_evaluator,
            single_batch,  # type: ignore
            validation=False)
        out = self.trainer.run(single_batch, max_epochs=100)
        return out

    def fit_debug(self: TrainerType, train_loader: DataLoader,
                  val_loader: DataLoader) -> State:
        train_loader = iter(train_loader)
        train_subset = [next(train_loader), next(train_loader)]
        val_loader = iter(val_loader)  # type: ignore
        val_subset = [next(val_loader), next(val_loader)]  # type ignore
        out = self.fit(train_subset, val_subset, epochs=6)  # type: ignore
        return out

    def _attach_checkpoint(self: TrainerType) -> TrainerType:
        ckpt = {'model': self.model, 'optimizer': self.optimizer}
        if self.checkpoint_dir is not None:
            self.valid_evaluator.add_event_handler(Events.COMPLETED,
                                                   self.checkpoint, ckpt)
        return self

    def attach(self: TrainerType) -> TrainerType:
        ra = RunningAverage(output_transform=lambda x: x)
        ra.attach(self.trainer, "Train Loss")
        self.pbar.attach(self.trainer, ['Train Loss'])
        self.val_pbar.attach(self.train_evaluator)
        self.val_pbar.attach(self.valid_evaluator)
        self.valid_evaluator.add_event_handler(Events.COMPLETED,
                                               self.early_stop)
        self = self._attach_checkpoint()

        def graceful_exit(engine, e):
            if isinstance(e, KeyboardInterrupt):
                engine.terminate()
                log.warn("CTRL-C caught. Exiting gracefully...")
            else:
                raise (e)

        self.trainer.add_event_handler(Events.EXCEPTION_RAISED, graceful_exit)
        self.train_evaluator.add_event_handler(Events.EXCEPTION_RAISED,
                                               graceful_exit)
        self.valid_evaluator.add_event_handler(Events.EXCEPTION_RAISED,
                                               graceful_exit)
        return self
Esempio n. 29
0
    def attach(
        self,
        trainer: Engine,
        to_save: Mapping,
        output_transform: Callable = lambda output: output,
        num_iter: Optional[int] = None,
        end_lr: float = 10.0,
        step_mode: str = "exp",
        smooth_f: float = 0.05,
        diverge_th: float = 5.0,
    ):
        """Attaches lr_finder to a given trainer. It also resets model and optimizer at the end of the run.

        Usage:

        .. code-block:: python

            to_save = {"model": model, "optimizer": optimizer}
            with lr_finder.attach(trainer, to_save=to_save) as trainer_with_lr_finder:
                trainer_with_lr_finder.run(dataloader)`

        Args:
            trainer (Engine): lr_finder is attached to this trainer. Please, keep in mind that all attached handlers
                will be executed.
            to_save (Mapping): dictionary with optimizer and other objects that needs to be restored after running
                the LR finder. For example, `to_save={'optimizer': optimizer, 'model': model}`. All objects should
                implement `state_dict` and `load_state_dict` methods.
            output_transform (callable, optional): function that transforms the trainer's `state.output` after each
                iteration. It must return the loss of that iteration.
            num_iter (int, optional): number of iterations for lr schedule between base lr and end_lr. Default, it will
                run for `trainer.state.epoch_length * trainer.state.max_epochs`.
            end_lr (float, optional): upper bound for lr search. Default, 10.0.
            step_mode (str, optional): "exp" or "linear", which way should the lr be increased from optimizer's initial
                lr to `end_lr`. Default, "exp".
            smooth_f (float, optional): loss smoothing factor in range `[0, 1)`. Default, 0.05
            diverge_th (float, optional): Used for stopping the search when `current loss > diverge_th * best_loss`.
                Default, 5.0.

        Note:
            lr_finder cannot be attached to more than one trainer at a time.

        Returns:
            trainer_with_lr_finder: trainer used for finding the lr
        """
        if not isinstance(to_save, Mapping):
            raise TypeError("Argument to_save should be a mapping, but given {}".format(type(to_save)))

        Checkpoint._check_objects(to_save, "state_dict")
        Checkpoint._check_objects(to_save, "load_state_dict")

        if "optimizer" not in to_save:
            raise ValueError("Mapping to_save should contain 'optimizer' key")

        if not isinstance(to_save["optimizer"], torch.optim.Optimizer):
            raise TypeError(
                "Object to_save['optimizer'] should be torch optimizer, but given {}".format(type(to_save["optimizer"]))
            )

        if smooth_f < 0 or smooth_f >= 1:
            raise ValueError("smooth_f is outside the range [0, 1]")
        if diverge_th < 1:
            raise ValueError("diverge_th should be larger than 1")
        if step_mode not in ["exp", "linear"]:
            raise ValueError("step_mode should be 'exp' or 'linear', but given {}".format(step_mode))
        if num_iter is not None:
            if not isinstance(num_iter, int):
                raise TypeError("if provided, num_iter should be an integer, but give {}".format(num_iter))
            if num_iter <= 0:
                raise ValueError("if provided, num_iter should be positive, but give {}".format(num_iter))

        # store to_save
        with tempfile.TemporaryDirectory() as tmpdirname:
            obj = {k: o.state_dict() for k, o in to_save.items()}
            # add trainer
            obj["trainer"] = trainer.state_dict()
            cache_filepath = Path(tmpdirname) / "ignite_lr_finder_cache.pt"
            torch.save(obj, cache_filepath.as_posix())

            optimizer = to_save["optimizer"]
            # Attach handlers
            if not trainer.has_event_handler(self._run):
                trainer.add_event_handler(
                    Events.STARTED,
                    self._run,
                    optimizer,
                    output_transform,
                    num_iter,
                    end_lr,
                    step_mode,
                    smooth_f,
                    diverge_th,
                )
            if not trainer.has_event_handler(self._warning):
                trainer.add_event_handler(Events.COMPLETED, self._warning)
            if not trainer.has_event_handler(self._reset):
                trainer.add_event_handler(Events.COMPLETED, self._reset)

            yield trainer
            self._detach(trainer)
            # restore to_save and reset trainer's state
            obj = torch.load(cache_filepath.as_posix())
            trainer.load_state_dict(obj["trainer"])
            for k, o in obj.items():
                if k in to_save:
                    to_save[k].load_state_dict(o)
Esempio n. 30
0
class IgniteJunction(Junction):
    """
    This abstracts the functionality of an Ignite Engine into Junction format. See the Ignite documentation (https://github.com/pytorch/ignite)
    for more details.
    """
    required_components = {'model': Model, 'dataset': object}

    # These dictionaries describe the allowed optimizers and learning rate schedulers, along with the keyword arguments that each can accept.
    # These are all the optimizers/schedulers that PyTorch includes.

    optimizers = {
        'Adadelta':  optim.Adadelta,
        'Adagrad': optim.Adagrad,
        'Adam': optim.Adam,
        'SparseAdam': optim.SparseAdam,
        'Adamax': optim.Adamax,
        'ASGD': optim.ASGD,
        'LBFGS': optim.LBFGS,
        'RMSprop': optim.RMSprop,
        'Rprop': optim.Rprop,
        'SGD': optim.SGD,
    }
    optimizer_kwargs = {
        'Adadelta':  ['lr', 'rho', 'eps', 'weight_decay'],
        'Adagrad': ['lr', 'lr_decay', 'weight_decay', 'initial_accumulator_value'],
        'Adam': ['lr', 'betas', 'eps', 'weight_decay', 'amsgrad'],
        'SparseAdam': ['lr', 'betas', 'eps'],
        'Adamax': ['lr', 'betas', 'eps', 'weight_decay'],
        'ASGD': ['lr', 'lambd', 'alpha', 't0', 'weight_decay'],
        'LBFGS': ['lr', 'max_iter', 'max_eval', 'tolerance_grad', 'tolerance_change', 'history_size', 'line_search_fn'],
        'RMSprop': ['lr', 'alpha', 'eps', 'weight_decay', 'momentum', 'centered'],
        'Rprop': ['lr', 'etas', 'step_sizes'],
        'SGD': ['lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'],
    }
    schedulers = {
        'LambdaLR': optim.lr_scheduler.LambdaLR,
        'StepLR': optim.lr_scheduler.StepLR,
        'MultiStepLR': optim.lr_scheduler.MultiStepLR,
        'ExponentialLR': optim.lr_scheduler.ExponentialLR,
        'CosineAnnealingLR': optim.lr_scheduler.CosineAnnealingLR,
        'ReduceLROnPlateau': optim.lr_scheduler.ReduceLROnPlateau,
    }
    scheduler_kwargs = {
        'LambdaLR': ['lr_lambda', 'last_epoch'],
        'StepLR': ['step_size', 'gamma', 'last_epoch'],
        'MultiStepLR': ['milestones', 'gamma', 'last_epoch'],
        'ExponentialLR': ['gamma', 'last_epoch'],
        'CosineAnnealingLR': ['T_max', 'eta_min', 'last_epoch'],
        'ReduceLROnPlateau': ['mode', 'factor', 'patience', 'verbose', 'threshold', 'threshold_mode', 'cooldown', 'min_lr', 'eps'],
    }

    def __init__(self, components, loss, optimizer, scheduler=None, update_function=default_training_closure, visdom=True, environment='default', description='', **kwargs):

        Junction.__init__(self, components = components)
        parameters = [x for x in self.model.all_parameters() if x.requires_grad]
        # Initialize engine
        self.optimizer = self.optimizers[optimizer](parameters, **subset_dict(kwargs, self.optimizer_kwargs[optimizer]))
        if scheduler is not None:
            self.optimizer = self.schedulers[scheduler](self.optimizer, **subset_dict(kwargs, self.scheduler_kwargs[scheduler]))
        self.loss = loss
        self.update_function = update_function(self.model, self.optimizer, self.loss)
        self.engine = Engine(self.update_function)

        # Configure metrics and events
        if not visdom:
            environment = None
        self.attach_events(environment=environment, description=description)

    def train(self, dataset = None, max_epochs=10):

        dataset = dataset or self.dataset
        self.engine.run(dataset, max_epochs=max_epochs)

    def run(self, *args, **kwargs):
        """ Alias for train """
        self.train(*args, **kwargs)

    def add_event_handler(self, *args, **kwargs):

        self.engine.add_event_handler(*args, **kwargs)

    def has_event_handler(self, *args, **kwargs):

        return self.engine.has_event_handler(*args, **kwargs)

    def attach_events(self, description, environment=None, save_file = None):

        tim = Timer()
        tim.attach( self.engine,
                    start=Events.STARTED,
                    step=Events.ITERATION_COMPLETED,
        )

        log_interval = 100
        plot_interval = 10

        @self.engine.on(Events.ITERATION_COMPLETED)
        def print_training_loss(engine):
            iter = (engine.state.iteration -1)
            if iter % log_interval == 0:
                print("Epoch[{}] Iteration: {} Time: {} Loss: {:.2f}".format(
                    engine.state.epoch, iter, str(datetime.timedelta(seconds=int(tim.value()))), engine.state.output['loss']
                ))
        if environment:

            vis = visdom.Visdom(env=environment)

            def create_plot_window(vis, xlabel, ylabel, title):
                return vis.line(X=np.array([1]), Y=np.array([np.nan]), opts=dict(xlabel=xlabel, ylabel=ylabel, title=title))

            train_loss_window = create_plot_window(vis, '#Iterations', 'Loss', 'Training Loss {0}'.format(description))

            @self.engine.on(Events.ITERATION_COMPLETED)
            def plot_training_loss(engine):
                iter = (engine.state.iteration -1)
                if iter % plot_interval == 0:
                    vis.line(X=np.array([engine.state.iteration]),
                            Y=np.array([engine.state.output['loss']]),
                            update='append',
                            win=train_loss_window)