class ModelDictOutputTest(ModelFittingTestCase):

    def setUp(self):
        super().setUp()
        torch.manual_seed(42)
        self.pytorch_network = DictOutputModel()
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=1e-3)

        self.model = Model(
            self.pytorch_network,
            self.optimizer,
            lambda y_p, y_t: self.loss_function(y_p['out1'], y_t[0]) + self.loss_function(y_p['out2'], y_t[1]),
            batch_metrics=self.batch_metrics,
            epoch_metrics=self.epoch_metrics)

    def test_fitting_with_tensor_multi_output_dict(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelDictOutputTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - \
                     train_final_batch_missing_samples
        train_x = torch.rand(train_size, 1)
        train_y = (torch.rand(train_size, 1), torch.rand(train_size, 1))

        valid_real_steps_per_epoch = 10
        # valid_batch_size will be the same as train_batch_size in the fit method.
        valid_batch_size = train_batch_size
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - \
                     valid_final_batch_missing_samples
        valid_x = torch.rand(valid_size, 1)
        valid_y = (torch.rand(valid_size, 1), torch.rand(valid_size, 1))

        logs = self.model.fit(train_x,
                              train_y,
                              validation_data=(valid_x, valid_y),
                              epochs=ModelDictOutputTest.epochs,
                              batch_size=train_batch_size,
                              steps_per_epoch=None,
                              validation_steps=None,
                              callbacks=[self.mock_callback])
        params = {'epochs': ModelDictOutputTest.epochs, 'steps': train_real_steps_per_epoch}
        self._test_callbacks_train(params, logs)

    def test_ndarray_train_on_batch_dict_output(self):
        x = np.random.rand(ModelDictOutputTest.batch_size, 1).astype(np.float32)
        y1 = np.random.rand(ModelDictOutputTest.batch_size, 1).astype(np.float32)
        y2 = np.random.rand(ModelDictOutputTest.batch_size, 1).astype(np.float32)
        loss = self.model.train_on_batch(x, (y1, y2))
        self.assertEqual(type(loss), float)

    def test_evaluate_with_pred_dict_output(self):
        y = (torch.rand(ModelDictOutputTest.evaluate_dataset_len,
                        1), torch.rand(ModelDictOutputTest.evaluate_dataset_len, 1))
        x = torch.rand(ModelDictOutputTest.evaluate_dataset_len, 1)
        # We also test the unpacking.
        _, pred_y = self.model.evaluate(x, y, batch_size=ModelDictOutputTest.batch_size, return_pred=True)
        for pred in pred_y.values():
            self.assertEqual(pred.shape, (ModelDictOutputTest.evaluate_dataset_len, 1))
Beispiel #2
0
    def setUp(self):
        super().setUp()
        self.train_generator = some_data_tensor_generator(ModelFittingTestCase.batch_size)
        self.valid_generator = some_data_tensor_generator(ModelFittingTestCase.batch_size)
        self.test_generator = some_data_tensor_generator(ModelFittingTestCase.batch_size)
        torch.manual_seed(42)
        self.pytorch_network = nn.Linear(1, 1)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=1e-3)
        self.batch_metrics = [
            some_batch_metric_1, ('custom_name', some_batch_metric_2), repeat_batch_metric, repeat_batch_metric
        ]
        self.batch_metrics_names = [
            'some_batch_metric_1', 'custom_name', 'repeat_batch_metric1', 'repeat_batch_metric2'
        ]
        self.batch_metrics_values = [
            some_metric_1_value, some_metric_2_value, repeat_batch_metric_value, repeat_batch_metric_value
        ]
        self.epoch_metrics = [SomeConstantEpochMetric()]
        self.epoch_metrics_names = ['some_constant_epoch_metric']
        self.epoch_metrics_values = [some_constant_epoch_metric_value]

        self.model = Model(self.pytorch_network,
                           self.optimizer,
                           self.loss_function,
                           batch_metrics=self.batch_metrics,
                           epoch_metrics=self.epoch_metrics)

        self._capture_output()
Beispiel #3
0
    def setUp(self):
        super().setUp()
        torch.manual_seed(42)
        self.pytorch_network = nn.Linear(1, 1)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(),
                                         lr=1e-3)
        self.batch_metrics = [
            some_batch_metric_1, ('custom_name', some_batch_metric_2),
            repeat_batch_metric, repeat_batch_metric
        ]
        self.batch_metrics_names = [
            'some_batch_metric_1', 'custom_name', 'repeat_batch_metric1',
            'repeat_batch_metric2'
        ]
        self.batch_metrics_values = [
            some_metric_1_value, some_metric_2_value,
            repeat_batch_metric_value, repeat_batch_metric_value
        ]
        self.epoch_metrics = [SomeConstantEpochMetric()]
        self.epoch_metrics_names = ['some_constant_epoch_metric']
        self.epoch_metrics_values = [some_constant_epoch_metric_value]

        self.model = Model(self.pytorch_network,
                           self.optimizer,
                           self.loss_function,
                           batch_metrics=self.batch_metrics,
                           epoch_metrics=self.epoch_metrics)

        self.default_main_device = ModelTestMultiGPU.cuda_device
Beispiel #4
0
 def setUp(self):
     torch.manual_seed(42)
     self.pytorch_network = nn.Linear(1, 1)
     self.loss_function = nn.MSELoss()
     self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(),
                                      lr=1e-3)
     self.model = Model(self.pytorch_network, self.optimizer,
                        self.loss_function)
Beispiel #5
0
 def setUp(self):
     torch.manual_seed(42)
     self.pytorch_network = nn.Linear(1, 1)
     self.loss_function = nn.MSELoss()
     self.optimizer = torch.optim.Adam(self.pytorch_network.parameters(), lr=1e-3)
     self.model = Model(self.pytorch_network, self.optimizer, self.loss_function)
     self.temp_dir_obj = TemporaryDirectory()
     self.checkpoint_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint_{epoch}.optim')
Beispiel #6
0
 def setUp(self):
     torch.manual_seed(42)
     self.pytorch_network = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 1))
     self.loss_function = nn.MSELoss()
     self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=1e-3)
     self.model = Model(self.pytorch_network, self.optimizer, self.loss_function)
     self.temp_dir_obj = TemporaryDirectory()
     self.csv_filename = os.path.join(self.temp_dir_obj.name, 'layer_{}.csv')
Beispiel #7
0
 def setUp(self):
     torch.manual_seed(42)
     self.pytorch_network = nn.Linear(1, 1)
     self.loss_function = nn.MSELoss()
     self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=BaseCSVLoggerTest.lr)
     self.model = Model(self.pytorch_network, self.optimizer, self.loss_function)
     self.temp_dir_obj = TemporaryDirectory()
     self.csv_filename = os.path.join(self.temp_dir_obj.name, 'my_log.csv')
def test_poutyne():
    callback = PlotLossesPoutyne(outputs=(CheckOutput(), ))
    network = Network()
    optimizer = optim.Adam(params=network.parameters(), lr=0.001)
    loss_fn = nn.CrossEntropyLoss()

    train_dataloader = get_random_data()

    model = Model(network, optimizer, loss_fn)
    model.fit_generator(train_dataloader, epochs=2, callbacks=[callback])
Beispiel #9
0
 def setUp(self):
     torch.manual_seed(42)
     self.pytorch_network = nn.Linear(1, 1)
     self.loss_function = nn.MSELoss()
     self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(),
                                      lr=1e-3)
     self.model = Model(self.pytorch_network, self.optimizer,
                        self.loss_function)
     self.train_gen = some_data_generator(20)
     self.valid_gen = some_data_generator(20)
Beispiel #10
0
 def setUp(self):
     torch.manual_seed(42)
     self.pytorch_network = nn.Linear(1, 1)
     self.loss_function = nn.MSELoss()
     self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=1e-3)
     self.model = Model(self.pytorch_network, self.optimizer, self.loss_function)
     self.mock_callback = MagicMock(spec=Callback)
     self.delay_callback = DelayCallback(self.mock_callback)
     self.train_dict = {'loss': ANY, 'time': ANY}
     self.log_dict = {'loss': ANY, 'val_loss': ANY, 'time': ANY}
Beispiel #11
0
 def setUp(self):
     torch.manual_seed(42)
     self.pytorch_network = nn.Linear(1, 1)
     self.loss_function = nn.MSELoss()
     self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=BaseTensorBoardLoggerTest.lr)
     self.model = Model(self.pytorch_network, self.optimizer, self.loss_function)
     self.temp_dir_obj = TemporaryDirectory()
     # pylint: disable=not-callable
     self.writer = self.SummaryWriter(self.temp_dir_obj.name)
     self.writer.add_scalars = MagicMock()
 def test_epoch_metrics_with_multiple_names_returned_by_tensor_on_gpu(self):
     with torch.cuda.device(MetricsModelIntegrationTest.cuda_device):
         epoch_metric = ConstEpochMetric(
             torch.tensor(self.metric_values).cuda())
         model = Model(self.pytorch_network,
                       self.optimizer,
                       self.loss_function,
                       epoch_metrics=[(self.metric_names, epoch_metric)])
         model.cuda()
         self._test_history(model, self.metric_names, self.metric_values)
Beispiel #13
0
    def setUp(self):
        super().setUp()
        torch.manual_seed(42)
        self.pytorch_network = DictIOModel(['x1', 'x2'], ['y1', 'y2'])
        self.loss_function = dict_mse_loss
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=1e-3)

        self.model = Model(self.pytorch_network,
                           self.optimizer,
                           self.loss_function,
                           batch_metrics=self.batch_metrics,
                           epoch_metrics=self.epoch_metrics)
Beispiel #14
0
    def test_tracking_N_layers_model_with_bias(self):
        self.num_layer = 4
        self.pytorch_network = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1))
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=self.lr)
        self.model = Model(self.pytorch_network, self.optimizer, self.loss_function)

        keep_bias = True
        train_gen = some_data_generator(20)
        valid_gen = some_data_generator(20)
        tracker = TensorBoardGradientTracker(self.writer, keep_bias=keep_bias)
        self.model.fit_generator(train_gen, valid_gen, epochs=self.num_epochs, steps_per_epoch=5, callbacks=[tracker])
        self._test_tracking(keep_bias)
Beispiel #15
0
class LRSchedulersTest(TestCase):
    batch_size = 20
    epochs = 10
    steps_per_epoch = 5

    def setUp(self):
        torch.manual_seed(42)
        self.pytorch_network = nn.Linear(1, 1)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(),
                                         lr=1e-3)
        self.model = Model(self.pytorch_network, self.optimizer,
                           self.loss_function)
        self.train_gen = some_data_generator(20)
        self.valid_gen = some_data_generator(20)

    def test_lambda_lr_integration(self):
        my_lambda = lambda epoch: 0.95**epoch
        lambda_lr = LambdaLR(lr_lambda=[my_lambda])
        self._fit_with_callback_integration(lambda_lr)

    def test_step_lr_integration(self):
        step_lr = StepLR(step_size=3)
        self._fit_with_callback_integration(step_lr)

    def test_multistep_lr_integration(self):
        multistep_lr = MultiStepLR(milestones=[2, 5, 7])
        self._fit_with_callback_integration(multistep_lr)

    def test_exponential_lr_integration(self):
        exponential_lr = ExponentialLR(gamma=0.01)
        self._fit_with_callback_integration(exponential_lr)

    def test_cosine_annealing_lr_integration(self):
        cosine_annealing_lr = CosineAnnealingLR(T_max=8)
        self._fit_with_callback_integration(cosine_annealing_lr)

    def test_reduce_lr_on_plateau_integration(self):
        reduce_lr = ReduceLROnPlateau(monitor='loss', patience=3)
        self._fit_with_callback_integration(reduce_lr)

    def _fit_with_callback_integration(self, callback):
        self.model.fit_generator(
            self.train_gen,
            self.valid_gen,
            epochs=LRSchedulersTest.epochs,
            steps_per_epoch=LRSchedulersTest.steps_per_epoch,
            callbacks=[callback],
        )

    def test_exception_is_thrown_on_optimizer_argument(self):
        with self.assertRaises(ValueError):
            StepLR(self.optimizer, step_size=3)
Beispiel #16
0
    def setUp(self):
        super().setUp()
        torch.manual_seed(42)
        self.pytorch_network = MultiIOModel(num_input=1, num_output=1)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(),
                                         lr=1e-3)

        self.model = Model(self.pytorch_network,
                           self.optimizer,
                           self.loss_function,
                           batch_metrics=self.batch_metrics,
                           epoch_metrics=self.epoch_metrics)
Beispiel #17
0
    def setUp(self):
        super().setUp()
        torch.manual_seed(42)
        self.pytorch_network = MultiIOModel(num_input=2, num_output=2)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=1e-3)

        self.model = Model(
            self.pytorch_network,
            self.optimizer,
            lambda y_pred, y_true: self.loss_function(y_pred[0], y_true[0]) + self.loss_function(y_pred[1], y_true[1]),
            batch_metrics=self.batch_metrics,
            epoch_metrics=self.epoch_metrics)
Beispiel #18
0
    def setUp(self):
        super().setUp()
        torch.manual_seed(42)
        self.pytorch_network = DictOutputModel()
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=1e-3)

        self.model = Model(
            self.pytorch_network,
            self.optimizer,
            lambda y_p, y_t: self.loss_function(y_p['out1'], y_t[0]) + self.loss_function(y_p['out2'], y_t[1]),
            batch_metrics=self.batch_metrics,
            epoch_metrics=self.epoch_metrics)
Beispiel #19
0
 def test_multiple_learning_rates(self):
     train_gen = some_data_generator(20)
     valid_gen = some_data_generator(20)
     logger = self.CSVLogger(self.csv_filename)
     lrs = [BaseCSVLoggerTest.lr, BaseCSVLoggerTest.lr / 2]
     optimizer = torch.optim.SGD(
         [dict(params=[self.pytorch_network.weight], lr=lrs[0]), dict(params=[self.pytorch_network.bias], lr=lrs[1])]
     )
     model = Model(self.pytorch_network, optimizer, self.loss_function)
     history = model.fit_generator(
         train_gen, valid_gen, epochs=self.num_epochs, steps_per_epoch=5, callbacks=[logger]
     )
     self._test_logging(history, lrs=lrs)
Beispiel #20
0
 def test_evaluate_with_only_one_metric(self):
     model = Model(self.pytorch_network,
                   self.optimizer,
                   self.loss_function,
                   batch_metrics=self.batch_metrics[:1])
     x = torch.rand(ModelTest.evaluate_dataset_len, 1)
     y = torch.rand(ModelTest.evaluate_dataset_len, 1)
     loss, first_metric = model.evaluate(x,
                                         y,
                                         batch_size=ModelTest.batch_size)
     self.assertEqual(type(loss), float)
     self.assertEqual(type(first_metric), float)
     self.assertEqual(first_metric, some_metric_1_value)
 def test_epoch_metrics_with_multiple_names_returned_by_tuple(self):
     epoch_metric = ConstEpochMetric(tuple(self.metric_values))
     model = Model(self.pytorch_network,
                   self.optimizer,
                   self.loss_function,
                   epoch_metrics=[(self.metric_names, epoch_metric)])
     self._test_history(model, self.metric_names, self.metric_values)
 def test_batch_metrics_with_multiple_names_returned_by_list(self):
     batch_metric = get_const_batch_metric(list(self.metric_values))
     model = Model(self.pytorch_network,
                   self.optimizer,
                   self.loss_function,
                   batch_metrics=[(self.metric_names, batch_metric)])
     self._test_history(model, self.metric_names, self.metric_values)
Beispiel #23
0
 def setUp(self):
     super().setUp()
     torch.manual_seed(42)
     self.pytorch_network = nn.Sequential(nn.Flatten(),
                                          nn.Linear(28 * 28, 10))
     self.batch_metrics = ['accuracy']
     self.batch_metrics_names = ['acc']
     self.batch_metrics_values = [ANY]
     self.epoch_metrics = ['f1']
     self.epoch_metrics_names = ['fscore_micro']
     self.epoch_metrics_values = [ANY]
     self.model = Model(self.pytorch_network,
                        'sgd',
                        'cross_entropy',
                        batch_metrics=self.batch_metrics,
                        epoch_metrics=self.epoch_metrics)
Beispiel #24
0
    def test_correct_optim_calls_1_batch_per_step(self):
        train_generator = some_data_tensor_generator(ModelTest.batch_size)

        mocked_optimizer = some_mocked_optimizer()
        mocked_optim_model = Model(self.pytorch_network,
                                   mocked_optimizer,
                                   self.loss_function,
                                   batch_metrics=self.batch_metrics,
                                   epoch_metrics=self.epoch_metrics)
        mocked_optim_model.fit_generator(train_generator,
                                         None,
                                         epochs=1,
                                         steps_per_epoch=1,
                                         batches_per_step=1)

        self.assertEqual(1, mocked_optimizer.step.call_count)
        self.assertEqual(1, mocked_optimizer.zero_grad.call_count)
Beispiel #25
0
 def test_epoch_metrics_integration(self):
     model = Model(self.pytorch_network,
                   self.optimizer,
                   self.loss_function,
                   epoch_metrics=[SomeEpochMetric()])
     train_generator = some_data_tensor_generator(ModelTest.batch_size)
     valid_generator = some_data_tensor_generator(ModelTest.batch_size)
     logs = model.fit_generator(train_generator,
                                valid_generator,
                                epochs=1,
                                steps_per_epoch=ModelTest.steps_per_epoch,
                                validation_steps=ModelTest.steps_per_epoch)
     actual_value = logs[-1]['some_epoch_metric']
     val_actual_value = logs[-1]['val_some_epoch_metric']
     expected_value = 5
     self.assertEqual(val_actual_value, expected_value)
     self.assertEqual(actual_value, expected_value)
 def test_epoch_metrics_with_multiple_names_returned_by_dict(self):
     d = dict(zip(self.metric_names, self.metric_values))
     epoch_metric = ConstEpochMetric(d)
     model = Model(self.pytorch_network,
                   self.optimizer,
                   self.loss_function,
                   epoch_metrics=[(self.metric_names, epoch_metric)])
     self._test_history(model, d.keys(), d.values())
 def test_repeated_batch_epoch_metrics_handling(self):
     expected_names = ['some_metric_name1', 'some_metric_name2']
     model = Model(self.pytorch_network,
                   self.optimizer,
                   self.loss_function,
                   batch_metrics=[get_batch_metric(1)],
                   epoch_metrics=[SomeMetricName(2)])
     self._test_history(model, expected_names, [1, 2])
Beispiel #28
0
 def test_metrics_integration(self):
     num_steps = 10
     model = Model(self.pytorch_network,
                   self.optimizer,
                   self.loss_function,
                   batch_metrics=[F.mse_loss])
     train_generator = some_data_tensor_generator(ModelTest.batch_size)
     valid_generator = some_data_tensor_generator(ModelTest.batch_size)
     model.fit_generator(train_generator,
                         valid_generator,
                         epochs=ModelTest.epochs,
                         steps_per_epoch=ModelTest.steps_per_epoch,
                         validation_steps=ModelTest.steps_per_epoch,
                         callbacks=[self.mock_callback])
     generator = some_data_tensor_generator(ModelTest.batch_size)
     loss, mse = model.evaluate_generator(generator, steps=num_steps)
     self.assertEqual(type(loss), float)
     self.assertEqual(type(mse), float)
Beispiel #29
0
    def setUp(self) -> None:
        super().setUp()
        self.notification_callback_mock = MagicMock()
        self.notificator_mock = MagicMock()

        self.train_generator = some_data_tensor_generator(NotificationCallbackTest.batch_size)
        self.valid_generator = some_data_tensor_generator(NotificationCallbackTest.batch_size)

        torch.manual_seed(42)
        self.pytorch_network = nn.Linear(1, 1)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=NotificationCallbackTest.lr)

        self.batch_metrics = [
            some_batch_metric_1,
            ('custom_name', some_batch_metric_2),
            repeat_batch_metric,
            repeat_batch_metric,
        ]
        self.batch_metrics_names = [
            'some_batch_metric_1',
            'custom_name',
            'repeat_batch_metric1',
            'repeat_batch_metric2',
        ]
        self.batch_metrics_values = [
            some_metric_1_value,
            some_metric_2_value,
            repeat_batch_metric_value,
            repeat_batch_metric_value,
        ]
        self.epoch_metrics = [SomeConstantEpochMetric()]
        self.epoch_metrics_names = ['some_constant_epoch_metric']
        self.epoch_metrics_values = [some_constant_epoch_metric_value]

        self.model = Model(
            self.pytorch_network,
            self.optimizer,
            self.loss_function,
            batch_metrics=self.batch_metrics,
            epoch_metrics=self.epoch_metrics,
        )
Beispiel #30
0
class BaseTensorBoardLoggerTest:
    SummaryWriter = None
    batch_size = 20
    lr = 1e-3
    num_epochs = 10

    def setUp(self):
        torch.manual_seed(42)
        self.pytorch_network = nn.Linear(1, 1)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=BaseTensorBoardLoggerTest.lr)
        self.model = Model(self.pytorch_network, self.optimizer, self.loss_function)
        self.temp_dir_obj = TemporaryDirectory()
        # pylint: disable=not-callable
        self.writer = self.SummaryWriter(self.temp_dir_obj.name)
        self.writer.add_scalars = MagicMock()

    def tearDown(self):
        self.temp_dir_obj.cleanup()

    def test_logging(self):
        train_gen = some_data_generator(20)
        valid_gen = some_data_generator(20)
        logger = TensorBoardLogger(self.writer)
        history = self.model.fit_generator(
            train_gen, valid_gen, epochs=self.num_epochs, steps_per_epoch=5, callbacks=[logger]
        )
        self._test_logging(history)

    def test_multiple_learning_rates(self):
        train_gen = some_data_generator(20)
        valid_gen = some_data_generator(20)
        logger = TensorBoardLogger(self.writer)
        lrs = [BaseCSVLoggerTest.lr, BaseCSVLoggerTest.lr / 2]
        optimizer = torch.optim.SGD(
            [dict(params=[self.pytorch_network.weight], lr=lrs[0]), dict(params=[self.pytorch_network.bias], lr=lrs[1])]
        )
        model = Model(self.pytorch_network, optimizer, self.loss_function)
        history = model.fit_generator(
            train_gen, valid_gen, epochs=self.num_epochs, steps_per_epoch=5, callbacks=[logger]
        )
        self._test_logging(history, lrs=lrs)

    def _test_logging(self, history, lrs=None):
        if lrs is None:
            lrs = [BaseCSVLoggerTest.lr]
        calls = []
        for h in history:
            calls.append(call('loss', {'loss': h['loss'], 'val_loss': h['val_loss']}, h['epoch']))
            if len(lrs) == 1:
                calls.append(call('lr', {'lr': self.lr}, h['epoch']))
            else:
                calls.append(call('lr', {f'lr_group_{i}': lr for i, lr in enumerate(lrs)}, h['epoch']))
        self.writer.add_scalars.assert_has_calls(calls, any_order=True)