Exemple #1
0
    def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
        """Prepares the callbacks for selected stage.

        Args:
            stage: stage name

        Returns:
            dictionary with stage callbacks
        """
        callbacks = super().get_callbacks(stage=stage)
        is_callback_exists = lambda callback_fn: any(
            callback_isinstance(x, callback_fn) for x in callbacks.values()
        )
        if isinstance(self._criterion, Criterion) and not is_callback_exists(ICriterionCallback):
            callbacks["_criterion"] = CriterionCallback(
                input_key=self._output_key, target_key=self._target_key, metric_key=self._loss_key,
            )
        if isinstance(self._optimizer, Optimizer) and not is_callback_exists(IOptimizerCallback):
            callbacks["_optimizer"] = OptimizerCallback(metric_key=self._loss_key)
        if isinstance(self._scheduler, (Scheduler, ReduceLROnPlateau)) and not is_callback_exists(
            ISchedulerCallback
        ):
            callbacks["_scheduler"] = SchedulerCallback(
                loader_key=self._valid_loader, metric_key=self._valid_metric
            )
        return callbacks
Exemple #2
0
 def get_callbacks(self) -> "OrderedDict[str, Callback]":
     """Returns the callbacks for the experiment."""
     callbacks = sort_callbacks_by_order(super().get_callbacks())
     callback_exists = lambda callback_fn: any(
         callback_isinstance(x, callback_fn) for x in callbacks.values())
     if isinstance(
             self._criterion,
             TorchCriterion) and not callback_exists(ICriterionCallback):
         callbacks["_criterion"] = CriterionCallback(
             input_key=self._output_key,
             target_key=self._target_key,
             metric_key=self._loss_key,
         )
     if isinstance(
             self._optimizer,
             TorchOptimizer) and not callback_exists(IBackwardCallback):
         callbacks["_backward"] = BackwardCallback(
             metric_key=self._loss_key)
     if isinstance(
             self._optimizer,
             TorchOptimizer) and not callback_exists(IOptimizerCallback):
         callbacks["_optimizer"] = OptimizerCallback(
             metric_key=self._loss_key)
     if isinstance(
             self._scheduler,
             TorchScheduler) and not callback_exists(ISchedulerCallback):
         callbacks["_scheduler"] = SchedulerCallback(
             loader_key=self._valid_loader, metric_key=self._valid_metric)
     return callbacks
Exemple #3
0
def test_save_model_grads():
    """
    Tests a feature of `OptimizerCallback` for saving model gradients
    """
    logdir = "./logs"
    dataset_root = "./data"
    loaders = _get_loaders(root=dataset_root, batch_size=4, num_workers=1)
    images, _ = next(iter(loaders["train"]))
    _, c, h, w = images.shape
    input_shape = (c, h, w)

    model = _SimpleNet(input_shape)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters())

    criterion_callback = CriterionCallback()
    optimizer_callback = OptimizerCallback()
    save_model_grads_callback = GradNormLogger()
    prefix = save_model_grads_callback.grad_norm_prefix
    test_callback = _OnBatchEndCheckGradsCallback(prefix)

    callbacks = collections.OrderedDict(
        loss=criterion_callback,
        optimizer=optimizer_callback,
        grad_norm=save_model_grads_callback,
        test_callback=test_callback,
    )

    runner = SupervisedRunner()
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        loaders=loaders,
        logdir=logdir,
        callbacks=callbacks,
        check=True,
        verbose=True,
    )

    shutil.rmtree(logdir)
Exemple #4
0
    def get_callbacks(self) -> "OrderedDict[str, Callback]":
        """Prepares the callbacks for selected stage.

        Args:
            stage: stage name

        Returns:
            dictionary with stage callbacks
        """
        callbacks = super().get_callbacks()
        callback_exists = lambda callback_fn: any(
            callback_isinstance(x, callback_fn) for x in callbacks.values())
        if isinstance(
                self._criterion,
                TorchCriterion) and not callback_exists(ICriterionCallback):
            callbacks["_criterion"] = CriterionCallback(
                input_key=f"{self.loss_mode_prefix}_left",
                target_key=f"{self.loss_mode_prefix}_right",
                metric_key=self._loss_key,
            )
        if isinstance(
                self._optimizer,
                TorchOptimizer) and not callback_exists(IBackwardCallback):
            callbacks["_backward"] = BackwardCallback(
                metric_key=self._loss_key)
        if isinstance(
                self._optimizer,
                TorchOptimizer) and not callback_exists(IOptimizerCallback):
            callbacks["_optimizer"] = OptimizerCallback(
                metric_key=self._loss_key)
        if isinstance(
                self._scheduler,
                TorchScheduler) and not callback_exists(ISchedulerCallback):
            callbacks["_scheduler"] = SchedulerCallback(
                loader_key=self._valid_loader, metric_key=self._valid_metric)
        return callbacks