def __init__(self, model: Module, optimizer: Optimizer, criterion, alpha: Union[float, Sequence[float]], temperature: float, train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: int = None, device=None, plugins: Optional[List[StrategyPlugin]] = None, evaluator: EvaluationPlugin = default_logger, eval_every=-1): """ Learning without Forgetting strategy. See LwF plugin for details. This strategy does not use task identities. :param model: The model. :param optimizer: The optimizer to use. :param criterion: The loss criterion to use. :param alpha: distillation hyperparameter. It can be either a float number or a list containing alpha for each experience. :param temperature: softmax temperature for distillation :param train_mb_size: The train minibatch size. Defaults to 1. :param train_epochs: The number of training epochs. Defaults to 1. :param eval_mb_size: The eval minibatch size. Defaults to 1. :param device: The device to use. Defaults to None (cpu). :param plugins: Plugins to be added. Defaults to None. :param evaluator: (optional) instance of EvaluationPlugin for logging and metric computations. :param eval_every: the frequency of the calls to `eval` inside the training loop. if -1: no evaluation during training. if 0: calls `eval` after the final epoch of each training experience. if >0: calls `eval` every `eval_every` epochs and at the end of all the epochs for a single experience. """ lwf = LwFPlugin(alpha, temperature) if plugins is None: plugins = [lwf] else: plugins.append(lwf) super().__init__(model, optimizer, criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, device=device, plugins=plugins, evaluator=evaluator, eval_every=eval_every)
def __init__(self, model: Module, optimizer: Optimizer, criterion, alpha: Union[float, Sequence[float]], temperature: float, train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, evaluator: EvaluationPlugin = default_evaluator, eval_every=-1, **base_kwargs): """Init. :param model: The model. :param optimizer: The optimizer to use. :param criterion: The loss criterion to use. :param alpha: distillation hyperparameter. It can be either a float number or a list containing alpha for each experience. :param temperature: softmax temperature for distillation :param train_mb_size: The train minibatch size. Defaults to 1. :param train_epochs: The number of training epochs. Defaults to 1. :param eval_mb_size: The eval minibatch size. Defaults to 1. :param device: The device to use. Defaults to None (cpu). :param plugins: Plugins to be added. Defaults to None. :param evaluator: (optional) instance of EvaluationPlugin for logging and metric computations. :param eval_every: the frequency of the calls to `eval` inside the training loop. -1 disables the evaluation. 0 means `eval` is called only at the end of the learning experience. Values >0 mean that `eval` is called every `eval_every` epochs and at the end of the learning experience. :param **base_kwargs: any additional :class:`~avalanche.training.BaseTemplate` constructor arguments. """ lwf = LwFPlugin(alpha, temperature) if plugins is None: plugins = [lwf] else: plugins.append(lwf) super().__init__(model, optimizer, criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, device=device, plugins=plugins, evaluator=evaluator, eval_every=eval_every, **base_kwargs)
def test_warning_slda_lwf(self): model, _, criterion, my_nc_benchmark = self.init_sit() with self.assertWarns(Warning) as cm: StreamingLDA( model, criterion, input_size=10, output_layer_name="features", num_classes=10, plugins=[LwFPlugin(), ReplayPlugin()], )
def test_warning_slda_lwf(self): model, _, criterion, my_nc_benchmark = self.init_sit() with self.assertLogs('avalanche.training.strategies', "WARNING") as cm: StreamingLDA(model, criterion, input_size=10, output_layer_name='features', num_classes=10, plugins=[LwFPlugin(), ReplayPlugin()]) self.assertEqual(1, len(cm.output)) self.assertIn( "LwFPlugin seems to use the callback before_backward" " which is disabled by StreamingLDA", cm.output[0] )