コード例 #1
0
 def test_metrics_integration(self):
     num_steps = 10
     model = Model(self.pytorch_network,
                   self.optimizer,
                   self.loss_function,
                   batch_metrics=[F.mse_loss])
     train_generator = some_data_tensor_generator(ModelTest.batch_size)
     valid_generator = some_data_tensor_generator(ModelTest.batch_size)
     model.fit_generator(train_generator,
                         valid_generator,
                         epochs=ModelTest.epochs,
                         steps_per_epoch=ModelTest.steps_per_epoch,
                         validation_steps=ModelTest.steps_per_epoch,
                         callbacks=[self.mock_callback])
     generator = some_data_tensor_generator(ModelTest.batch_size)
     loss, mse = model.evaluate_generator(generator, steps=num_steps)
     self.assertEqual(type(loss), float)
     self.assertEqual(type(mse), float)
コード例 #2
0
class ModelMultiDictIOTest(ModelFittingTestCase):

    def setUp(self):
        super().setUp()
        torch.manual_seed(42)
        self.pytorch_network = DictIOModel(['x1', 'x2'], ['y1', 'y2'])
        self.loss_function = dict_mse_loss
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=1e-3)

        self.model = Model(self.pytorch_network,
                           self.optimizer,
                           self.loss_function,
                           batch_metrics=self.batch_metrics,
                           epoch_metrics=self.epoch_metrics)

    def test_fitting_tensor_generator_multi_dict_io(self):
        train_generator = some_data_tensor_generator_dict_io(ModelMultiDictIOTest.batch_size)
        valid_generator = some_data_tensor_generator_dict_io(ModelMultiDictIOTest.batch_size)
        logs = self.model.fit_generator(train_generator,
                                        valid_generator,
                                        epochs=ModelMultiDictIOTest.epochs,
                                        steps_per_epoch=ModelMultiDictIOTest.steps_per_epoch,
                                        validation_steps=ModelMultiDictIOTest.steps_per_epoch,
                                        callbacks=[self.mock_callback])
        params = {
            'epochs': ModelMultiDictIOTest.epochs,
            'steps': ModelMultiDictIOTest.steps_per_epoch,
            'valid_steps': ModelMultiDictIOTest.steps_per_epoch
        }
        self._test_callbacks_train(params, logs, valid_steps=ModelMultiDictIOTest.steps_per_epoch)

    def test_tensor_train_on_batch_multi_dict_io(self):
        x, y = get_batch(ModelMultiDictIOTest.batch_size)
        loss = self.model.train_on_batch(x, y)
        self.assertEqual(type(loss), float)

    def test_train_on_batch_with_pred_multi_dict_io(self):
        x, y = get_batch(ModelMultiDictIOTest.batch_size)
        loss, pred_y = self.model.train_on_batch(x, y, return_pred=True)
        self.assertEqual(type(loss), float)
        for value in pred_y.values():
            self.assertEqual(value.shape, (ModelMultiDictIOTest.batch_size, 1))

    def test_ndarray_train_on_batch_multi_dict_io(self):
        x1 = np.random.rand(ModelMultiDictIOTest.batch_size, 1).astype(np.float32)
        x2 = np.random.rand(ModelMultiDictIOTest.batch_size, 1).astype(np.float32)
        y1 = np.random.rand(ModelMultiDictIOTest.batch_size, 1).astype(np.float32)
        y2 = np.random.rand(ModelMultiDictIOTest.batch_size, 1).astype(np.float32)
        x, y = dict(x1=x1, x2=x2), dict(y1=y1, y2=y2)
        loss = self.model.train_on_batch(x, y)
        self.assertEqual(type(loss), float)

    def test_evaluate_generator_multi_dict_io(self):
        num_steps = 10
        generator = some_data_tensor_generator_dict_io(ModelMultiDictIOTest.batch_size)
        loss, pred_y = self.model.evaluate_generator(generator, steps=num_steps, return_pred=True)
        self.assertEqual(type(loss), float)
        self._test_size_and_type_for_generator(pred_y, (num_steps * ModelMultiDictIOTest.batch_size, 1))

    def test_tensor_evaluate_on_batch_multi_dict_io(self):
        x, y = get_batch(ModelMultiDictIOTest.batch_size)
        loss = self.model.evaluate_on_batch(x, y)
        self.assertEqual(type(loss), float)

    def test_predict_generator_multi_dict_io(self):
        num_steps = 10
        generator = some_data_tensor_generator_dict_io(ModelMultiDictIOTest.batch_size)
        generator = (x for x, _ in generator)
        pred_y = self.model.predict_generator(generator, steps=num_steps)
        self._test_size_and_type_for_generator(pred_y, (num_steps * ModelMultiDictIOTest.batch_size, 1))

    def test_tensor_predict_on_batch_multi_dict_io(self):
        x1 = torch.rand(ModelMultiDictIOTest.batch_size, 1)
        x2 = torch.rand(ModelMultiDictIOTest.batch_size, 1)
        pred_y = self.model.predict_on_batch(dict(x1=x1, x2=x2))
        self._test_size_and_type_for_generator(pred_y, (ModelMultiDictIOTest.batch_size, 1))
コード例 #3
0
from poutyne import Model

model = Model(network,
              'sgd',
              'cross_entropy',
              batch_metrics=['accuracy'],
              epoch_metrics=['f1'])
model.to(device)

model.fit_generator(train_loader,
                    valid_loader,
                    epochs=num_epochs,
                    callbacks=callbacks)

test_loss, (test_acc, test_f1) = model.evaluate_generator(test_loader)
print(f'Test: Loss: {test_loss}, Accuracy: {test_acc}, F1: {test_f1}')
コード例 #4
0
class LambdaTest(ModelFittingTestCase):
    def setUp(self):
        super().setUp()
        torch.manual_seed(42)
        self.pytorch_network = nn.Linear(1, 1)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.pytorch_network.parameters(),
                                          lr=1e-3)
        self.model = Model(self.pytorch_network, self.optimizer,
                           self.loss_function)

    def test_integration_zero_args(self):
        lambda_callback = LambdaCallback()

        train_generator = some_data_tensor_generator(LambdaTest.batch_size)
        valid_generator = some_data_tensor_generator(LambdaTest.batch_size)
        test_generator = some_data_tensor_generator(LambdaTest.batch_size)
        self.model.fit_generator(
            train_generator,
            valid_generator,
            epochs=LambdaTest.epochs,
            steps_per_epoch=LambdaTest.steps_per_epoch,
            validation_steps=LambdaTest.steps_per_epoch,
            callbacks=[lambda_callback],
        )

        num_steps = 10
        self.model.evaluate_generator(test_generator,
                                      steps=num_steps,
                                      callbacks=[lambda_callback])

    def test_with_only_on_epoch_end_arg(self):
        on_epoch_end = Mock()
        lambda_callback = LambdaCallback(on_epoch_end=on_epoch_end)

        train_generator = some_data_tensor_generator(LambdaTest.batch_size)
        valid_generator = some_data_tensor_generator(LambdaTest.batch_size)
        test_generator = some_data_tensor_generator(LambdaTest.batch_size)
        logs = self.model.fit_generator(
            train_generator,
            valid_generator,
            epochs=LambdaTest.epochs,
            steps_per_epoch=LambdaTest.steps_per_epoch,
            validation_steps=LambdaTest.steps_per_epoch,
            callbacks=[lambda_callback],
        )

        num_steps = 10
        self.model.evaluate_generator(test_generator,
                                      steps=num_steps,
                                      callbacks=[lambda_callback])

        expected_calls = [
            call(epoch_number, log)
            for epoch_number, log in enumerate(logs, 1)
        ]
        actual_calls = on_epoch_end.mock_calls
        self.assertEqual(len(expected_calls), len(actual_calls))
        self.assertEqual(expected_calls, actual_calls)

    def test_lambda_test_calls(self):
        lambda_callback, mock_calls = self._get_lambda_callback_with_mock_args(
        )
        num_steps = 10
        generator = some_data_tensor_generator(LambdaTest.batch_size)
        self.model.evaluate_generator(
            generator,
            steps=num_steps,
            callbacks=[lambda_callback, self.mock_callback])

        expected_calls = self.mock_callback.method_calls[2:]
        actual_calls = mock_calls.method_calls
        self.assertEqual(len(expected_calls), len(actual_calls))
        self.assertEqual(expected_calls, actual_calls)

    def test_lambda_train_calls(self):
        lambda_callback, mock_calls = self._get_lambda_callback_with_mock_args(
        )
        train_generator = some_data_tensor_generator(LambdaTest.batch_size)
        valid_generator = some_data_tensor_generator(LambdaTest.batch_size)
        self.model.fit_generator(
            train_generator,
            valid_generator,
            epochs=LambdaTest.epochs,
            steps_per_epoch=LambdaTest.steps_per_epoch,
            validation_steps=LambdaTest.steps_per_epoch,
            callbacks=[lambda_callback, self.mock_callback],
        )

        expected_calls = self.mock_callback.method_calls[2:]
        actual_calls = mock_calls.method_calls
        self.assertEqual(len(expected_calls), len(actual_calls))
        self.assertEqual(expected_calls, actual_calls)

    def _get_lambda_callback_with_mock_args(self):
        mock_callback = Mock(spec=Callback())
        lambda_callback = LambdaCallback(
            on_epoch_begin=mock_callback.on_epoch_begin,
            on_epoch_end=mock_callback.on_epoch_end,
            on_train_batch_begin=mock_callback.on_train_batch_begin,
            on_train_batch_end=mock_callback.on_train_batch_end,
            on_valid_batch_begin=mock_callback.on_valid_batch_begin,
            on_valid_batch_end=mock_callback.on_valid_batch_end,
            on_test_batch_begin=mock_callback.on_test_batch_begin,
            on_test_batch_end=mock_callback.on_test_batch_end,
            on_train_begin=mock_callback.on_train_begin,
            on_train_end=mock_callback.on_train_end,
            on_valid_begin=mock_callback.on_valid_begin,
            on_valid_end=mock_callback.on_valid_end,
            on_test_begin=mock_callback.on_test_begin,
            on_test_end=mock_callback.on_test_end,
            on_backward_end=mock_callback.on_backward_end,
        )
        return lambda_callback, mock_callback
コード例 #5
0
class ModelMultiInputTest(ModelFittingTestCase):
    def setUp(self):
        super().setUp()
        torch.manual_seed(42)
        self.pytorch_network = MultiIOModel(num_input=1, num_output=1)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(),
                                         lr=1e-3)

        self.model = Model(self.pytorch_network,
                           self.optimizer,
                           self.loss_function,
                           batch_metrics=self.batch_metrics,
                           epoch_metrics=self.epoch_metrics)

    def test_fitting_tensor_generator_multi_input(self):
        train_generator = some_data_tensor_generator_multi_input(
            ModelMultiInputTest.batch_size)
        valid_generator = some_data_tensor_generator_multi_input(
            ModelMultiInputTest.batch_size)
        logs = self.model.fit_generator(
            train_generator,
            valid_generator,
            epochs=ModelMultiInputTest.epochs,
            steps_per_epoch=ModelMultiInputTest.steps_per_epoch,
            validation_steps=ModelMultiInputTest.steps_per_epoch,
            callbacks=[self.mock_callback])
        params = {
            'epochs': ModelMultiInputTest.epochs,
            'steps': ModelMultiInputTest.steps_per_epoch
        }
        self._test_callbacks_train(params, logs)

    def test_fitting_with_tensor_multi_input(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelMultiInputTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - \
                     train_final_batch_missing_samples
        train_x = (torch.rand(train_size, 1), torch.rand(train_size, 1))
        train_y = torch.rand(train_size, 1)

        valid_real_steps_per_epoch = 10
        # valid_batch_size will be the same as train_batch_size in the fit method.
        valid_batch_size = train_batch_size
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - \
                     valid_final_batch_missing_samples
        valid_x = (torch.rand(valid_size, 1), torch.rand(valid_size, 1))
        valid_y = torch.rand(valid_size, 1)

        logs = self.model.fit(train_x,
                              train_y,
                              validation_data=(valid_x, valid_y),
                              epochs=ModelMultiInputTest.epochs,
                              batch_size=train_batch_size,
                              steps_per_epoch=None,
                              validation_steps=None,
                              callbacks=[self.mock_callback])
        params = {
            'epochs': ModelMultiInputTest.epochs,
            'steps': train_real_steps_per_epoch
        }
        self._test_callbacks_train(params, logs)

    def test_tensor_train_on_batch_multi_input(self):
        x1 = torch.rand(ModelMultiInputTest.batch_size, 1)
        x2 = torch.rand(ModelMultiInputTest.batch_size, 1)
        y = torch.rand(ModelMultiInputTest.batch_size, 1)
        loss = self.model.train_on_batch((x1, x2), y)
        self.assertEqual(type(loss), float)

    def test_train_on_batch_with_pred_multi_input(self):
        x1 = torch.rand(ModelMultiInputTest.batch_size, 1)
        x2 = torch.rand(ModelMultiInputTest.batch_size, 1)
        y = torch.rand(ModelMultiInputTest.batch_size, 1)
        loss, pred_y = self.model.train_on_batch((x1, x2), y, return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(pred_y.shape, (ModelMultiInputTest.batch_size, 1))

    def test_ndarray_train_on_batch_multi_input(self):
        x1 = np.random.rand(ModelMultiInputTest.batch_size,
                            1).astype(np.float32)
        x2 = np.random.rand(ModelMultiInputTest.batch_size,
                            1).astype(np.float32)
        y = np.random.rand(ModelMultiInputTest.batch_size,
                           1).astype(np.float32)
        loss = self.model.train_on_batch((x1, x2), y)
        self.assertEqual(type(loss), float)

    def test_evaluate_multi_input(self):
        x = (torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1),
             torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1))
        y = torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1)
        loss = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelMultiInputTest.batch_size)
        self.assertEqual(type(loss), float)

    def test_evaluate_with_pred_multi_input(self):
        x = (torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1),
             torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1))
        y = torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1)
        # We also test the unpacking.
        _, pred_y = self.model.evaluate(
            x, y, batch_size=ModelMultiInputTest.batch_size, return_pred=True)
        self.assertEqual(pred_y.shape,
                         (ModelMultiInputTest.evaluate_dataset_len, 1))

    def test_evaluate_with_np_array_multi_input(self):
        x1 = np.random.rand(ModelMultiInputTest.evaluate_dataset_len,
                            1).astype(np.float32)
        x2 = np.random.rand(ModelMultiInputTest.evaluate_dataset_len,
                            1).astype(np.float32)
        x = (x1, x2)
        y = np.random.rand(ModelMultiInputTest.evaluate_dataset_len,
                           1).astype(np.float32)
        loss, pred_y = self.model.evaluate(
            x, y, batch_size=ModelMultiInputTest.batch_size, return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(pred_y.shape,
                         (ModelMultiInputTest.evaluate_dataset_len, 1))

    def test_evaluate_data_loader_multi_input(self):
        x1 = torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1)
        x2 = torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1)
        dataset = TensorDataset((x1, x2), y)
        generator = DataLoader(dataset, ModelMultiInputTest.batch_size)
        loss, pred_y = self.model.evaluate_generator(generator,
                                                     return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(pred_y.shape,
                         (ModelMultiInputTest.evaluate_dataset_len, 1))

    def test_evaluate_generator_multi_input(self):
        num_steps = 10
        generator = some_data_tensor_generator_multi_input(
            ModelMultiInputTest.batch_size)
        loss, pred_y = self.model.evaluate_generator(generator,
                                                     steps=num_steps,
                                                     return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(pred_y.shape,
                         (num_steps * ModelMultiInputTest.batch_size, 1))

    def test_tensor_evaluate_on_batch_multi_input(self):
        x1 = torch.rand(ModelMultiInputTest.batch_size, 1)
        x2 = torch.rand(ModelMultiInputTest.batch_size, 1)
        y = torch.rand(ModelMultiInputTest.batch_size, 1)
        loss = self.model.evaluate_on_batch((x1, x2), y)
        self.assertEqual(type(loss), float)

    def test_predict_multi_input(self):
        x = (torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1),
             torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1))
        pred_y = self.model.predict(x,
                                    batch_size=ModelMultiInputTest.batch_size)
        self.assertEqual(pred_y.shape,
                         (ModelMultiInputTest.evaluate_dataset_len, 1))

    def test_predict_with_np_array_multi_input(self):
        x1 = np.random.rand(ModelMultiInputTest.evaluate_dataset_len,
                            1).astype(np.float32)
        x2 = np.random.rand(ModelMultiInputTest.evaluate_dataset_len,
                            1).astype(np.float32)
        x = (x1, x2)
        pred_y = self.model.predict(x,
                                    batch_size=ModelMultiInputTest.batch_size)
        self.assertEqual(pred_y.shape,
                         (ModelMultiInputTest.evaluate_dataset_len, 1))

    def test_predict_generator_multi_input(self):
        num_steps = 10
        generator = some_data_tensor_generator_multi_input(
            ModelMultiInputTest.batch_size)
        generator = (x for x, _ in generator)
        pred_y = self.model.predict_generator(generator, steps=num_steps)
        self.assertEqual(type(pred_y), np.ndarray)
        self.assertEqual(pred_y.shape,
                         (num_steps * ModelMultiInputTest.batch_size, 1))

    def test_tensor_predict_on_batch_multi_input(self):
        x1 = torch.rand(ModelMultiInputTest.batch_size, 1)
        x2 = torch.rand(ModelMultiInputTest.batch_size, 1)
        pred_y = self.model.predict_on_batch((x1, x2))
        self.assertEqual(pred_y.shape, (ModelMultiInputTest.batch_size, 1))
コード例 #6
0
from poutyne import Model

model = Model(network, optimizer, loss_function)
model.to(device)

model.fit_generator(train_loader,
                    valid_loader,
                    epochs=num_epochs,
                    callbacks=callbacks)

test_loss = model.evaluate_generator(test_loader)
コード例 #7
0
class ModelFittingTestCaseProgress(ModelFittingTestCase):
    # pylint: disable=too-many-public-methods
    num_steps = 5

    def setUp(self):
        super().setUp()
        self.train_generator = some_data_tensor_generator(ModelFittingTestCase.batch_size)
        self.valid_generator = some_data_tensor_generator(ModelFittingTestCase.batch_size)
        self.test_generator = some_data_tensor_generator(ModelFittingTestCase.batch_size)
        torch.manual_seed(42)
        self.pytorch_network = nn.Linear(1, 1)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=1e-3)
        self.batch_metrics = [
            some_batch_metric_1, ('custom_name', some_batch_metric_2), repeat_batch_metric, repeat_batch_metric
        ]
        self.batch_metrics_names = [
            'some_batch_metric_1', 'custom_name', 'repeat_batch_metric1', 'repeat_batch_metric2'
        ]
        self.batch_metrics_values = [
            some_metric_1_value, some_metric_2_value, repeat_batch_metric_value, repeat_batch_metric_value
        ]
        self.epoch_metrics = [SomeConstantEpochMetric()]
        self.epoch_metrics_names = ['some_constant_epoch_metric']
        self.epoch_metrics_values = [some_constant_epoch_metric_value]

        self.model = Model(self.pytorch_network,
                           self.optimizer,
                           self.loss_function,
                           batch_metrics=self.batch_metrics,
                           epoch_metrics=self.epoch_metrics)

        self._capture_output()

    def assertStdoutContains(self, values):
        for value in values:
            self.assertIn(value, self.test_out.getvalue().strip())

    def assertStdoutNotContains(self, values):
        for value in values:
            self.assertNotIn(value, self.test_out.getvalue().strip())

    @skipIf(color is None, "Unable to import colorama")
    def test_fitting_with_default_coloring(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback])

        self.assertStdoutContains(["[32m", "[35m", "[36m", "[94m"])

    def test_fitting_with_progress_bar_show_epoch(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback])

        self.assertStdoutContains(["Epoch", "1/5", "2/5"])

    def test_fitting_with_progress_bar_show_steps(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback])

        self.assertStdoutContains(["steps", f"{ModelFittingTestCase.steps_per_epoch}"])

    def test_fitting_with_progress_bar_show_train_val_final_steps(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback])

        self.assertStdoutContains(["Val steps", "Train steps"])

    def test_fitting_with_no_progress_bar_dont_show_epoch(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     verbose=False)

        self.assertStdoutNotContains(["Epoch", "1/5", "2/5"])

    @skipIf(color is None, "Unable to import colorama")
    def test_fitting_with_user_coloring(self):
        coloring = {
            "text_color": 'BLACK',
            "ratio_color": "BLACK",
            "metric_value_color": "BLACK",
            "time_color": "BLACK",
            "progress_bar_color": "BLACK"
        }
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=coloring))

        self.assertStdoutContains(["[30m"])

    @skipIf(color is None, "Unable to import colorama")
    def test_fitting_with_user_partial_coloring(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring={
                                         "text_color": 'BLACK',
                                         "ratio_color": "BLACK"
                                     }))

        self.assertStdoutContains(["[30m", "[32m", "[35m", "[94m"])

    def test_fitting_with_user_coloring_invalid(self):
        with self.assertRaises(KeyError):
            _ = self.model.fit_generator(self.train_generator,
                                         self.valid_generator,
                                         epochs=ModelFittingTestCase.epochs,
                                         steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                         validation_steps=ModelFittingTestCase.steps_per_epoch,
                                         callbacks=[self.mock_callback],
                                         progress_options=dict(coloring={"invalid_name": 'A COLOR'}))

    def test_fitting_with_no_coloring(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=False))

        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"])

    @skipIf(color is None, "Unable to import colorama")
    def test_fitting_with_progress_bar_default_color(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=True, progress_bar=True))

        self.assertStdoutContains(["%", "[32m", "[35m", "[36m", "[94m", "\u2588"])

    @skipIf(color is None, "Unable to import colorama")
    def test_fitting_with_progress_bar_user_color(self):
        coloring = {
            "text_color": 'BLACK',
            "ratio_color": "BLACK",
            "metric_value_color": "BLACK",
            "time_color": "BLACK",
            "progress_bar_color": "BLACK"
        }
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=coloring, progress_bar=True))

        self.assertStdoutContains(["%", "[30m", "\u2588"])

    def test_fitting_with_progress_bar_no_color(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=False, progress_bar=True))

        self.assertStdoutContains(["%", "\u2588"])
        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"])

    def test_fitting_with_no_progress_bar(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=False, progress_bar=False))

        self.assertStdoutNotContains(["%", "\u2588"])
        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"])

    def test_progress_bar_with_step_is_none(self):
        train_generator = SomeDataGeneratorUsingStopIteration(ModelFittingTestCase.batch_size, 10)
        valid_generator = SomeDataGeneratorUsingStopIteration(ModelFittingTestCase.batch_size, 10)
        _ = self.model.fit_generator(train_generator,
                                     valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     progress_options=dict(coloring=False, progress_bar=True))

        self.assertStdoutContains(["s/step"])
        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m", "\u2588", "%"])

    def test_evaluate_without_progress_output(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        _, _ = self.model.evaluate(x, y, batch_size=ModelFittingTestCase.batch_size, verbose=False)

        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"])

    @skipIf(color is None, "Unable to import colorama")
    def test_evaluate_with_default_coloring(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        _, _ = self.model.evaluate(x, y, batch_size=ModelFittingTestCase.batch_size)

        self.assertStdoutContains(["[32m", "[35m", "[36m", "[94m"])

    @skipIf(color is None, "Unable to import colorama")
    def test_evaluate_with_user_coloring(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        coloring = {
            "text_color": 'BLACK',
            "ratio_color": "BLACK",
            "metric_value_color": "BLACK",
            "time_color": "BLACK",
            "progress_bar_color": "BLACK"
        }

        _, _ = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelFittingTestCase.batch_size,
                                   progress_options=dict(coloring=coloring))

        self.assertStdoutContains(["[30m"])

    @skipIf(color is None, "Unable to import colorama")
    def test_evaluate_with_user_partial_coloring(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        _, _ = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelFittingTestCase.batch_size,
                                   progress_options=dict(coloring={
                                       "text_color": 'BLACK',
                                       "ratio_color": "BLACK"
                                   }))
        self.assertStdoutContains(["[30m", "[32m", "[35m", "[94m"])

    def test_evaluate_with_user_coloring_invalid(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        with self.assertRaises(KeyError):
            _, _ = self.model.evaluate(x,
                                       y,
                                       batch_size=ModelFittingTestCase.batch_size,
                                       progress_options=dict(coloring={"invalid_name": 'A COLOR'}))

    def test_evaluate_with_no_coloring(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        _, _ = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelFittingTestCase.batch_size,
                                   progress_options=dict(coloring=False))

        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"])

    @skipIf(color is None, "Unable to import colorama")
    def test_evaluate_with_progress_bar_default_coloring(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        _, _ = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelFittingTestCase.batch_size,
                                   progress_options=dict(coloring=True, progress_bar=True))

        self.assertStdoutContains(["%", "[32m", "[35m", "[36m", "[94m", "\u2588"])

    @skipIf(color is None, "Unable to import colorama")
    def test_evaluate_with_progress_bar_user_coloring(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        coloring = {
            "text_color": 'BLACK',
            "ratio_color": "BLACK",
            "metric_value_color": "BLACK",
            "time_color": "BLACK",
            "progress_bar_color": "BLACK"
        }

        _, _ = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelFittingTestCase.batch_size,
                                   progress_options=dict(coloring=coloring, progress_bar=True))

        self.assertStdoutContains(["%", "[30m", "\u2588"])

    @skipIf(color is None, "Unable to import colorama")
    def test_evaluate_with_progress_bar_user_no_color(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        _, _ = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelFittingTestCase.batch_size,
                                   progress_options=dict(coloring=False, progress_bar=True))

        self.assertStdoutContains(["%", "\u2588"])
        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"])

    def test_evaluate_with_no_progress_bar(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        _, _ = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelFittingTestCase.batch_size,
                                   progress_options=dict(coloring=False, progress_bar=False))

        self.assertStdoutNotContains(["%", "\u2588"])
        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"])

    def test_evaluate_data_loader_with_progress_bar_coloring(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        dataset = TensorDataset(x, y)
        generator = DataLoader(dataset, ModelFittingTestCase.batch_size)

        _, _ = self.model.evaluate_generator(generator, verbose=True)

        self.assertStdoutContains(["%", "[32m", "[35m", "[36m", "[94m", "\u2588"])

    def test_evaluate_generator_with_progress_bar_coloring(self):
        generator = some_data_tensor_generator(ModelFittingTestCase.batch_size)

        _, _ = self.model.evaluate_generator(generator, steps=ModelFittingTestCaseProgress.num_steps, verbose=True)

        self.assertStdoutContains(["%", "[32m", "[35m", "[36m", "[94m", "\u2588"])

    def test_evaluate_generator_with_callback_and_progress_bar_coloring(self):
        generator = some_data_tensor_generator(ModelFittingTestCase.batch_size)

        _, _ = self.model.evaluate_generator(generator,
                                             steps=ModelFittingTestCaseProgress.num_steps,
                                             callbacks=[self.mock_callback],
                                             verbose=True)

        self.assertStdoutContains(["%", "[32m", "[35m", "[36m", "[94m", "\u2588"])

    def test_fitting_complete_display_test_with_progress_bar_coloring(self):
        # we use the same color for all components for simplicity
        coloring = {
            "text_color": 'WHITE',
            "ratio_color": "WHITE",
            "metric_value_color": "WHITE",
            "time_color": "WHITE",
            "progress_bar_color": "WHITE"
        }
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=1,
                                     steps_per_epoch=ModelFittingTestCaseProgress.num_steps,
                                     validation_steps=ModelFittingTestCaseProgress.num_steps,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=coloring, progress_bar=False))

        # We split per step update
        steps_update = self.test_out.getvalue().strip().split("\r")

        # we don't validate the templating of metrics since tested before
        template_format = r".*Epoch:.*{}\/1.*\[37mStep:.*{}\/5.*{:6.2f}\%.*|{}|.*ETA:"
        epoch = 1
        # the 5 train steps
        for step, step_update in enumerate(steps_update[:ModelFittingTestCaseProgress.num_steps]):
            step += 1
            progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2)
            regex_filled = template_format.format(epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100,
                                                  progress_Bar)
            self.assertRegex(step_update, regex_filled)

        # The 5 val steps
        for step, step_update in enumerate(steps_update[ModelFittingTestCaseProgress.num_steps:-1]):
            step += 1
            progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2)
            regex_filled = template_format.format(epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100,
                                                  progress_Bar)
            self.assertRegex(step_update, regex_filled)

        # last print update templating different
        last_print_regex = r".*\[37mTrain steps:.*5.*Val steps:.*5.*[0-9]*\.[0-9][0-9]s"
        self.assertRegex(steps_update[-1], last_print_regex)

    def test_fitting_complete_display_test_with_progress_bar_no_coloring(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=1,
                                     steps_per_epoch=ModelFittingTestCaseProgress.num_steps,
                                     validation_steps=ModelFittingTestCaseProgress.num_steps,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=False, progress_bar=True))

        # We split per step update
        steps_update = self.test_out.getvalue().strip().split("\r")

        # we don't validate the templating of metrics since tested before
        template_format = r".*Epoch:.*{}\/1.*Step:.*{}\/5.*{:6.2f}\%.*|{}|.*ETA:"
        epoch = 1
        # the 5 train steps
        for step, step_update in enumerate(steps_update[:ModelFittingTestCaseProgress.num_steps]):
            step += 1
            progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2)
            regex_filled = template_format.format(epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100,
                                                  progress_Bar)
            self.assertRegex(step_update, regex_filled)

        # The 5 val steps
        for step, step_update in enumerate(steps_update[ModelFittingTestCaseProgress.num_steps:-1]):
            step += 1
            progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2)
            regex_filled = template_format.format(epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100,
                                                  progress_Bar)
            self.assertRegex(step_update, regex_filled)

        # last print update templating different
        last_print_regex = r".*Train steps:.*5.*Val steps:.*5.*[0-9]*\.[0-9][0-9]s"
        self.assertRegex(steps_update[-1], last_print_regex)

    def test_fitting_complete_display_test_with_no_progress_bar_no_coloring(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=1,
                                     steps_per_epoch=ModelFittingTestCaseProgress.num_steps,
                                     validation_steps=ModelFittingTestCaseProgress.num_steps,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=False, progress_bar=False))

        # We split per step update
        steps_update = self.test_out.getvalue().strip().split("\r")

        # we don't validate the templating of metrics since tested before
        template_format = r".*Epoch:.*{}\/1.*Step:.*{}\/5.*ETA:"
        epoch = 1
        # the 5 train steps
        for step, step_update in enumerate(steps_update[:ModelFittingTestCaseProgress.num_steps]):
            step += 1
            regex_filled = template_format.format(epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100)
            self.assertRegex(step_update, regex_filled)

        # The 5 val steps
        for step, step_update in enumerate(steps_update[ModelFittingTestCaseProgress.num_steps:-1]):
            step += 1
            regex_filled = template_format.format(epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100)
            self.assertRegex(step_update, regex_filled)

        # last print update templating different
        last_print_regex = r".*Train steps:.*5.*Val steps:.*5.*[0-9]*\.[0-9][0-9]s"
        self.assertRegex(steps_update[-1], last_print_regex)

    def test_evaluate_complete_display_test_with_progress_bar_coloring(self):
        # we use the same color for all components for simplicity
        coloring = {
            "text_color": 'WHITE',
            "ratio_color": "WHITE",
            "metric_value_color": "WHITE",
            "time_color": "WHITE",
            "progress_bar_color": "WHITE"
        }

        _, _ = self.model.evaluate_generator(self.test_generator,
                                             steps=ModelFittingTestCaseProgress.num_steps,
                                             callbacks=[self.mock_callback],
                                             verbose=True,
                                             progress_options=dict(coloring=coloring, progress_bar=True))

        # We split per step update
        steps_update = self.test_out.getvalue().strip().split("\r")

        # we don't validate the templating of metrics since tested before
        template_format = r".*\[37mStep:.*{}\/5.*{:6.2f}\%.*|{}|.*ETA:"
        for step, step_update in enumerate(steps_update[:-1]):
            step += 1
            progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2)
            regex_filled = template_format.format(step, step / ModelFittingTestCaseProgress.num_steps * 100,
                                                  progress_Bar)
            self.assertRegex(step_update, regex_filled)

        # last print update templating different
        last_print_regex = r".*\[37mTest steps:.*5.*[0-9]*\.[0-9][0-9]s"
        self.assertRegex(steps_update[-1], last_print_regex)

    def test_evaluate_complete_display_test_with_progress_bar_no_coloring(self):
        _, _ = self.model.evaluate_generator(self.test_generator,
                                             steps=ModelFittingTestCaseProgress.num_steps,
                                             callbacks=[self.mock_callback],
                                             verbose=True,
                                             progress_options=dict(coloring=False, progress_bar=True))

        # We split per step update
        steps_update = self.test_out.getvalue().strip().split("\r")

        # we don't validate the templating of metrics since tested before
        template_format = r".*Step:.*{}\/5.*{:6.2f}\%.*|{}|.*ETA:"
        for step, step_update in enumerate(steps_update[:-1]):
            step += 1
            progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2)
            regex_filled = template_format.format(step, step / ModelFittingTestCaseProgress.num_steps * 100,
                                                  progress_Bar)
            self.assertRegex(step_update, regex_filled)

        # last print update templating different
        last_print_regex = r".*Test steps:.*5.*[0-9]*\.[0-9][0-9]s"
        self.assertRegex(steps_update[-1], last_print_regex)

    def test_evaluate_complete_display_test_with_no_progress_bar_no_coloring(self):
        _, _ = self.model.evaluate_generator(self.test_generator,
                                             steps=ModelFittingTestCaseProgress.num_steps,
                                             callbacks=[self.mock_callback],
                                             verbose=True,
                                             progress_options=dict(coloring=False, progress_bar=False))

        # We split per step update
        steps_update = self.test_out.getvalue().strip().split("\r")

        # we don't validate the templating of metrics since tested before
        template_format = r".*Step:.*{}\/5.*ETA:"
        for step, step_update in enumerate(steps_update[:-1]):
            step += 1
            regex_filled = template_format.format(step, step / ModelFittingTestCaseProgress.num_steps * 100)
            self.assertRegex(step_update, regex_filled)

        # last print update templating different
        last_print_regex = r".*Test steps:.*5.*[0-9]*\.[0-9][0-9]s"
        self.assertRegex(steps_update[-1], last_print_regex)
コード例 #8
0
class ModelTest(ModelFittingTestCase):
    # pylint: disable=too-many-public-methods

    def setUp(self):
        super().setUp()
        torch.manual_seed(42)
        self.pytorch_network = nn.Linear(1, 1)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.pytorch_network.parameters(),
                                          lr=1e-3)
        self.batch_metrics = [
            some_batch_metric_1, ('custom_name', some_batch_metric_2),
            repeat_batch_metric, repeat_batch_metric
        ]
        self.batch_metrics_names = [
            'some_batch_metric_1', 'custom_name', 'repeat_batch_metric1',
            'repeat_batch_metric2'
        ]
        self.batch_metrics_values = [
            some_metric_1_value, some_metric_2_value,
            repeat_batch_metric_value, repeat_batch_metric_value
        ]
        self.epoch_metrics = [SomeConstantEpochMetric()]
        self.epoch_metrics_names = ['some_constant_epoch_metric']
        self.epoch_metrics_values = [some_constant_epoch_metric_value]

        self.model = Model(self.pytorch_network,
                           self.optimizer,
                           self.loss_function,
                           batch_metrics=self.batch_metrics,
                           epoch_metrics=self.epoch_metrics)

    def test_fitting_tensor_generator(self):
        train_generator = some_data_tensor_generator(ModelTest.batch_size)
        valid_generator = some_data_tensor_generator(ModelTest.batch_size)
        logs = self.model.fit_generator(
            train_generator,
            valid_generator,
            epochs=ModelTest.epochs,
            steps_per_epoch=ModelTest.steps_per_epoch,
            validation_steps=ModelTest.steps_per_epoch,
            callbacks=[self.mock_callback])
        params = {
            'epochs': ModelTest.epochs,
            'steps': ModelTest.steps_per_epoch,
            'valid_steps': ModelTest.steps_per_epoch
        }
        self._test_callbacks_train(params,
                                   logs,
                                   valid_steps=ModelTest.steps_per_epoch)

    def test_fitting_without_valid_generator(self):
        train_generator = some_data_tensor_generator(ModelTest.batch_size)
        logs = self.model.fit_generator(
            train_generator,
            None,
            epochs=ModelTest.epochs,
            steps_per_epoch=ModelTest.steps_per_epoch,
            callbacks=[self.mock_callback])
        params = {
            'epochs': ModelTest.epochs,
            'steps': ModelTest.steps_per_epoch
        }
        self._test_callbacks_train(params, logs, has_valid=False)

    def test_correct_optim_calls_1_batch_per_step(self):
        train_generator = some_data_tensor_generator(ModelTest.batch_size)

        mocked_optimizer = some_mocked_optimizer()
        mocked_optim_model = Model(self.pytorch_network,
                                   mocked_optimizer,
                                   self.loss_function,
                                   batch_metrics=self.batch_metrics,
                                   epoch_metrics=self.epoch_metrics)
        mocked_optim_model.fit_generator(train_generator,
                                         None,
                                         epochs=1,
                                         steps_per_epoch=1,
                                         batches_per_step=1)

        self.assertEqual(1, mocked_optimizer.step.call_count)
        self.assertEqual(1, mocked_optimizer.zero_grad.call_count)

    def test_correct_optim_calls__valid_n_batches_per_step(self):
        n_batches = 5
        items_per_batch = int(ModelTest.batch_size / n_batches)

        x = torch.rand(n_batches, items_per_batch, 1)
        y = torch.rand(n_batches, items_per_batch, 1)

        mocked_optimizer = some_mocked_optimizer()
        mocked_optim_model = Model(self.pytorch_network,
                                   mocked_optimizer,
                                   self.loss_function,
                                   batch_metrics=self.batch_metrics,
                                   epoch_metrics=self.epoch_metrics)
        mocked_optim_model.fit_generator(list(zip(x, y)),
                                         None,
                                         epochs=1,
                                         batches_per_step=n_batches)

        self.assertEqual(1, mocked_optimizer.step.call_count)
        self.assertEqual(1, mocked_optimizer.zero_grad.call_count)

    def test_fitting_generator_n_batches_per_step(self):
        total_batch_size = 6

        x = torch.rand(1, total_batch_size, 1)
        y = torch.rand(1, total_batch_size, 1)

        initial_params = self.model.get_weight_copies()

        self.model.fit_generator(list(zip(x, y)),
                                 None,
                                 epochs=1,
                                 batches_per_step=1)

        expected_params = self.model.get_weight_copies()

        for mini_batch_size in [1, 2, 5]:
            self.model.set_weights(initial_params)

            n_batches_per_step = int(total_batch_size / mini_batch_size)

            x.resize_((n_batches_per_step, mini_batch_size, 1))
            y.resize_((n_batches_per_step, mini_batch_size, 1))

            self.model.fit_generator(list(zip(x, y)),
                                     None,
                                     epochs=1,
                                     batches_per_step=n_batches_per_step)

            returned_params = self.model.get_weight_copies()

            self.assertEqual(returned_params.keys(), expected_params.keys())
            for k in expected_params.keys():
                np.testing.assert_almost_equal(returned_params[k].numpy(),
                                               expected_params[k].numpy(),
                                               decimal=4)

    def test_fitting_generator_n_batches_per_step_higher_than_num_batches(
            self):
        total_batch_size = 6

        x = torch.rand(1, total_batch_size, 1)
        y = torch.rand(1, total_batch_size, 1)

        initial_params = self.model.get_weight_copies()

        self.model.fit_generator(list(zip(x, y)),
                                 None,
                                 epochs=1,
                                 batches_per_step=1)

        expected_params = self.model.get_weight_copies()

        self.model.set_weights(initial_params)

        self.model.fit_generator(list(zip(x, y)),
                                 None,
                                 epochs=1,
                                 batches_per_step=2)

        returned_params = self.model.get_weight_copies()

        self.assertEqual(returned_params.keys(), expected_params.keys())
        for k in expected_params.keys():
            np.testing.assert_almost_equal(returned_params[k].numpy(),
                                           expected_params[k].numpy(),
                                           decimal=4)

    def test_fitting_generator_n_batches_per_step_uneven_batches(self):
        total_batch_size = 6

        x = torch.rand(1, total_batch_size, 1)
        y = torch.rand(1, total_batch_size, 1)

        initial_params = self.model.get_weight_copies()

        self.model.fit_generator(list(zip(x, y)),
                                 None,
                                 epochs=1,
                                 batches_per_step=1)

        expected_params = self.model.get_weight_copies()

        x.squeeze_(dim=0)
        y.squeeze_(dim=0)

        uneven_chunk_sizes = [4, 5]

        for chunk_size in uneven_chunk_sizes:
            self.model.set_weights(initial_params)

            splitted_x = x.split(chunk_size)
            splitted_y = y.split(chunk_size)

            n_batches_per_step = ceil(total_batch_size / chunk_size)

            self.model.fit_generator(list(zip(splitted_x, splitted_y)),
                                     None,
                                     epochs=1,
                                     batches_per_step=n_batches_per_step)

            returned_params = self.model.get_weight_copies()

            self.assertEqual(returned_params.keys(), expected_params.keys())
            for k in expected_params.keys():
                np.testing.assert_almost_equal(returned_params[k].numpy(),
                                               expected_params[k].numpy(),
                                               decimal=4)

    def test_fitting_ndarray_generator(self):
        train_generator = some_ndarray_generator(ModelTest.batch_size)
        valid_generator = some_ndarray_generator(ModelTest.batch_size)
        logs = self.model.fit_generator(
            train_generator,
            valid_generator,
            epochs=ModelTest.epochs,
            steps_per_epoch=ModelTest.steps_per_epoch,
            validation_steps=ModelTest.steps_per_epoch,
            callbacks=[self.mock_callback])
        params = {
            'epochs': ModelTest.epochs,
            'steps': ModelTest.steps_per_epoch,
            'valid_steps': ModelTest.steps_per_epoch
        }
        self._test_callbacks_train(params,
                                   logs,
                                   valid_steps=ModelTest.steps_per_epoch)

    def test_fitting_with_data_loader(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - \
                     train_final_batch_missing_samples
        train_x = torch.rand(train_size, 1)
        train_y = torch.rand(train_size, 1)
        train_dataset = TensorDataset(train_x, train_y)
        train_generator = DataLoader(train_dataset, train_batch_size)

        valid_real_steps_per_epoch = 10
        valid_batch_size = 15
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - \
                     valid_final_batch_missing_samples
        valid_x = torch.rand(valid_size, 1)
        valid_y = torch.rand(valid_size, 1)
        valid_dataset = TensorDataset(valid_x, valid_y)
        valid_generator = DataLoader(valid_dataset, valid_batch_size)

        logs = self.model.fit_generator(train_generator,
                                        valid_generator,
                                        epochs=ModelTest.epochs,
                                        steps_per_epoch=None,
                                        validation_steps=None,
                                        callbacks=[self.mock_callback])
        params = {
            'epochs': ModelTest.epochs,
            'steps': train_real_steps_per_epoch,
            'valid_steps': valid_real_steps_per_epoch
        }
        self._test_callbacks_train(params, logs)

    def test_fitting_generator_calls(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - \
                     train_final_batch_missing_samples
        train_x = torch.rand(train_size, 1)
        train_y = torch.rand(train_size, 1)
        train_dataset = TensorDataset(train_x, train_y)
        train_generator = DataLoader(train_dataset, train_batch_size)

        valid_real_steps_per_epoch = 10
        valid_batch_size = 15
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - \
                     valid_final_batch_missing_samples
        valid_x = torch.rand(valid_size, 1)
        valid_y = torch.rand(valid_size, 1)
        valid_dataset = TensorDataset(valid_x, valid_y)
        valid_generator = DataLoader(valid_dataset, valid_batch_size)

        mock_train_generator = IterableMock(train_generator)
        mock_valid_generator = IterableMock(valid_generator)
        self.model.fit_generator(mock_train_generator,
                                 mock_valid_generator,
                                 epochs=ModelTest.epochs)
        expected_train_calls = ['__len__'] + \
                               (['__iter__'] + ['__next__'] * train_real_steps_per_epoch) * ModelTest.epochs
        expected_valid_calls = ['__len__'] + \
                               (['__iter__'] + ['__next__'] * valid_real_steps_per_epoch) * ModelTest.epochs
        self.assertEqual(mock_train_generator.calls, expected_train_calls)
        self.assertEqual(mock_valid_generator.calls, expected_valid_calls)

    def test_fitting_generator_calls_with_longer_validation_set(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - \
                     train_final_batch_missing_samples
        train_x = torch.rand(train_size, 1)
        train_y = torch.rand(train_size, 1)
        train_dataset = TensorDataset(train_x, train_y)
        train_generator = DataLoader(train_dataset, train_batch_size)

        valid_real_steps_per_epoch = 40
        valid_batch_size = 15
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - \
                     valid_final_batch_missing_samples
        valid_x = torch.rand(valid_size, 1)
        valid_y = torch.rand(valid_size, 1)
        valid_dataset = TensorDataset(valid_x, valid_y)
        valid_generator = DataLoader(valid_dataset, valid_batch_size)

        mock_train_generator = IterableMock(train_generator)
        mock_valid_generator = IterableMock(valid_generator)
        self.model.fit_generator(mock_train_generator,
                                 mock_valid_generator,
                                 epochs=ModelTest.epochs)
        expected_train_calls = ['__len__'] + \
                               (['__iter__'] + ['__next__'] * train_real_steps_per_epoch) * ModelTest.epochs
        expected_valid_calls = ['__len__'] + \
                               (['__iter__'] + ['__next__'] * valid_real_steps_per_epoch) * ModelTest.epochs
        self.assertEqual(mock_train_generator.calls, expected_train_calls)
        self.assertEqual(mock_valid_generator.calls, expected_valid_calls)

    def test_fitting_with_tensor(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - \
                     train_final_batch_missing_samples
        train_x = torch.rand(train_size, 1)
        train_y = torch.rand(train_size, 1)

        valid_real_steps_per_epoch = 10
        # valid_batch_size will be the same as train_batch_size in the fit method.
        valid_batch_size = train_batch_size
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - \
                     valid_final_batch_missing_samples
        valid_x = torch.rand(valid_size, 1)
        valid_y = torch.rand(valid_size, 1)

        logs = self.model.fit(train_x,
                              train_y,
                              validation_data=(valid_x, valid_y),
                              epochs=ModelTest.epochs,
                              batch_size=train_batch_size,
                              steps_per_epoch=None,
                              validation_steps=None,
                              callbacks=[self.mock_callback])
        params = {
            'epochs': ModelTest.epochs,
            'steps': train_real_steps_per_epoch,
            'valid_steps': valid_real_steps_per_epoch
        }
        self._test_callbacks_train(params, logs)

    def test_fitting_with_np_array(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - \
                     train_final_batch_missing_samples
        train_x = np.random.rand(train_size, 1).astype(np.float32)
        train_y = np.random.rand(train_size, 1).astype(np.float32)

        valid_real_steps_per_epoch = 10
        # valid_batch_size will be the same as train_batch_size in the fit method.
        valid_batch_size = train_batch_size
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - \
                     valid_final_batch_missing_samples
        valid_x = np.random.rand(valid_size, 1).astype(np.float32)
        valid_y = np.random.rand(valid_size, 1).astype(np.float32)

        logs = self.model.fit(train_x,
                              train_y,
                              validation_data=(valid_x, valid_y),
                              epochs=ModelTest.epochs,
                              batch_size=train_batch_size,
                              steps_per_epoch=None,
                              validation_steps=None,
                              callbacks=[self.mock_callback])
        params = {
            'epochs': ModelTest.epochs,
            'steps': train_real_steps_per_epoch,
            'valid_steps': valid_real_steps_per_epoch
        }
        self._test_callbacks_train(params, logs)

    def test_fitting_with_generator_with_len(self):
        train_real_steps_per_epoch = 30
        train_generator = SomeDataGeneratorWithLen(
            batch_size=ModelTest.batch_size,
            length=train_real_steps_per_epoch,
            num_missing_samples=7)
        valid_real_steps_per_epoch = 10
        valid_generator = SomeDataGeneratorWithLen(
            batch_size=15,
            length=valid_real_steps_per_epoch,
            num_missing_samples=3)
        logs = self.model.fit_generator(train_generator,
                                        valid_generator,
                                        epochs=ModelTest.epochs,
                                        steps_per_epoch=None,
                                        validation_steps=None,
                                        callbacks=[self.mock_callback])
        params = {
            'epochs': ModelTest.epochs,
            'steps': train_real_steps_per_epoch,
            'valid_steps': valid_real_steps_per_epoch
        }
        self._test_callbacks_train(params, logs)

    def test_fitting_with_generator_with_stop_iteration(self):
        train_real_steps_per_epoch = 30
        train_generator = SomeDataGeneratorUsingStopIteration(
            batch_size=ModelTest.batch_size, length=train_real_steps_per_epoch)
        valid_generator = SomeDataGeneratorUsingStopIteration(batch_size=15,
                                                              length=10)
        logs = self.model.fit_generator(train_generator,
                                        valid_generator,
                                        epochs=ModelTest.epochs,
                                        steps_per_epoch=None,
                                        validation_steps=None,
                                        callbacks=[self.mock_callback])
        params = {'epochs': ModelTest.epochs, 'steps': None}
        self._test_callbacks_train(params,
                                   logs,
                                   steps=train_real_steps_per_epoch)

    def test_tensor_train_on_batch(self):
        x = torch.rand(ModelTest.batch_size, 1)
        y = torch.rand(ModelTest.batch_size, 1)
        loss, metrics = self.model.train_on_batch(x, y)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(), self.batch_metrics_values)

    def test_train_on_batch_with_pred(self):
        x = torch.rand(ModelTest.batch_size, 1)
        y = torch.rand(ModelTest.batch_size, 1)
        loss, metrics, pred_y = self.model.train_on_batch(x,
                                                          y,
                                                          return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(), self.batch_metrics_values)
        self.assertEqual(pred_y.shape, (ModelTest.batch_size, 1))

    def test_ndarray_train_on_batch(self):
        x = np.random.rand(ModelTest.batch_size, 1).astype(np.float32)
        y = np.random.rand(ModelTest.batch_size, 1).astype(np.float32)
        loss, metrics = self.model.train_on_batch(x, y)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(), self.batch_metrics_values)

    def test_evaluate(self):
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelTest.evaluate_dataset_len, 1)
        loss, metrics = self.model.evaluate(x,
                                            y,
                                            batch_size=ModelTest.batch_size)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(),
                         self.batch_metrics_values + self.epoch_metrics_values)

    def test_evaluate_with_pred(self):
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelTest.evaluate_dataset_len, 1)
        # We also test the unpacking.
        _, _, pred_y = self.model.evaluate(x,
                                           y,
                                           batch_size=ModelTest.batch_size,
                                           return_pred=True)
        self.assertEqual(pred_y.shape, (ModelTest.evaluate_dataset_len, 1))

    def test_evaluate_with_callback(self):
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelTest.evaluate_dataset_len, 1)
        # We also test the unpacking.
        _, _, pred_y = self.model.evaluate(x,
                                           y,
                                           batch_size=ModelTest.batch_size,
                                           return_pred=True,
                                           callbacks=[self.mock_callback])
        self.assertEqual(pred_y.shape, (ModelTest.evaluate_dataset_len, 1))

    def test_evaluate_with_return_dict(self):
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelTest.evaluate_dataset_len, 1)
        logs = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelTest.batch_size,
                                   return_dict_format=True)

        self._test_return_dict_logs(logs)

    def test_evaluate_with_np_array(self):
        x = np.random.rand(ModelTest.evaluate_dataset_len,
                           1).astype(np.float32)
        y = np.random.rand(ModelTest.evaluate_dataset_len,
                           1).astype(np.float32)
        loss, metrics, pred_y = self.model.evaluate(
            x, y, batch_size=ModelTest.batch_size, return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(),
                         self.batch_metrics_values + self.epoch_metrics_values)
        self.assertEqual(pred_y.shape, (ModelTest.evaluate_dataset_len, 1))

    def test_evaluate_data_loader(self):
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelTest.evaluate_dataset_len, 1)
        dataset = TensorDataset(x, y)
        generator = DataLoader(dataset, ModelTest.batch_size)
        loss, metrics, pred_y = self.model.evaluate_generator(generator,
                                                              return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(),
                         self.batch_metrics_values + self.epoch_metrics_values)
        self.assertEqual(pred_y.shape, (ModelTest.evaluate_dataset_len, 1))

    def test_evaluate_generator(self):
        num_steps = 10
        generator = some_data_tensor_generator(ModelTest.batch_size)
        loss, metrics, pred_y = self.model.evaluate_generator(generator,
                                                              steps=num_steps,
                                                              return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(),
                         self.batch_metrics_values + self.epoch_metrics_values)
        self.assertEqual(type(pred_y), np.ndarray)
        self.assertEqual(pred_y.shape, (num_steps * ModelTest.batch_size, 1))

    def test_evaluate_generator_with_stop_iteration(self):
        test_generator = SomeDataGeneratorUsingStopIteration(
            ModelTest.batch_size, 10)

        loss, _ = self.model.evaluate_generator(test_generator)

        self.assertEqual(type(loss), float)

    def test_evaluate_generator_with_callback(self):
        num_steps = 10
        generator = some_data_tensor_generator(ModelTest.batch_size)
        self.model.evaluate_generator(generator,
                                      steps=num_steps,
                                      callbacks=[self.mock_callback])

        params = {'steps': ModelTest.epochs}
        self._test_callbacks_test(params)

    def test_evaluate_generator_with_return_dict(self):
        num_steps = 10
        generator = some_data_tensor_generator(ModelTest.batch_size)
        logs = self.model.evaluate_generator(generator,
                                             steps=num_steps,
                                             return_dict_format=True)

        self._test_return_dict_logs(logs)

    def test_evaluate_generator_with_ground_truth(self):
        num_steps = 10
        generator = some_data_tensor_generator(ModelTest.batch_size)
        loss, metrics, pred_y, true_y = self.model.evaluate_generator(
            generator,
            steps=num_steps,
            return_pred=True,
            return_ground_truth=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(),
                         self.batch_metrics_values + self.epoch_metrics_values)
        self.assertEqual(type(pred_y), np.ndarray)
        self.assertEqual(type(true_y), np.ndarray)
        self.assertEqual(pred_y.shape, (num_steps * ModelTest.batch_size, 1))
        self.assertEqual(true_y.shape, (num_steps * ModelTest.batch_size, 1))

    def test_evaluate_generator_with_no_concatenation(self):
        num_steps = 10
        generator = some_data_tensor_generator(ModelTest.batch_size)
        loss, metrics, pred_y, true_y = self.model.evaluate_generator(
            generator,
            steps=num_steps,
            return_pred=True,
            return_ground_truth=True,
            concatenate_returns=False)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(),
                         self.batch_metrics_values + self.epoch_metrics_values)

        self.assertEqual(type(pred_y), list)
        for pred in pred_y:
            self.assertEqual(type(pred), np.ndarray)
            self.assertEqual(pred.shape, (ModelTest.batch_size, 1))
        self.assertEqual(type(true_y), list)
        for true in true_y:
            self.assertEqual(type(true), np.ndarray)
            self.assertEqual(true.shape, (ModelTest.batch_size, 1))

    def test_evaluate_with_only_one_metric(self):
        model = Model(self.pytorch_network,
                      self.optimizer,
                      self.loss_function,
                      batch_metrics=self.batch_metrics[:1])
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelTest.evaluate_dataset_len, 1)
        loss, first_metric = model.evaluate(x,
                                            y,
                                            batch_size=ModelTest.batch_size)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(first_metric), float)
        self.assertEqual(first_metric, some_metric_1_value)

    def test_metrics_integration(self):
        num_steps = 10
        model = Model(self.pytorch_network,
                      self.optimizer,
                      self.loss_function,
                      batch_metrics=[F.mse_loss])
        train_generator = some_data_tensor_generator(ModelTest.batch_size)
        valid_generator = some_data_tensor_generator(ModelTest.batch_size)
        model.fit_generator(train_generator,
                            valid_generator,
                            epochs=ModelTest.epochs,
                            steps_per_epoch=ModelTest.steps_per_epoch,
                            validation_steps=ModelTest.steps_per_epoch,
                            callbacks=[self.mock_callback])
        generator = some_data_tensor_generator(ModelTest.batch_size)
        loss, mse = model.evaluate_generator(generator, steps=num_steps)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(mse), float)

    def test_epoch_metrics_integration(self):
        model = Model(self.pytorch_network,
                      self.optimizer,
                      self.loss_function,
                      epoch_metrics=[SomeEpochMetric()])
        train_generator = some_data_tensor_generator(ModelTest.batch_size)
        valid_generator = some_data_tensor_generator(ModelTest.batch_size)
        logs = model.fit_generator(train_generator,
                                   valid_generator,
                                   epochs=1,
                                   steps_per_epoch=ModelTest.steps_per_epoch,
                                   validation_steps=ModelTest.steps_per_epoch)
        actual_value = logs[-1]['some_epoch_metric']
        val_actual_value = logs[-1]['val_some_epoch_metric']
        expected_value = 5
        self.assertEqual(val_actual_value, expected_value)
        self.assertEqual(actual_value, expected_value)

    def test_evaluate_with_no_metric(self):
        model = Model(self.pytorch_network, self.optimizer, self.loss_function)
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelTest.evaluate_dataset_len, 1)
        loss = model.evaluate(x, y, batch_size=ModelTest.batch_size)
        self.assertEqual(type(loss), float)

    def test_tensor_evaluate_on_batch(self):
        x = torch.rand(ModelTest.batch_size, 1)
        y = torch.rand(ModelTest.batch_size, 1)
        loss, metrics = self.model.evaluate_on_batch(x, y)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(), self.batch_metrics_values)

    def test_evaluate_on_batch_with_pred(self):
        x = torch.rand(ModelTest.batch_size, 1)
        y = torch.rand(ModelTest.batch_size, 1)
        loss, metrics, pred_y = self.model.evaluate_on_batch(x,
                                                             y,
                                                             return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(), self.batch_metrics_values)
        self.assertEqual(pred_y.shape, (ModelTest.batch_size, 1))

    def test_ndarray_evaluate_on_batch(self):
        x = np.random.rand(ModelTest.batch_size, 1).astype(np.float32)
        y = np.random.rand(ModelTest.batch_size, 1).astype(np.float32)
        loss, metrics = self.model.evaluate_on_batch(x, y)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(), self.batch_metrics_values)

    def test_predict(self):
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        pred_y = self.model.predict(x, batch_size=ModelTest.batch_size)
        self.assertEqual(pred_y.shape, (ModelTest.evaluate_dataset_len, 1))

    def test_predict_with_np_array(self):
        x = np.random.rand(ModelTest.evaluate_dataset_len,
                           1).astype(np.float32)
        pred_y = self.model.predict(x, batch_size=ModelTest.batch_size)
        self.assertEqual(type(pred_y), np.ndarray)
        self.assertEqual(pred_y.shape, (ModelTest.evaluate_dataset_len, 1))

    def test_predict_data_loader(self):
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        generator = DataLoader(x, ModelTest.batch_size)
        pred_y = self.model.predict_generator(generator)
        self.assertEqual(type(pred_y), np.ndarray)
        self.assertEqual(pred_y.shape, (ModelTest.evaluate_dataset_len, 1))

    def test_predict_generator(self):
        num_steps = 10
        generator = some_data_tensor_generator(ModelTest.batch_size)
        generator = (x for x, _ in generator)
        pred_y = self.model.predict_generator(generator, steps=num_steps)
        self.assertEqual(type(pred_y), np.ndarray)
        self.assertEqual(pred_y.shape, (num_steps * ModelTest.batch_size, 1))

    def test_predict_generator_with_no_concatenation(self):
        num_steps = 10
        generator = some_data_tensor_generator(ModelTest.batch_size)
        generator = (x for x, _ in generator)
        pred_y = self.model.predict_generator(generator,
                                              steps=num_steps,
                                              concatenate_returns=False)
        self.assertEqual(type(pred_y), list)
        for pred in pred_y:
            self.assertEqual(type(pred), np.ndarray)
            self.assertEqual(pred.shape, (ModelTest.batch_size, 1))

    def test_tensor_predict_on_batch(self):
        x = torch.rand(ModelTest.batch_size, 1)
        pred_y = self.model.predict_on_batch(x)
        self.assertEqual(pred_y.shape, (ModelTest.batch_size, 1))

    def test_ndarray_predict_on_batch(self):
        x = np.random.rand(ModelTest.batch_size, 1).astype(np.float32)
        pred_y = self.model.predict_on_batch(x)
        self.assertEqual(pred_y.shape, (ModelTest.batch_size, 1))

    @skipIf(not torch.cuda.is_available(), "no gpu available")
    def test_cpu_cuda(self):
        train_generator = some_data_tensor_generator(ModelTest.batch_size)
        valid_generator = some_data_tensor_generator(ModelTest.batch_size)

        self._capture_output()

        with torch.cuda.device(ModelTest.cuda_device):
            self.model.cuda()
            self.model.fit_generator(
                train_generator,
                valid_generator,
                epochs=ModelTest.epochs,
                steps_per_epoch=ModelTest.steps_per_epoch,
                validation_steps=ModelTest.steps_per_epoch,
                callbacks=[self.mock_callback])

        # The context manager is also used here because of this bug:
        # https://github.com/pytorch/pytorch/issues/7320
        with torch.cuda.device(ModelTest.cuda_device):
            self.model.cuda(ModelTest.cuda_device)
            self._test_device(
                torch.device('cuda:' + str(ModelTest.cuda_device)))
            self.model.fit_generator(
                train_generator,
                valid_generator,
                epochs=ModelTest.epochs,
                steps_per_epoch=ModelTest.steps_per_epoch,
                validation_steps=ModelTest.steps_per_epoch,
                callbacks=[self.mock_callback])

            self.model.cpu()
            self._test_device(torch.device('cpu'))
            self.model.fit_generator(
                train_generator,
                valid_generator,
                epochs=ModelTest.epochs,
                steps_per_epoch=ModelTest.steps_per_epoch,
                validation_steps=ModelTest.steps_per_epoch,
                callbacks=[self.mock_callback])

            self.model.to(torch.device('cuda:' + str(ModelTest.cuda_device)))
            self._test_device(
                torch.device('cuda:' + str(ModelTest.cuda_device)))
            self.model.fit_generator(
                train_generator,
                valid_generator,
                epochs=ModelTest.epochs,
                steps_per_epoch=ModelTest.steps_per_epoch,
                validation_steps=ModelTest.steps_per_epoch,
                callbacks=[self.mock_callback])

            self.model.to(torch.device('cpu'))
            self._test_device(torch.device('cpu'))
            self.model.fit_generator(
                train_generator,
                valid_generator,
                epochs=ModelTest.epochs,
                steps_per_epoch=ModelTest.steps_per_epoch,
                validation_steps=ModelTest.steps_per_epoch,
                callbacks=[self.mock_callback])

    def test_get_batch_size(self):
        batch_size = ModelTest.batch_size
        x = np.random.rand(batch_size, 1).astype(np.float32)
        y = np.random.rand(batch_size, 1).astype(np.float32)

        batch_size2 = ModelTest.batch_size + 1
        x2 = np.random.rand(batch_size2, 1).astype(np.float32)
        y2 = np.random.rand(batch_size2, 1).astype(np.float32)

        other_batch_size = batch_size2 + 1

        inf_batch_size = self.model.get_batch_size(x, y)
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size(x2, y2)
        self.assertEqual(inf_batch_size, batch_size2)

        inf_batch_size = self.model.get_batch_size(x, y2)
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size(x2, y)
        self.assertEqual(inf_batch_size, batch_size2)

        inf_batch_size = self.model.get_batch_size((x, x2), y)
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size((x2, x), y)
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size((x, x2), (y, y2))
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size((x2, x), (y, y2))
        self.assertEqual(inf_batch_size, batch_size2)

        inf_batch_size = self.model.get_batch_size([x, x2], y)
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size([x2, x], y)
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size([x, x2], [y, y2])
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size([x2, x], [y, y2])
        self.assertEqual(inf_batch_size, batch_size2)

        inf_batch_size = self.model.get_batch_size(
            {
                'batch_size': other_batch_size,
                'x': x
            }, {'y': y})
        self.assertEqual(inf_batch_size, other_batch_size)

        inf_batch_size = self.model.get_batch_size({'x': x}, {
            'batch_size': other_batch_size,
            'y': y
        })
        self.assertEqual(inf_batch_size, other_batch_size)

        inf_batch_size = self.model.get_batch_size({'x': x}, {'y': y})
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size(
            OrderedDict([('x1', x), ('x2', x2)]), {'y': y})
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size(
            OrderedDict([('x1', x2), ('x2', x)]), {'y': y})
        self.assertEqual(inf_batch_size, batch_size2)

        inf_batch_size = self.model.get_batch_size([1, 2, 3], {'y': y})
        self.assertEqual(inf_batch_size, batch_size)

    def test_get_batch_size_warning(self):
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            inf_batch_size = self.model.get_batch_size([1, 2, 3], [4, 5, 6])
            self.assertEqual(inf_batch_size, 1)
            self.assertEqual(len(w), 1)

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            warning_settings['batch_size'] = 'ignore'
            inf_batch_size = self.model.get_batch_size([1, 2, 3], [4, 5, 6])
            self.assertEqual(inf_batch_size, 1)
            self.assertEqual(len(w), 0)