def attach(self, engine: Engine) -> None: if self._name is None: self.logger = engine.logger if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) if not engine.has_event_handler(self.saver.finalize, Events.COMPLETED): engine.add_event_handler(Events.COMPLETED, lambda engine: self.saver.finalize())
def attach(self, engine: Engine, name: str) -> None: """ Attaches current metric to provided engine. On the end of engine's run, `engine.state.metrics` dictionary will contain computed metric's value under provided name. Args: engine (Engine): the engine to which the metric must be attached name (str): the name of the metric to attach Example: .. code-block:: python metric = ... metric.attach(engine, "mymetric") assert "mymetric" in engine.run(data).metrics assert metric.is_attached(engine) """ engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name) if not engine.has_event_handler(self.started, Events.EPOCH_STARTED): engine.add_event_handler(Events.EPOCH_STARTED, self.started) if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
def test_remove_event_handler_on_callable_events(): engine = Engine(lambda e, b: 1) def foo(e): pass assert not engine.has_event_handler(foo) engine.add_event_handler(Events.EPOCH_STARTED, foo) assert engine.has_event_handler(foo) engine.remove_event_handler(foo, Events.EPOCH_STARTED) assert not engine.has_event_handler(foo) def bar(e): pass engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar) assert engine.has_event_handler(bar) engine.remove_event_handler(bar, Events.EPOCH_COMPLETED) assert not engine.has_event_handler(foo) with pytest.raises( TypeError, match=r"Argument event_name should not be a filtered event"): engine.remove_event_handler(bar, Events.EPOCH_COMPLETED(every=3))
def detach(self, engine: Engine) -> None: """ Detaches current metric from the engine and no metric's computation is done during the run. This method in conjunction with :meth:`~ignite.metrics.Metric.attach` can be useful if several metrics need to be computed with different periods. For example, one metric is computed every training epoch and another metric (e.g. more expensive one) is done every n-th training epoch. Args: engine (Engine): the engine from which the metric must be detached Example: .. code-block:: python metric = ... engine = ... metric.detach(engine) assert "mymetric" not in engine.run(data).metrics assert not metric.is_attached(engine) """ if engine.has_event_handler(self.completed, Events.EPOCH_COMPLETED): engine.remove_event_handler(self.completed, Events.EPOCH_COMPLETED) if engine.has_event_handler(self.started, Events.EPOCH_STARTED): engine.remove_event_handler(self.started, Events.EPOCH_STARTED) if engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): engine.remove_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED)
def attach(self, engine: Engine): """ Attaches lr_finder to engine. It is recommended to use `with lr_finder.attach(engine)` instead of explicitly detaching using `lr_finder.detach()` Args: engine: lr_finder is attached to this engine Notes: lr_finder cannot be attached to more than one engine at a time Returns: self """ if self._engine: raise AlreadyAttachedError( "This LRFinder is already attached. create a new one or use lr_finder.detach()" ) self._engine = weakref.ref(engine) if not engine.has_event_handler(self._run): engine.add_event_handler(Events.STARTED, self._run) if not engine.has_event_handler(self._warning): engine.add_event_handler(Events.COMPLETED, self._warning) if not engine.has_event_handler(self._reset): engine.add_event_handler(Events.COMPLETED, self._reset) return self
def attach(self, engine: Engine) -> None: """ Register a set of Ignite Event-Handlers to a specified Ignite engine. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ if self.name is None: self.logger = engine.logger if self.logger.getEffectiveLevel( ) > logging.INFO or logging.root.getEffectiveLevel() > logging.INFO: warnings.warn( "the effective log level of engine logger or RootLogger is higher than INFO, may not record log," " please call `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` to enable it." ) if self.iteration_log and not engine.has_event_handler( self.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) if self.epoch_log and not engine.has_event_handler( self.epoch_completed, Events.EPOCH_COMPLETED): engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed) if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED): engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised)
def attach(self, engine: Engine, name: str) -> None: engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name) if not engine.has_event_handler(self.started, Events.EPOCH_STARTED): engine.add_event_handler(Events.EPOCH_STARTED, self.started) if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
def _run( self, trainer: Engine, optimizer: Optimizer, output_transform: Callable, num_iter: int, start_lr: float, end_lr: float, step_mode: str, smooth_f: float, diverge_th: float, ) -> None: self._history = {"lr": [], "loss": []} self._best_loss = None self._diverge_flag = False # attach LRScheduler to trainer. if num_iter is None: num_iter = trainer.state.epoch_length * trainer.state.max_epochs else: max_iter = trainer.state.epoch_length * trainer.state.max_epochs # type: ignore[operator] if max_iter < num_iter: max_iter = num_iter trainer.state.max_iters = num_iter trainer.state.max_epochs = ceil( num_iter / trainer.state.epoch_length) # type: ignore[operator] if not trainer.has_event_handler(self._reached_num_iterations): trainer.add_event_handler(Events.ITERATION_COMPLETED, self._reached_num_iterations, num_iter) # attach loss and lr logging if not trainer.has_event_handler(self._log_lr_and_loss): trainer.add_event_handler(Events.ITERATION_COMPLETED, self._log_lr_and_loss, output_transform, smooth_f, diverge_th) self.logger.debug(f"Running LR finder for {num_iter} iterations") if start_lr is None: start_lr = optimizer.param_groups[0]["lr"] # Initialize the proper learning rate policy if step_mode.lower() == "exp": start_lr = [start_lr] * len(optimizer.param_groups) # type: ignore self._lr_schedule = LRScheduler( _ExponentialLR(optimizer, start_lr, end_lr, num_iter)) else: self._lr_schedule = PiecewiseLinear(optimizer, param_name="lr", milestones_values=[ (0, start_lr), (num_iter, end_lr) ]) if not trainer.has_event_handler(self._lr_schedule): trainer.add_event_handler(Events.ITERATION_COMPLETED, self._lr_schedule, num_iter)
def _internal_attach(self, engine: Engine) -> None: for index, metric in enumerate(itertools.chain(self.args, self.kwargs.values())): if isinstance(metric, MetricsLambda): metric._internal_attach(engine) elif isinstance(metric, Metric): if not engine.has_event_handler(metric.started, Events.EPOCH_STARTED): engine.add_event_handler(Events.EPOCH_STARTED, metric.started) if not engine.has_event_handler(metric.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, metric.iteration_completed)
def _run( self, trainer: Engine, optimizer: Optimizer, output_transform: Callable, num_iter: int, end_lr: float, step_mode: str, smooth_f: float, diverge_th: float, ): self._history = {"lr": [], "loss": []} self._best_loss = None self._diverge_flag = False # attach LRScheduler to trainer. if num_iter is None: num_iter = trainer.state.epoch_length * trainer.state.max_epochs else: max_iter = trainer.state.epoch_length * trainer.state.max_epochs if num_iter > max_iter: warnings.warn( "Desired num_iter {} is unreachable with the current run setup of {} iteration " "({} epochs)".format(num_iter, max_iter, trainer.state.max_epochs), UserWarning, ) if not trainer.has_event_handler(self._reached_num_iterations): trainer.add_event_handler(Events.ITERATION_COMPLETED, self._reached_num_iterations, num_iter) # attach loss and lr logging if not trainer.has_event_handler(self._log_lr_and_loss): trainer.add_event_handler(Events.ITERATION_COMPLETED, self._log_lr_and_loss, output_transform, smooth_f, diverge_th) self.logger.debug( "Running LR finder for {} iterations".format(num_iter)) # Initialize the proper learning rate policy if step_mode.lower() == "exp": self._lr_schedule = LRScheduler( _ExponentialLR(optimizer, end_lr, num_iter)) else: start_lr = optimizer.param_groups[0]["lr"] self._lr_schedule = PiecewiseLinear(optimizer, param_name="lr", milestones_values=[ (0, start_lr), (num_iter, end_lr) ]) if not trainer.has_event_handler(self._lr_schedule): trainer.add_event_handler(Events.ITERATION_COMPLETED, self._lr_schedule, num_iter)
def attach(self, engine: Engine): """Register a set of Ignite Event-Handlers to a specified Ignite engine. Args: engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. """ if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed)
def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ if self._name is None: self.logger = engine.logger if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) if not engine.has_event_handler(self.saver.finalize, Events.COMPLETED): engine.add_event_handler(Events.COMPLETED, lambda engine: self.saver.finalize())
def _internal_attach(self, engine: Engine, usage: MetricUsage) -> None: self.engine = engine for index, metric in enumerate(itertools.chain(self.args, self.kwargs.values())): if isinstance(metric, MetricsLambda): metric._internal_attach(engine, usage) elif isinstance(metric, Metric): # NB : metrics is attached partially # We must not use is_attached() but rather if these events exist if not engine.has_event_handler(metric.started, usage.STARTED): engine.add_event_handler(usage.STARTED, metric.started) if not engine.has_event_handler(metric.iteration_completed, usage.ITERATION_COMPLETED): engine.add_event_handler(usage.ITERATION_COMPLETED, metric.iteration_completed)
def _detach(self, trainer: Engine): """ Detaches lr_finder from trainer. Args: trainer: the trainer to detach form. """ if trainer.has_event_handler(self._run, Events.STARTED): trainer.remove_event_handler(self._run, Events.STARTED) if trainer.has_event_handler(self._warning, Events.COMPLETED): trainer.remove_event_handler(self._warning, Events.COMPLETED) if trainer.has_event_handler(self._reset, Events.COMPLETED): trainer.remove_event_handler(self._reset, Events.COMPLETED)
def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ if self._name is None: self.logger = engine.logger if not engine.has_event_handler(self._started, Events.EPOCH_STARTED): engine.add_event_handler(Events.EPOCH_STARTED, self._started) if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) if not engine.has_event_handler(self._finalize, Events.EPOCH_COMPLETED): engine.add_event_handler(Events.EPOCH_COMPLETED, self._finalize)
def _internal_is_attached(self, engine: Engine, usage: MetricUsage) -> bool: # if no engine, metrics is not attached if engine is None: return False # check recursively if metrics are attached is_detached = False for metric in itertools.chain(self.args, self.kwargs.values()): if isinstance(metric, MetricsLambda): if not metric._internal_is_attached(engine, usage): is_detached = True elif isinstance(metric, Metric): if not engine.has_event_handler(metric.started, usage.STARTED): is_detached = True if not engine.has_event_handler(metric.iteration_completed, usage.ITERATION_COMPLETED): is_detached = True return not is_detached
def attach(self, engine: Engine) -> None: """ Register a set of Ignite Event-Handlers to a specified Ignite engine. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ if self._name is None: self.logger = engine.logger if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed) if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED): engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised)
def detach( self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise()) -> None: """ Detaches current metric from the engine and no metric's computation is done during the run. This method in conjunction with :meth:`~ignite.metrics.Metric.attach` can be useful if several metrics need to be computed with different periods. For example, one metric is computed every training epoch and another metric (e.g. more expensive one) is done every n-th training epoch. Args: engine (Engine): the engine from which the metric must be detached usage (str or MetricUsage, optional): the usage of the metric. Valid string values should be 'epoch_wise' (default) or 'batch_wise'. Example: .. code-block:: python metric = ... engine = ... metric.detach(engine) assert "mymetric" not in engine.run(data).metrics assert not metric.is_attached(engine) Example with usage: .. code-block:: python metric = ... engine = ... metric.detach(engine, usage="batch_wise") assert "mymetric" not in engine.run(data).metrics assert not metric.is_attached(engine, usage="batch_wise") """ usage = self._check_usage(usage) if engine.has_event_handler(self.completed, usage.COMPLETED): engine.remove_event_handler(self.completed, usage.COMPLETED) if engine.has_event_handler(self.started, usage.STARTED): engine.remove_event_handler(self.started, usage.STARTED) if engine.has_event_handler(self.iteration_completed, usage.ITERATION_COMPLETED): engine.remove_event_handler(self.iteration_completed, usage.ITERATION_COMPLETED)
def attach( self, engine: Engine, name: str, usage: Union[str, MetricUsage] = EpochWise()) -> None: """ Attaches current metric to provided engine. On the end of engine's run, `engine.state.metrics` dictionary will contain computed metric's value under provided name. Args: engine (Engine): the engine to which the metric must be attached name (str): the name of the metric to attach usage (str or MetricUsage, optional): the usage of the metric. Valid string values should be :attr:`ignite.metrics.EpochWise.usage_name` (default) or :attr:`ignite.metrics.BatchWise.usage_name`. Example: .. code-block:: python metric = ... metric.attach(engine, "mymetric") assert "mymetric" in engine.run(data).metrics assert metric.is_attached(engine) Example with usage: .. code-block:: python metric = ... metric.attach(engine, "mymetric", usage=BatchWise.usage_name) assert "mymetric" in engine.run(data).metrics assert metric.is_attached(engine, usage=BatchWise.usage_name) """ usage = self._check_usage(usage) if not engine.has_event_handler(self.started, usage.STARTED): engine.add_event_handler(usage.STARTED, self.started) if not engine.has_event_handler(self.iteration_completed, usage.ITERATION_COMPLETED): engine.add_event_handler(usage.ITERATION_COMPLETED, self.iteration_completed) engine.add_event_handler(usage.COMPLETED, self.completed, name)
def attach(self, engine: Engine): if not isinstance(engine, Engine): raise TypeError("Argument engine should be ignite.engine.Engine, " "but given {}".format(type(engine))) if not engine.has_event_handler(self._as_first_started): engine._event_handlers[Events.STARTED].insert( 0, (self._as_first_started, (engine, ), {}))
def test_remove_event_handler_on_callable_events(): engine = Engine(lambda e, b: 1) def foo(e): pass assert not engine.has_event_handler(foo) engine.add_event_handler(Events.EPOCH_STARTED, foo) assert engine.has_event_handler(foo) engine.remove_event_handler(foo, Events.EPOCH_STARTED) assert not engine.has_event_handler(foo) def bar(e): pass engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar) assert engine.has_event_handler(bar) engine.remove_event_handler(bar, Events.EPOCH_COMPLETED) assert not engine.has_event_handler(bar) engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar) assert engine.has_event_handler(bar) engine.remove_event_handler(bar, Events.EPOCH_COMPLETED(every=3)) assert not engine.has_event_handler(bar)
def is_attached(self, engine: Engine) -> bool: """ Checks if current metric is attached to provided engine. If attached, metric's computed value is written to `engine.state.metrics` dictionary. Args: engine (Engine): the engine checked from which the metric should be attached """ return engine.has_event_handler(self.completed, Events.EPOCH_COMPLETED)
def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ for name, info in engine.data_loader.dataset.info.items(): self.case_names.append(name) self.probs_maps[name] = np.zeros(info['mask_dims']) self.levels[name] = info['level'] if self._name is None: self.logger = engine.logger if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) if not engine.has_event_handler(self.finalize, Events.COMPLETED): engine.add_event_handler(Events.COMPLETED, lambda engine: self.finalize())
def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ self.num_images = len(engine.data_loader.dataset.data) for sample in engine.data_loader.dataset.data: name = sample["name"] self.prob_map[name] = np.zeros(sample["mask_shape"], dtype=self.dtype) self.counter[name] = len(sample["mask_locations"]) self.level[name] = sample["level"] if self._name is None: self.logger = engine.logger if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) if not engine.has_event_handler(self.finalize, Events.COMPLETED): engine.add_event_handler(Events.COMPLETED, self.finalize)
def is_attached(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise()) -> bool: """ Checks if current metric is attached to provided engine. If attached, metric's computed value is written to `engine.state.metrics` dictionary. Args: engine: the engine checked from which the metric should be attached usage: the usage of the metric. Valid string values should be 'epoch_wise' (default) or 'batch_wise'. """ usage = self._check_usage(usage) return engine.has_event_handler(self.completed, usage.COMPLETED)
def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ image_data = engine.data_loader.dataset.image_data # type: ignore self.num_images = len(image_data) # Initialized probability maps for all the images for sample in image_data: name = sample[ProbMapKeys.NAME] self.counter[name] = sample[ProbMapKeys.COUNT] self.prob_map[name] = np.zeros(sample[ProbMapKeys.SIZE], dtype=self.dtype) if self._name is None: self.logger = engine.logger if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) if not engine.has_event_handler(self.finalize, Events.COMPLETED): engine.add_event_handler(Events.COMPLETED, self.finalize)
def attach(self, engine: Engine) -> None: """Attach HandlersTimeProfiler to the given engine. Args: engine: the instance of Engine to attach """ if not isinstance(engine, Engine): raise TypeError( f"Argument engine should be ignite.engine.Engine, but given {type(engine)}" ) if not engine.has_event_handler(self._as_first_started): engine._event_handlers[Events.STARTED].insert( 0, (self._as_first_started, (engine, ), {}))
class Trainer(object): def __init__(self: TrainerType, model: nn.Module, optimizer: Optimizer, checkpoint_dir: str = '../../checkpoints', experiment_name: str = 'experiment', model_checkpoint: Optional[str] = None, optimizer_checkpoint: Optional[str] = None, metrics: types.GenericDict = None, patience: int = 10, validate_every: int = 1, accumulation_steps: int = 1, loss_fn: Union[_Loss, DataParallelCriterion] = None, non_blocking: bool = True, retain_graph: bool = False, dtype: torch.dtype = torch.float, device: str = 'cpu', parallel: bool = False) -> None: self.dtype = dtype self.retain_graph = retain_graph self.non_blocking = non_blocking self.device = device self.loss_fn = loss_fn self.validate_every = validate_every self.patience = patience self.accumulation_steps = accumulation_steps self.checkpoint_dir = checkpoint_dir model_checkpoint = self._check_checkpoint(model_checkpoint) optimizer_checkpoint = self._check_checkpoint(optimizer_checkpoint) self.model = cast( nn.Module, from_checkpoint(model_checkpoint, model, map_location=torch.device('cpu'))) self.model = self.model.type(dtype).to(device) self.optimizer = from_checkpoint(optimizer_checkpoint, optimizer) self.parallel = parallel if parallel: if device == 'cpu': raise ValueError("parallel can be used only with cuda device") self.model = DataParallelModel(self.model).to(device) self.loss_fn = DataParallelCriterion(self.loss_fn) # type: ignore if metrics is None: metrics = {} if 'loss' not in metrics: if self.parallel: metrics['loss'] = Loss( lambda x, y: self.loss_fn(x, y).mean()) # type: ignore else: metrics['loss'] = Loss(self.loss_fn) self.trainer = Engine(self.train_step) self.train_evaluator = Engine(self.eval_step) self.valid_evaluator = Engine(self.eval_step) for name, metric in metrics.items(): metric.attach(self.train_evaluator, name) metric.attach(self.valid_evaluator, name) self.pbar = ProgressBar() self.val_pbar = ProgressBar(desc='Validation') if checkpoint_dir is not None: self.checkpoint = CheckpointHandler(checkpoint_dir, experiment_name, score_name='validation_loss', score_function=self._score_fn, n_saved=2, require_empty=False, save_as_state_dict=True) self.early_stop = EarlyStopping(patience, self._score_fn, self.trainer) self.val_handler = EvaluationHandler(pbar=self.pbar, validate_every=1, early_stopping=self.early_stop) self.attach() log.info( f'Trainer configured to run {experiment_name}\n' f'\tpretrained model: {model_checkpoint} {optimizer_checkpoint}\n' f'\tcheckpoint directory: {checkpoint_dir}\n' f'\tpatience: {patience}\n' f'\taccumulation steps: {accumulation_steps}\n' f'\tnon blocking: {non_blocking}\n' f'\tretain graph: {retain_graph}\n' f'\tdevice: {device}\n' f'\tmodel dtype: {dtype}\n' f'\tparallel: {parallel}') def _check_checkpoint(self: TrainerType, ckpt: Optional[str]) -> Optional[str]: if ckpt is None: return ckpt if system.is_url(ckpt): ckpt = system.download_url(cast(str, ckpt), self.checkpoint_dir) ckpt = os.path.join(self.checkpoint_dir, ckpt) return ckpt @staticmethod def _score_fn(engine: Engine) -> float: """Returns the scoring metric for checkpointing and early stopping Args: engine (ignite.engine.Engine): The engine that calculates the val loss Returns: (float): The validation loss """ negloss: float = -engine.state.metrics['loss'] return negloss def parse_batch(self: TrainerType, batch: List[torch.Tensor]) -> Tuple[torch.Tensor, ...]: inputs = to_device(batch[0], device=self.device, non_blocking=self.non_blocking) targets = to_device(batch[1], device=self.device, non_blocking=self.non_blocking) return inputs, targets def get_predictions_and_targets( self: TrainerType, batch: List[torch.Tensor]) -> Tuple[torch.Tensor, ...]: inputs, targets = self.parse_batch(batch) y_pred = self.model(inputs) return y_pred, targets def train_step(self: TrainerType, engine: Engine, batch: List[torch.Tensor]) -> float: self.model.train() y_pred, targets = self.get_predictions_and_targets(batch) loss = self.loss_fn(y_pred, targets.long()) # type: ignore if self.parallel: loss = loss.mean() loss = loss / self.accumulation_steps loss.backward(retain_graph=self.retain_graph) if (self.trainer.state.iteration + 1) % self.accumulation_steps == 0: self.optimizer.step() # type: ignore self.optimizer.zero_grad() loss_value: float = loss.item() return loss_value def eval_step(self: TrainerType, engine: Engine, batch: List[torch.Tensor]) -> Tuple[torch.Tensor, ...]: self.model.eval() with torch.no_grad(): y_pred, targets = self.get_predictions_and_targets(batch) return y_pred, targets def predict(self: TrainerType, dataloader: DataLoader) -> State: return self.valid_evaluator.run(dataloader) def fit(self: TrainerType, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 50) -> State: log.info('Trainer will run for\n' f'model: {self.model}\n' f'optimizer: {self.optimizer}\n' f'loss: {self.loss_fn}') self.val_handler.attach(self.trainer, self.train_evaluator, train_loader, validation=False) self.val_handler.attach(self.trainer, self.valid_evaluator, val_loader, validation=True) self.model.zero_grad() self.trainer.run(train_loader, max_epochs=epochs) best_score = (-self.early_stop.best_score if self.early_stop else self.valid_evaluator.state.metrics['loss']) return best_score def overfit_single_batch(self: TrainerType, train_loader: DataLoader) -> State: single_batch = [next(iter(train_loader))] if self.trainer.has_event_handler(self.val_handler, Events.EPOCH_COMPLETED): self.trainer.remove_event_handler(self.val_handler, Events.EPOCH_COMPLETED) self.val_handler.attach( self.trainer, self.train_evaluator, single_batch, # type: ignore validation=False) out = self.trainer.run(single_batch, max_epochs=100) return out def fit_debug(self: TrainerType, train_loader: DataLoader, val_loader: DataLoader) -> State: train_loader = iter(train_loader) train_subset = [next(train_loader), next(train_loader)] val_loader = iter(val_loader) # type: ignore val_subset = [next(val_loader), next(val_loader)] # type ignore out = self.fit(train_subset, val_subset, epochs=6) # type: ignore return out def _attach_checkpoint(self: TrainerType) -> TrainerType: ckpt = {'model': self.model, 'optimizer': self.optimizer} if self.checkpoint_dir is not None: self.valid_evaluator.add_event_handler(Events.COMPLETED, self.checkpoint, ckpt) return self def attach(self: TrainerType) -> TrainerType: ra = RunningAverage(output_transform=lambda x: x) ra.attach(self.trainer, "Train Loss") self.pbar.attach(self.trainer, ['Train Loss']) self.val_pbar.attach(self.train_evaluator) self.val_pbar.attach(self.valid_evaluator) self.valid_evaluator.add_event_handler(Events.COMPLETED, self.early_stop) self = self._attach_checkpoint() def graceful_exit(engine, e): if isinstance(e, KeyboardInterrupt): engine.terminate() log.warn("CTRL-C caught. Exiting gracefully...") else: raise (e) self.trainer.add_event_handler(Events.EXCEPTION_RAISED, graceful_exit) self.train_evaluator.add_event_handler(Events.EXCEPTION_RAISED, graceful_exit) self.valid_evaluator.add_event_handler(Events.EXCEPTION_RAISED, graceful_exit) return self
def attach( self, trainer: Engine, to_save: Mapping, output_transform: Callable = lambda output: output, num_iter: Optional[int] = None, end_lr: float = 10.0, step_mode: str = "exp", smooth_f: float = 0.05, diverge_th: float = 5.0, ): """Attaches lr_finder to a given trainer. It also resets model and optimizer at the end of the run. Usage: .. code-block:: python to_save = {"model": model, "optimizer": optimizer} with lr_finder.attach(trainer, to_save=to_save) as trainer_with_lr_finder: trainer_with_lr_finder.run(dataloader)` Args: trainer (Engine): lr_finder is attached to this trainer. Please, keep in mind that all attached handlers will be executed. to_save (Mapping): dictionary with optimizer and other objects that needs to be restored after running the LR finder. For example, `to_save={'optimizer': optimizer, 'model': model}`. All objects should implement `state_dict` and `load_state_dict` methods. output_transform (callable, optional): function that transforms the trainer's `state.output` after each iteration. It must return the loss of that iteration. num_iter (int, optional): number of iterations for lr schedule between base lr and end_lr. Default, it will run for `trainer.state.epoch_length * trainer.state.max_epochs`. end_lr (float, optional): upper bound for lr search. Default, 10.0. step_mode (str, optional): "exp" or "linear", which way should the lr be increased from optimizer's initial lr to `end_lr`. Default, "exp". smooth_f (float, optional): loss smoothing factor in range `[0, 1)`. Default, 0.05 diverge_th (float, optional): Used for stopping the search when `current loss > diverge_th * best_loss`. Default, 5.0. Note: lr_finder cannot be attached to more than one trainer at a time. Returns: trainer_with_lr_finder: trainer used for finding the lr """ if not isinstance(to_save, Mapping): raise TypeError("Argument to_save should be a mapping, but given {}".format(type(to_save))) Checkpoint._check_objects(to_save, "state_dict") Checkpoint._check_objects(to_save, "load_state_dict") if "optimizer" not in to_save: raise ValueError("Mapping to_save should contain 'optimizer' key") if not isinstance(to_save["optimizer"], torch.optim.Optimizer): raise TypeError( "Object to_save['optimizer'] should be torch optimizer, but given {}".format(type(to_save["optimizer"])) ) if smooth_f < 0 or smooth_f >= 1: raise ValueError("smooth_f is outside the range [0, 1]") if diverge_th < 1: raise ValueError("diverge_th should be larger than 1") if step_mode not in ["exp", "linear"]: raise ValueError("step_mode should be 'exp' or 'linear', but given {}".format(step_mode)) if num_iter is not None: if not isinstance(num_iter, int): raise TypeError("if provided, num_iter should be an integer, but give {}".format(num_iter)) if num_iter <= 0: raise ValueError("if provided, num_iter should be positive, but give {}".format(num_iter)) # store to_save with tempfile.TemporaryDirectory() as tmpdirname: obj = {k: o.state_dict() for k, o in to_save.items()} # add trainer obj["trainer"] = trainer.state_dict() cache_filepath = Path(tmpdirname) / "ignite_lr_finder_cache.pt" torch.save(obj, cache_filepath.as_posix()) optimizer = to_save["optimizer"] # Attach handlers if not trainer.has_event_handler(self._run): trainer.add_event_handler( Events.STARTED, self._run, optimizer, output_transform, num_iter, end_lr, step_mode, smooth_f, diverge_th, ) if not trainer.has_event_handler(self._warning): trainer.add_event_handler(Events.COMPLETED, self._warning) if not trainer.has_event_handler(self._reset): trainer.add_event_handler(Events.COMPLETED, self._reset) yield trainer self._detach(trainer) # restore to_save and reset trainer's state obj = torch.load(cache_filepath.as_posix()) trainer.load_state_dict(obj["trainer"]) for k, o in obj.items(): if k in to_save: to_save[k].load_state_dict(o)
class IgniteJunction(Junction): """ This abstracts the functionality of an Ignite Engine into Junction format. See the Ignite documentation (https://github.com/pytorch/ignite) for more details. """ required_components = {'model': Model, 'dataset': object} # These dictionaries describe the allowed optimizers and learning rate schedulers, along with the keyword arguments that each can accept. # These are all the optimizers/schedulers that PyTorch includes. optimizers = { 'Adadelta': optim.Adadelta, 'Adagrad': optim.Adagrad, 'Adam': optim.Adam, 'SparseAdam': optim.SparseAdam, 'Adamax': optim.Adamax, 'ASGD': optim.ASGD, 'LBFGS': optim.LBFGS, 'RMSprop': optim.RMSprop, 'Rprop': optim.Rprop, 'SGD': optim.SGD, } optimizer_kwargs = { 'Adadelta': ['lr', 'rho', 'eps', 'weight_decay'], 'Adagrad': ['lr', 'lr_decay', 'weight_decay', 'initial_accumulator_value'], 'Adam': ['lr', 'betas', 'eps', 'weight_decay', 'amsgrad'], 'SparseAdam': ['lr', 'betas', 'eps'], 'Adamax': ['lr', 'betas', 'eps', 'weight_decay'], 'ASGD': ['lr', 'lambd', 'alpha', 't0', 'weight_decay'], 'LBFGS': ['lr', 'max_iter', 'max_eval', 'tolerance_grad', 'tolerance_change', 'history_size', 'line_search_fn'], 'RMSprop': ['lr', 'alpha', 'eps', 'weight_decay', 'momentum', 'centered'], 'Rprop': ['lr', 'etas', 'step_sizes'], 'SGD': ['lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'], } schedulers = { 'LambdaLR': optim.lr_scheduler.LambdaLR, 'StepLR': optim.lr_scheduler.StepLR, 'MultiStepLR': optim.lr_scheduler.MultiStepLR, 'ExponentialLR': optim.lr_scheduler.ExponentialLR, 'CosineAnnealingLR': optim.lr_scheduler.CosineAnnealingLR, 'ReduceLROnPlateau': optim.lr_scheduler.ReduceLROnPlateau, } scheduler_kwargs = { 'LambdaLR': ['lr_lambda', 'last_epoch'], 'StepLR': ['step_size', 'gamma', 'last_epoch'], 'MultiStepLR': ['milestones', 'gamma', 'last_epoch'], 'ExponentialLR': ['gamma', 'last_epoch'], 'CosineAnnealingLR': ['T_max', 'eta_min', 'last_epoch'], 'ReduceLROnPlateau': ['mode', 'factor', 'patience', 'verbose', 'threshold', 'threshold_mode', 'cooldown', 'min_lr', 'eps'], } def __init__(self, components, loss, optimizer, scheduler=None, update_function=default_training_closure, visdom=True, environment='default', description='', **kwargs): Junction.__init__(self, components = components) parameters = [x for x in self.model.all_parameters() if x.requires_grad] # Initialize engine self.optimizer = self.optimizers[optimizer](parameters, **subset_dict(kwargs, self.optimizer_kwargs[optimizer])) if scheduler is not None: self.optimizer = self.schedulers[scheduler](self.optimizer, **subset_dict(kwargs, self.scheduler_kwargs[scheduler])) self.loss = loss self.update_function = update_function(self.model, self.optimizer, self.loss) self.engine = Engine(self.update_function) # Configure metrics and events if not visdom: environment = None self.attach_events(environment=environment, description=description) def train(self, dataset = None, max_epochs=10): dataset = dataset or self.dataset self.engine.run(dataset, max_epochs=max_epochs) def run(self, *args, **kwargs): """ Alias for train """ self.train(*args, **kwargs) def add_event_handler(self, *args, **kwargs): self.engine.add_event_handler(*args, **kwargs) def has_event_handler(self, *args, **kwargs): return self.engine.has_event_handler(*args, **kwargs) def attach_events(self, description, environment=None, save_file = None): tim = Timer() tim.attach( self.engine, start=Events.STARTED, step=Events.ITERATION_COMPLETED, ) log_interval = 100 plot_interval = 10 @self.engine.on(Events.ITERATION_COMPLETED) def print_training_loss(engine): iter = (engine.state.iteration -1) if iter % log_interval == 0: print("Epoch[{}] Iteration: {} Time: {} Loss: {:.2f}".format( engine.state.epoch, iter, str(datetime.timedelta(seconds=int(tim.value()))), engine.state.output['loss'] )) if environment: vis = visdom.Visdom(env=environment) def create_plot_window(vis, xlabel, ylabel, title): return vis.line(X=np.array([1]), Y=np.array([np.nan]), opts=dict(xlabel=xlabel, ylabel=ylabel, title=title)) train_loss_window = create_plot_window(vis, '#Iterations', 'Loss', 'Training Loss {0}'.format(description)) @self.engine.on(Events.ITERATION_COMPLETED) def plot_training_loss(engine): iter = (engine.state.iteration -1) if iter % plot_interval == 0: vis.line(X=np.array([engine.state.iteration]), Y=np.array([engine.state.output['loss']]), update='append', win=train_loss_window)