示例#1
0
class ModelDictOutputTest(ModelFittingTestCase):

    def setUp(self):
        super().setUp()
        torch.manual_seed(42)
        self.pytorch_network = DictOutputModel()
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=1e-3)

        self.model = Model(
            self.pytorch_network,
            self.optimizer,
            lambda y_p, y_t: self.loss_function(y_p['out1'], y_t[0]) + self.loss_function(y_p['out2'], y_t[1]),
            batch_metrics=self.batch_metrics,
            epoch_metrics=self.epoch_metrics)

    def test_fitting_with_tensor_multi_output_dict(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelDictOutputTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - \
                     train_final_batch_missing_samples
        train_x = torch.rand(train_size, 1)
        train_y = (torch.rand(train_size, 1), torch.rand(train_size, 1))

        valid_real_steps_per_epoch = 10
        # valid_batch_size will be the same as train_batch_size in the fit method.
        valid_batch_size = train_batch_size
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - \
                     valid_final_batch_missing_samples
        valid_x = torch.rand(valid_size, 1)
        valid_y = (torch.rand(valid_size, 1), torch.rand(valid_size, 1))

        logs = self.model.fit(train_x,
                              train_y,
                              validation_data=(valid_x, valid_y),
                              epochs=ModelDictOutputTest.epochs,
                              batch_size=train_batch_size,
                              steps_per_epoch=None,
                              validation_steps=None,
                              callbacks=[self.mock_callback])
        params = {'epochs': ModelDictOutputTest.epochs, 'steps': train_real_steps_per_epoch}
        self._test_callbacks_train(params, logs)

    def test_ndarray_train_on_batch_dict_output(self):
        x = np.random.rand(ModelDictOutputTest.batch_size, 1).astype(np.float32)
        y1 = np.random.rand(ModelDictOutputTest.batch_size, 1).astype(np.float32)
        y2 = np.random.rand(ModelDictOutputTest.batch_size, 1).astype(np.float32)
        loss = self.model.train_on_batch(x, (y1, y2))
        self.assertEqual(type(loss), float)

    def test_evaluate_with_pred_dict_output(self):
        y = (torch.rand(ModelDictOutputTest.evaluate_dataset_len,
                        1), torch.rand(ModelDictOutputTest.evaluate_dataset_len, 1))
        x = torch.rand(ModelDictOutputTest.evaluate_dataset_len, 1)
        # We also test the unpacking.
        _, pred_y = self.model.evaluate(x, y, batch_size=ModelDictOutputTest.batch_size, return_pred=True)
        for pred in pred_y.values():
            self.assertEqual(pred.shape, (ModelDictOutputTest.evaluate_dataset_len, 1))
示例#2
0
 def test_evaluate_with_only_one_metric(self):
     model = Model(self.pytorch_network,
                   self.optimizer,
                   self.loss_function,
                   batch_metrics=self.batch_metrics[:1])
     x = torch.rand(ModelTest.evaluate_dataset_len, 1)
     y = torch.rand(ModelTest.evaluate_dataset_len, 1)
     loss, first_metric = model.evaluate(x,
                                         y,
                                         batch_size=ModelTest.batch_size)
     self.assertEqual(type(loss), float)
     self.assertEqual(type(first_metric), float)
     self.assertEqual(first_metric, some_metric_1_value)
示例#3
0
class ModelMultiIOTest(ModelFittingTestCase):

    def setUp(self):
        super().setUp()
        torch.manual_seed(42)
        self.pytorch_network = MultiIOModel(num_input=2, num_output=2)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=1e-3)

        self.model = Model(
            self.pytorch_network,
            self.optimizer,
            lambda y_pred, y_true: self.loss_function(y_pred[0], y_true[0]) + self.loss_function(y_pred[1], y_true[1]),
            batch_metrics=self.batch_metrics,
            epoch_metrics=self.epoch_metrics)

    def test_fitting_tensor_generator_multi_io(self):
        train_generator = some_data_tensor_generator_multi_io(ModelMultiIOTest.batch_size)
        valid_generator = some_data_tensor_generator_multi_io(ModelMultiIOTest.batch_size)
        logs = self.model.fit_generator(train_generator,
                                        valid_generator,
                                        epochs=ModelMultiIOTest.epochs,
                                        steps_per_epoch=ModelMultiIOTest.steps_per_epoch,
                                        validation_steps=ModelMultiIOTest.steps_per_epoch,
                                        callbacks=[self.mock_callback])
        params = {'epochs': ModelMultiIOTest.epochs, 'steps': ModelMultiIOTest.steps_per_epoch}
        self._test_callbacks_train(params, logs)

    def test_fitting_with_tensor_multi_io(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelMultiIOTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - \
                     train_final_batch_missing_samples
        train_x = (torch.rand(train_size, 1), torch.rand(train_size, 1))
        train_y = (torch.rand(train_size, 1), torch.rand(train_size, 1))

        valid_real_steps_per_epoch = 10
        # valid_batch_size will be the same as train_batch_size in the fit method.
        valid_batch_size = train_batch_size
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - \
                     valid_final_batch_missing_samples
        valid_x = (torch.rand(valid_size, 1), torch.rand(valid_size, 1))
        valid_y = (torch.rand(valid_size, 1), torch.rand(valid_size, 1))

        logs = self.model.fit(train_x,
                              train_y,
                              validation_data=(valid_x, valid_y),
                              epochs=ModelMultiIOTest.epochs,
                              batch_size=train_batch_size,
                              steps_per_epoch=None,
                              validation_steps=None,
                              callbacks=[self.mock_callback])
        params = {'epochs': ModelMultiIOTest.epochs, 'steps': train_real_steps_per_epoch}
        self._test_callbacks_train(params, logs)

    def test_tensor_train_on_batch_multi_io(self):
        x1 = torch.rand(ModelMultiIOTest.batch_size, 1)
        x2 = torch.rand(ModelMultiIOTest.batch_size, 1)
        y1 = torch.rand(ModelMultiIOTest.batch_size, 1)
        y2 = torch.rand(ModelMultiIOTest.batch_size, 1)
        loss = self.model.train_on_batch((x1, x2), (y1, y2))
        self.assertEqual(type(loss), float)

    def test_ndarray_train_on_batch_multi_io(self):
        x1 = np.random.rand(ModelMultiIOTest.batch_size, 1).astype(np.float32)
        x2 = np.random.rand(ModelMultiIOTest.batch_size, 1).astype(np.float32)
        y1 = np.random.rand(ModelMultiIOTest.batch_size, 1).astype(np.float32)
        y2 = np.random.rand(ModelMultiIOTest.batch_size, 1).astype(np.float32)
        loss = self.model.train_on_batch((x1, x2), (y1, y2))
        self.assertEqual(type(loss), float)

    def test_evaluate_with_pred_multi_io(self):
        x = (torch.rand(ModelMultiIOTest.evaluate_dataset_len, 1), torch.rand(ModelMultiIOTest.evaluate_dataset_len, 1))
        y = (torch.rand(ModelMultiIOTest.evaluate_dataset_len, 1), torch.rand(ModelMultiIOTest.evaluate_dataset_len, 1))
        # We also test the unpacking.
        _, pred_y = self.model.evaluate(x, y, batch_size=ModelMultiIOTest.batch_size, return_pred=True)
        for pred in pred_y:
            self.assertEqual(pred.shape, (ModelMultiIOTest.evaluate_dataset_len, 1))

    def test_tensor_evaluate_on_batch_multi_io(self):
        y = (torch.rand(ModelMultiIOTest.batch_size, 1), torch.rand(ModelMultiIOTest.batch_size, 1))
        x = (torch.rand(ModelMultiIOTest.batch_size, 1), torch.rand(ModelMultiIOTest.batch_size, 1))
        loss = self.model.evaluate_on_batch(x, y)
        self.assertEqual(type(loss), float)

    def test_predict_with_np_array_multi_io(self):
        x1 = np.random.rand(ModelMultiIOTest.evaluate_dataset_len, 1).astype(np.float32)
        x2 = np.random.rand(ModelMultiIOTest.evaluate_dataset_len, 1).astype(np.float32)
        x = (x1, x2)
        pred_y = self.model.predict(x, batch_size=ModelMultiIOTest.batch_size)
        for pred in pred_y:
            self.assertEqual(pred.shape, (ModelMultiIOTest.evaluate_dataset_len, 1))

    def test_predict_generator_multi_io(self):
        num_steps = 10
        generator = some_data_tensor_generator_multi_io(ModelMultiIOTest.batch_size)
        generator = (x for x, _ in generator)
        pred_y = self.model.predict_generator(generator, steps=num_steps)

        for pred in pred_y:
            self.assertEqual(pred.shape, (num_steps * ModelMultiIOTest.batch_size, 1))

    def test_tensor_predict_on_batch_multi_io(self):
        x = (torch.rand(ModelMultiIOTest.batch_size, 1), torch.rand(ModelMultiIOTest.batch_size, 1))
        pred_y = self.model.predict_on_batch(x)
        self._test_size_and_type_for_generator(pred_y, (ModelMultiIOTest.batch_size, 1))
import torch.nn as nn
import torch.optim as optim
from poutyne import Model

net = nn.Sequential(
    nn.Linear(100, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
)

num_features = 100
num_classes = 10

# Our training dataset with 8000 samples.
num_train_samples = 80000
x_train = torch.rand(num_train_samples, num_features)
y_train = torch.randint(num_classes, (num_train_samples,), dtype=torch.long)

# Our test dataset with 200 samples.
num_test_samples = 200
x_test = torch.rand(num_test_samples, num_features)
y_test = torch.randint(num_classes, (num_test_samples,), dtype=torch.long)


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

model = Model(net, optimizer, criterion, batch_metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5,  batch_size=32)
loss, accuracy = model.evaluate(x_test, y_test, batch_size=32)
示例#5
0
class ModelFittingTestCaseProgress(ModelFittingTestCase):
    # pylint: disable=too-many-public-methods
    num_steps = 5

    def setUp(self):
        super().setUp()
        self.train_generator = some_data_tensor_generator(ModelFittingTestCase.batch_size)
        self.valid_generator = some_data_tensor_generator(ModelFittingTestCase.batch_size)
        self.test_generator = some_data_tensor_generator(ModelFittingTestCase.batch_size)
        torch.manual_seed(42)
        self.pytorch_network = nn.Linear(1, 1)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=1e-3)
        self.batch_metrics = [
            some_batch_metric_1, ('custom_name', some_batch_metric_2), repeat_batch_metric, repeat_batch_metric
        ]
        self.batch_metrics_names = [
            'some_batch_metric_1', 'custom_name', 'repeat_batch_metric1', 'repeat_batch_metric2'
        ]
        self.batch_metrics_values = [
            some_metric_1_value, some_metric_2_value, repeat_batch_metric_value, repeat_batch_metric_value
        ]
        self.epoch_metrics = [SomeConstantEpochMetric()]
        self.epoch_metrics_names = ['some_constant_epoch_metric']
        self.epoch_metrics_values = [some_constant_epoch_metric_value]

        self.model = Model(self.pytorch_network,
                           self.optimizer,
                           self.loss_function,
                           batch_metrics=self.batch_metrics,
                           epoch_metrics=self.epoch_metrics)

        self._capture_output()

    def assertStdoutContains(self, values):
        for value in values:
            self.assertIn(value, self.test_out.getvalue().strip())

    def assertStdoutNotContains(self, values):
        for value in values:
            self.assertNotIn(value, self.test_out.getvalue().strip())

    @skipIf(color is None, "Unable to import colorama")
    def test_fitting_with_default_coloring(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback])

        self.assertStdoutContains(["[32m", "[35m", "[36m", "[94m"])

    def test_fitting_with_progress_bar_show_epoch(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback])

        self.assertStdoutContains(["Epoch", "1/5", "2/5"])

    def test_fitting_with_progress_bar_show_steps(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback])

        self.assertStdoutContains(["steps", f"{ModelFittingTestCase.steps_per_epoch}"])

    def test_fitting_with_progress_bar_show_train_val_final_steps(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback])

        self.assertStdoutContains(["Val steps", "Train steps"])

    def test_fitting_with_no_progress_bar_dont_show_epoch(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     verbose=False)

        self.assertStdoutNotContains(["Epoch", "1/5", "2/5"])

    @skipIf(color is None, "Unable to import colorama")
    def test_fitting_with_user_coloring(self):
        coloring = {
            "text_color": 'BLACK',
            "ratio_color": "BLACK",
            "metric_value_color": "BLACK",
            "time_color": "BLACK",
            "progress_bar_color": "BLACK"
        }
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=coloring))

        self.assertStdoutContains(["[30m"])

    @skipIf(color is None, "Unable to import colorama")
    def test_fitting_with_user_partial_coloring(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring={
                                         "text_color": 'BLACK',
                                         "ratio_color": "BLACK"
                                     }))

        self.assertStdoutContains(["[30m", "[32m", "[35m", "[94m"])

    def test_fitting_with_user_coloring_invalid(self):
        with self.assertRaises(KeyError):
            _ = self.model.fit_generator(self.train_generator,
                                         self.valid_generator,
                                         epochs=ModelFittingTestCase.epochs,
                                         steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                         validation_steps=ModelFittingTestCase.steps_per_epoch,
                                         callbacks=[self.mock_callback],
                                         progress_options=dict(coloring={"invalid_name": 'A COLOR'}))

    def test_fitting_with_no_coloring(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=False))

        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"])

    @skipIf(color is None, "Unable to import colorama")
    def test_fitting_with_progress_bar_default_color(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=True, progress_bar=True))

        self.assertStdoutContains(["%", "[32m", "[35m", "[36m", "[94m", "\u2588"])

    @skipIf(color is None, "Unable to import colorama")
    def test_fitting_with_progress_bar_user_color(self):
        coloring = {
            "text_color": 'BLACK',
            "ratio_color": "BLACK",
            "metric_value_color": "BLACK",
            "time_color": "BLACK",
            "progress_bar_color": "BLACK"
        }
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=coloring, progress_bar=True))

        self.assertStdoutContains(["%", "[30m", "\u2588"])

    def test_fitting_with_progress_bar_no_color(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=False, progress_bar=True))

        self.assertStdoutContains(["%", "\u2588"])
        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"])

    def test_fitting_with_no_progress_bar(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     steps_per_epoch=ModelFittingTestCase.steps_per_epoch,
                                     validation_steps=ModelFittingTestCase.steps_per_epoch,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=False, progress_bar=False))

        self.assertStdoutNotContains(["%", "\u2588"])
        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"])

    def test_progress_bar_with_step_is_none(self):
        train_generator = SomeDataGeneratorUsingStopIteration(ModelFittingTestCase.batch_size, 10)
        valid_generator = SomeDataGeneratorUsingStopIteration(ModelFittingTestCase.batch_size, 10)
        _ = self.model.fit_generator(train_generator,
                                     valid_generator,
                                     epochs=ModelFittingTestCase.epochs,
                                     progress_options=dict(coloring=False, progress_bar=True))

        self.assertStdoutContains(["s/step"])
        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m", "\u2588", "%"])

    def test_evaluate_without_progress_output(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        _, _ = self.model.evaluate(x, y, batch_size=ModelFittingTestCase.batch_size, verbose=False)

        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"])

    @skipIf(color is None, "Unable to import colorama")
    def test_evaluate_with_default_coloring(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        _, _ = self.model.evaluate(x, y, batch_size=ModelFittingTestCase.batch_size)

        self.assertStdoutContains(["[32m", "[35m", "[36m", "[94m"])

    @skipIf(color is None, "Unable to import colorama")
    def test_evaluate_with_user_coloring(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        coloring = {
            "text_color": 'BLACK',
            "ratio_color": "BLACK",
            "metric_value_color": "BLACK",
            "time_color": "BLACK",
            "progress_bar_color": "BLACK"
        }

        _, _ = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelFittingTestCase.batch_size,
                                   progress_options=dict(coloring=coloring))

        self.assertStdoutContains(["[30m"])

    @skipIf(color is None, "Unable to import colorama")
    def test_evaluate_with_user_partial_coloring(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        _, _ = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelFittingTestCase.batch_size,
                                   progress_options=dict(coloring={
                                       "text_color": 'BLACK',
                                       "ratio_color": "BLACK"
                                   }))
        self.assertStdoutContains(["[30m", "[32m", "[35m", "[94m"])

    def test_evaluate_with_user_coloring_invalid(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        with self.assertRaises(KeyError):
            _, _ = self.model.evaluate(x,
                                       y,
                                       batch_size=ModelFittingTestCase.batch_size,
                                       progress_options=dict(coloring={"invalid_name": 'A COLOR'}))

    def test_evaluate_with_no_coloring(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        _, _ = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelFittingTestCase.batch_size,
                                   progress_options=dict(coloring=False))

        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"])

    @skipIf(color is None, "Unable to import colorama")
    def test_evaluate_with_progress_bar_default_coloring(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        _, _ = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelFittingTestCase.batch_size,
                                   progress_options=dict(coloring=True, progress_bar=True))

        self.assertStdoutContains(["%", "[32m", "[35m", "[36m", "[94m", "\u2588"])

    @skipIf(color is None, "Unable to import colorama")
    def test_evaluate_with_progress_bar_user_coloring(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        coloring = {
            "text_color": 'BLACK',
            "ratio_color": "BLACK",
            "metric_value_color": "BLACK",
            "time_color": "BLACK",
            "progress_bar_color": "BLACK"
        }

        _, _ = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelFittingTestCase.batch_size,
                                   progress_options=dict(coloring=coloring, progress_bar=True))

        self.assertStdoutContains(["%", "[30m", "\u2588"])

    @skipIf(color is None, "Unable to import colorama")
    def test_evaluate_with_progress_bar_user_no_color(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        _, _ = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelFittingTestCase.batch_size,
                                   progress_options=dict(coloring=False, progress_bar=True))

        self.assertStdoutContains(["%", "\u2588"])
        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"])

    def test_evaluate_with_no_progress_bar(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)

        _, _ = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelFittingTestCase.batch_size,
                                   progress_options=dict(coloring=False, progress_bar=False))

        self.assertStdoutNotContains(["%", "\u2588"])
        self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"])

    def test_evaluate_data_loader_with_progress_bar_coloring(self):
        x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1)
        dataset = TensorDataset(x, y)
        generator = DataLoader(dataset, ModelFittingTestCase.batch_size)

        _, _ = self.model.evaluate_generator(generator, verbose=True)

        self.assertStdoutContains(["%", "[32m", "[35m", "[36m", "[94m", "\u2588"])

    def test_evaluate_generator_with_progress_bar_coloring(self):
        generator = some_data_tensor_generator(ModelFittingTestCase.batch_size)

        _, _ = self.model.evaluate_generator(generator, steps=ModelFittingTestCaseProgress.num_steps, verbose=True)

        self.assertStdoutContains(["%", "[32m", "[35m", "[36m", "[94m", "\u2588"])

    def test_evaluate_generator_with_callback_and_progress_bar_coloring(self):
        generator = some_data_tensor_generator(ModelFittingTestCase.batch_size)

        _, _ = self.model.evaluate_generator(generator,
                                             steps=ModelFittingTestCaseProgress.num_steps,
                                             callbacks=[self.mock_callback],
                                             verbose=True)

        self.assertStdoutContains(["%", "[32m", "[35m", "[36m", "[94m", "\u2588"])

    def test_fitting_complete_display_test_with_progress_bar_coloring(self):
        # we use the same color for all components for simplicity
        coloring = {
            "text_color": 'WHITE',
            "ratio_color": "WHITE",
            "metric_value_color": "WHITE",
            "time_color": "WHITE",
            "progress_bar_color": "WHITE"
        }
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=1,
                                     steps_per_epoch=ModelFittingTestCaseProgress.num_steps,
                                     validation_steps=ModelFittingTestCaseProgress.num_steps,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=coloring, progress_bar=False))

        # We split per step update
        steps_update = self.test_out.getvalue().strip().split("\r")

        # we don't validate the templating of metrics since tested before
        template_format = r".*Epoch:.*{}\/1.*\[37mStep:.*{}\/5.*{:6.2f}\%.*|{}|.*ETA:"
        epoch = 1
        # the 5 train steps
        for step, step_update in enumerate(steps_update[:ModelFittingTestCaseProgress.num_steps]):
            step += 1
            progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2)
            regex_filled = template_format.format(epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100,
                                                  progress_Bar)
            self.assertRegex(step_update, regex_filled)

        # The 5 val steps
        for step, step_update in enumerate(steps_update[ModelFittingTestCaseProgress.num_steps:-1]):
            step += 1
            progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2)
            regex_filled = template_format.format(epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100,
                                                  progress_Bar)
            self.assertRegex(step_update, regex_filled)

        # last print update templating different
        last_print_regex = r".*\[37mTrain steps:.*5.*Val steps:.*5.*[0-9]*\.[0-9][0-9]s"
        self.assertRegex(steps_update[-1], last_print_regex)

    def test_fitting_complete_display_test_with_progress_bar_no_coloring(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=1,
                                     steps_per_epoch=ModelFittingTestCaseProgress.num_steps,
                                     validation_steps=ModelFittingTestCaseProgress.num_steps,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=False, progress_bar=True))

        # We split per step update
        steps_update = self.test_out.getvalue().strip().split("\r")

        # we don't validate the templating of metrics since tested before
        template_format = r".*Epoch:.*{}\/1.*Step:.*{}\/5.*{:6.2f}\%.*|{}|.*ETA:"
        epoch = 1
        # the 5 train steps
        for step, step_update in enumerate(steps_update[:ModelFittingTestCaseProgress.num_steps]):
            step += 1
            progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2)
            regex_filled = template_format.format(epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100,
                                                  progress_Bar)
            self.assertRegex(step_update, regex_filled)

        # The 5 val steps
        for step, step_update in enumerate(steps_update[ModelFittingTestCaseProgress.num_steps:-1]):
            step += 1
            progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2)
            regex_filled = template_format.format(epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100,
                                                  progress_Bar)
            self.assertRegex(step_update, regex_filled)

        # last print update templating different
        last_print_regex = r".*Train steps:.*5.*Val steps:.*5.*[0-9]*\.[0-9][0-9]s"
        self.assertRegex(steps_update[-1], last_print_regex)

    def test_fitting_complete_display_test_with_no_progress_bar_no_coloring(self):
        _ = self.model.fit_generator(self.train_generator,
                                     self.valid_generator,
                                     epochs=1,
                                     steps_per_epoch=ModelFittingTestCaseProgress.num_steps,
                                     validation_steps=ModelFittingTestCaseProgress.num_steps,
                                     callbacks=[self.mock_callback],
                                     progress_options=dict(coloring=False, progress_bar=False))

        # We split per step update
        steps_update = self.test_out.getvalue().strip().split("\r")

        # we don't validate the templating of metrics since tested before
        template_format = r".*Epoch:.*{}\/1.*Step:.*{}\/5.*ETA:"
        epoch = 1
        # the 5 train steps
        for step, step_update in enumerate(steps_update[:ModelFittingTestCaseProgress.num_steps]):
            step += 1
            regex_filled = template_format.format(epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100)
            self.assertRegex(step_update, regex_filled)

        # The 5 val steps
        for step, step_update in enumerate(steps_update[ModelFittingTestCaseProgress.num_steps:-1]):
            step += 1
            regex_filled = template_format.format(epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100)
            self.assertRegex(step_update, regex_filled)

        # last print update templating different
        last_print_regex = r".*Train steps:.*5.*Val steps:.*5.*[0-9]*\.[0-9][0-9]s"
        self.assertRegex(steps_update[-1], last_print_regex)

    def test_evaluate_complete_display_test_with_progress_bar_coloring(self):
        # we use the same color for all components for simplicity
        coloring = {
            "text_color": 'WHITE',
            "ratio_color": "WHITE",
            "metric_value_color": "WHITE",
            "time_color": "WHITE",
            "progress_bar_color": "WHITE"
        }

        _, _ = self.model.evaluate_generator(self.test_generator,
                                             steps=ModelFittingTestCaseProgress.num_steps,
                                             callbacks=[self.mock_callback],
                                             verbose=True,
                                             progress_options=dict(coloring=coloring, progress_bar=True))

        # We split per step update
        steps_update = self.test_out.getvalue().strip().split("\r")

        # we don't validate the templating of metrics since tested before
        template_format = r".*\[37mStep:.*{}\/5.*{:6.2f}\%.*|{}|.*ETA:"
        for step, step_update in enumerate(steps_update[:-1]):
            step += 1
            progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2)
            regex_filled = template_format.format(step, step / ModelFittingTestCaseProgress.num_steps * 100,
                                                  progress_Bar)
            self.assertRegex(step_update, regex_filled)

        # last print update templating different
        last_print_regex = r".*\[37mTest steps:.*5.*[0-9]*\.[0-9][0-9]s"
        self.assertRegex(steps_update[-1], last_print_regex)

    def test_evaluate_complete_display_test_with_progress_bar_no_coloring(self):
        _, _ = self.model.evaluate_generator(self.test_generator,
                                             steps=ModelFittingTestCaseProgress.num_steps,
                                             callbacks=[self.mock_callback],
                                             verbose=True,
                                             progress_options=dict(coloring=False, progress_bar=True))

        # We split per step update
        steps_update = self.test_out.getvalue().strip().split("\r")

        # we don't validate the templating of metrics since tested before
        template_format = r".*Step:.*{}\/5.*{:6.2f}\%.*|{}|.*ETA:"
        for step, step_update in enumerate(steps_update[:-1]):
            step += 1
            progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2)
            regex_filled = template_format.format(step, step / ModelFittingTestCaseProgress.num_steps * 100,
                                                  progress_Bar)
            self.assertRegex(step_update, regex_filled)

        # last print update templating different
        last_print_regex = r".*Test steps:.*5.*[0-9]*\.[0-9][0-9]s"
        self.assertRegex(steps_update[-1], last_print_regex)

    def test_evaluate_complete_display_test_with_no_progress_bar_no_coloring(self):
        _, _ = self.model.evaluate_generator(self.test_generator,
                                             steps=ModelFittingTestCaseProgress.num_steps,
                                             callbacks=[self.mock_callback],
                                             verbose=True,
                                             progress_options=dict(coloring=False, progress_bar=False))

        # We split per step update
        steps_update = self.test_out.getvalue().strip().split("\r")

        # we don't validate the templating of metrics since tested before
        template_format = r".*Step:.*{}\/5.*ETA:"
        for step, step_update in enumerate(steps_update[:-1]):
            step += 1
            regex_filled = template_format.format(step, step / ModelFittingTestCaseProgress.num_steps * 100)
            self.assertRegex(step_update, regex_filled)

        # last print update templating different
        last_print_regex = r".*Test steps:.*5.*[0-9]*\.[0-9][0-9]s"
        self.assertRegex(steps_update[-1], last_print_regex)
示例#6
0
 def test_evaluate_with_no_metric(self):
     model = Model(self.pytorch_network, self.optimizer, self.loss_function)
     x = torch.rand(ModelTest.evaluate_dataset_len, 1)
     y = torch.rand(ModelTest.evaluate_dataset_len, 1)
     loss = model.evaluate(x, y, batch_size=ModelTest.batch_size)
     self.assertEqual(type(loss), float)
示例#7
0
class ModelTest(ModelFittingTestCase):
    # pylint: disable=too-many-public-methods

    def setUp(self):
        super().setUp()
        torch.manual_seed(42)
        self.pytorch_network = nn.Linear(1, 1)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.pytorch_network.parameters(),
                                          lr=1e-3)
        self.batch_metrics = [
            some_batch_metric_1, ('custom_name', some_batch_metric_2),
            repeat_batch_metric, repeat_batch_metric
        ]
        self.batch_metrics_names = [
            'some_batch_metric_1', 'custom_name', 'repeat_batch_metric1',
            'repeat_batch_metric2'
        ]
        self.batch_metrics_values = [
            some_metric_1_value, some_metric_2_value,
            repeat_batch_metric_value, repeat_batch_metric_value
        ]
        self.epoch_metrics = [SomeConstantEpochMetric()]
        self.epoch_metrics_names = ['some_constant_epoch_metric']
        self.epoch_metrics_values = [some_constant_epoch_metric_value]

        self.model = Model(self.pytorch_network,
                           self.optimizer,
                           self.loss_function,
                           batch_metrics=self.batch_metrics,
                           epoch_metrics=self.epoch_metrics)

    def test_fitting_tensor_generator(self):
        train_generator = some_data_tensor_generator(ModelTest.batch_size)
        valid_generator = some_data_tensor_generator(ModelTest.batch_size)
        logs = self.model.fit_generator(
            train_generator,
            valid_generator,
            epochs=ModelTest.epochs,
            steps_per_epoch=ModelTest.steps_per_epoch,
            validation_steps=ModelTest.steps_per_epoch,
            callbacks=[self.mock_callback])
        params = {
            'epochs': ModelTest.epochs,
            'steps': ModelTest.steps_per_epoch,
            'valid_steps': ModelTest.steps_per_epoch
        }
        self._test_callbacks_train(params,
                                   logs,
                                   valid_steps=ModelTest.steps_per_epoch)

    def test_fitting_without_valid_generator(self):
        train_generator = some_data_tensor_generator(ModelTest.batch_size)
        logs = self.model.fit_generator(
            train_generator,
            None,
            epochs=ModelTest.epochs,
            steps_per_epoch=ModelTest.steps_per_epoch,
            callbacks=[self.mock_callback])
        params = {
            'epochs': ModelTest.epochs,
            'steps': ModelTest.steps_per_epoch
        }
        self._test_callbacks_train(params, logs, has_valid=False)

    def test_correct_optim_calls_1_batch_per_step(self):
        train_generator = some_data_tensor_generator(ModelTest.batch_size)

        mocked_optimizer = some_mocked_optimizer()
        mocked_optim_model = Model(self.pytorch_network,
                                   mocked_optimizer,
                                   self.loss_function,
                                   batch_metrics=self.batch_metrics,
                                   epoch_metrics=self.epoch_metrics)
        mocked_optim_model.fit_generator(train_generator,
                                         None,
                                         epochs=1,
                                         steps_per_epoch=1,
                                         batches_per_step=1)

        self.assertEqual(1, mocked_optimizer.step.call_count)
        self.assertEqual(1, mocked_optimizer.zero_grad.call_count)

    def test_correct_optim_calls__valid_n_batches_per_step(self):
        n_batches = 5
        items_per_batch = int(ModelTest.batch_size / n_batches)

        x = torch.rand(n_batches, items_per_batch, 1)
        y = torch.rand(n_batches, items_per_batch, 1)

        mocked_optimizer = some_mocked_optimizer()
        mocked_optim_model = Model(self.pytorch_network,
                                   mocked_optimizer,
                                   self.loss_function,
                                   batch_metrics=self.batch_metrics,
                                   epoch_metrics=self.epoch_metrics)
        mocked_optim_model.fit_generator(list(zip(x, y)),
                                         None,
                                         epochs=1,
                                         batches_per_step=n_batches)

        self.assertEqual(1, mocked_optimizer.step.call_count)
        self.assertEqual(1, mocked_optimizer.zero_grad.call_count)

    def test_fitting_generator_n_batches_per_step(self):
        total_batch_size = 6

        x = torch.rand(1, total_batch_size, 1)
        y = torch.rand(1, total_batch_size, 1)

        initial_params = self.model.get_weight_copies()

        self.model.fit_generator(list(zip(x, y)),
                                 None,
                                 epochs=1,
                                 batches_per_step=1)

        expected_params = self.model.get_weight_copies()

        for mini_batch_size in [1, 2, 5]:
            self.model.set_weights(initial_params)

            n_batches_per_step = int(total_batch_size / mini_batch_size)

            x.resize_((n_batches_per_step, mini_batch_size, 1))
            y.resize_((n_batches_per_step, mini_batch_size, 1))

            self.model.fit_generator(list(zip(x, y)),
                                     None,
                                     epochs=1,
                                     batches_per_step=n_batches_per_step)

            returned_params = self.model.get_weight_copies()

            self.assertEqual(returned_params.keys(), expected_params.keys())
            for k in expected_params.keys():
                np.testing.assert_almost_equal(returned_params[k].numpy(),
                                               expected_params[k].numpy(),
                                               decimal=4)

    def test_fitting_generator_n_batches_per_step_higher_than_num_batches(
            self):
        total_batch_size = 6

        x = torch.rand(1, total_batch_size, 1)
        y = torch.rand(1, total_batch_size, 1)

        initial_params = self.model.get_weight_copies()

        self.model.fit_generator(list(zip(x, y)),
                                 None,
                                 epochs=1,
                                 batches_per_step=1)

        expected_params = self.model.get_weight_copies()

        self.model.set_weights(initial_params)

        self.model.fit_generator(list(zip(x, y)),
                                 None,
                                 epochs=1,
                                 batches_per_step=2)

        returned_params = self.model.get_weight_copies()

        self.assertEqual(returned_params.keys(), expected_params.keys())
        for k in expected_params.keys():
            np.testing.assert_almost_equal(returned_params[k].numpy(),
                                           expected_params[k].numpy(),
                                           decimal=4)

    def test_fitting_generator_n_batches_per_step_uneven_batches(self):
        total_batch_size = 6

        x = torch.rand(1, total_batch_size, 1)
        y = torch.rand(1, total_batch_size, 1)

        initial_params = self.model.get_weight_copies()

        self.model.fit_generator(list(zip(x, y)),
                                 None,
                                 epochs=1,
                                 batches_per_step=1)

        expected_params = self.model.get_weight_copies()

        x.squeeze_(dim=0)
        y.squeeze_(dim=0)

        uneven_chunk_sizes = [4, 5]

        for chunk_size in uneven_chunk_sizes:
            self.model.set_weights(initial_params)

            splitted_x = x.split(chunk_size)
            splitted_y = y.split(chunk_size)

            n_batches_per_step = ceil(total_batch_size / chunk_size)

            self.model.fit_generator(list(zip(splitted_x, splitted_y)),
                                     None,
                                     epochs=1,
                                     batches_per_step=n_batches_per_step)

            returned_params = self.model.get_weight_copies()

            self.assertEqual(returned_params.keys(), expected_params.keys())
            for k in expected_params.keys():
                np.testing.assert_almost_equal(returned_params[k].numpy(),
                                               expected_params[k].numpy(),
                                               decimal=4)

    def test_fitting_ndarray_generator(self):
        train_generator = some_ndarray_generator(ModelTest.batch_size)
        valid_generator = some_ndarray_generator(ModelTest.batch_size)
        logs = self.model.fit_generator(
            train_generator,
            valid_generator,
            epochs=ModelTest.epochs,
            steps_per_epoch=ModelTest.steps_per_epoch,
            validation_steps=ModelTest.steps_per_epoch,
            callbacks=[self.mock_callback])
        params = {
            'epochs': ModelTest.epochs,
            'steps': ModelTest.steps_per_epoch,
            'valid_steps': ModelTest.steps_per_epoch
        }
        self._test_callbacks_train(params,
                                   logs,
                                   valid_steps=ModelTest.steps_per_epoch)

    def test_fitting_with_data_loader(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - \
                     train_final_batch_missing_samples
        train_x = torch.rand(train_size, 1)
        train_y = torch.rand(train_size, 1)
        train_dataset = TensorDataset(train_x, train_y)
        train_generator = DataLoader(train_dataset, train_batch_size)

        valid_real_steps_per_epoch = 10
        valid_batch_size = 15
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - \
                     valid_final_batch_missing_samples
        valid_x = torch.rand(valid_size, 1)
        valid_y = torch.rand(valid_size, 1)
        valid_dataset = TensorDataset(valid_x, valid_y)
        valid_generator = DataLoader(valid_dataset, valid_batch_size)

        logs = self.model.fit_generator(train_generator,
                                        valid_generator,
                                        epochs=ModelTest.epochs,
                                        steps_per_epoch=None,
                                        validation_steps=None,
                                        callbacks=[self.mock_callback])
        params = {
            'epochs': ModelTest.epochs,
            'steps': train_real_steps_per_epoch,
            'valid_steps': valid_real_steps_per_epoch
        }
        self._test_callbacks_train(params, logs)

    def test_fitting_generator_calls(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - \
                     train_final_batch_missing_samples
        train_x = torch.rand(train_size, 1)
        train_y = torch.rand(train_size, 1)
        train_dataset = TensorDataset(train_x, train_y)
        train_generator = DataLoader(train_dataset, train_batch_size)

        valid_real_steps_per_epoch = 10
        valid_batch_size = 15
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - \
                     valid_final_batch_missing_samples
        valid_x = torch.rand(valid_size, 1)
        valid_y = torch.rand(valid_size, 1)
        valid_dataset = TensorDataset(valid_x, valid_y)
        valid_generator = DataLoader(valid_dataset, valid_batch_size)

        mock_train_generator = IterableMock(train_generator)
        mock_valid_generator = IterableMock(valid_generator)
        self.model.fit_generator(mock_train_generator,
                                 mock_valid_generator,
                                 epochs=ModelTest.epochs)
        expected_train_calls = ['__len__'] + \
                               (['__iter__'] + ['__next__'] * train_real_steps_per_epoch) * ModelTest.epochs
        expected_valid_calls = ['__len__'] + \
                               (['__iter__'] + ['__next__'] * valid_real_steps_per_epoch) * ModelTest.epochs
        self.assertEqual(mock_train_generator.calls, expected_train_calls)
        self.assertEqual(mock_valid_generator.calls, expected_valid_calls)

    def test_fitting_generator_calls_with_longer_validation_set(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - \
                     train_final_batch_missing_samples
        train_x = torch.rand(train_size, 1)
        train_y = torch.rand(train_size, 1)
        train_dataset = TensorDataset(train_x, train_y)
        train_generator = DataLoader(train_dataset, train_batch_size)

        valid_real_steps_per_epoch = 40
        valid_batch_size = 15
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - \
                     valid_final_batch_missing_samples
        valid_x = torch.rand(valid_size, 1)
        valid_y = torch.rand(valid_size, 1)
        valid_dataset = TensorDataset(valid_x, valid_y)
        valid_generator = DataLoader(valid_dataset, valid_batch_size)

        mock_train_generator = IterableMock(train_generator)
        mock_valid_generator = IterableMock(valid_generator)
        self.model.fit_generator(mock_train_generator,
                                 mock_valid_generator,
                                 epochs=ModelTest.epochs)
        expected_train_calls = ['__len__'] + \
                               (['__iter__'] + ['__next__'] * train_real_steps_per_epoch) * ModelTest.epochs
        expected_valid_calls = ['__len__'] + \
                               (['__iter__'] + ['__next__'] * valid_real_steps_per_epoch) * ModelTest.epochs
        self.assertEqual(mock_train_generator.calls, expected_train_calls)
        self.assertEqual(mock_valid_generator.calls, expected_valid_calls)

    def test_fitting_with_tensor(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - \
                     train_final_batch_missing_samples
        train_x = torch.rand(train_size, 1)
        train_y = torch.rand(train_size, 1)

        valid_real_steps_per_epoch = 10
        # valid_batch_size will be the same as train_batch_size in the fit method.
        valid_batch_size = train_batch_size
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - \
                     valid_final_batch_missing_samples
        valid_x = torch.rand(valid_size, 1)
        valid_y = torch.rand(valid_size, 1)

        logs = self.model.fit(train_x,
                              train_y,
                              validation_data=(valid_x, valid_y),
                              epochs=ModelTest.epochs,
                              batch_size=train_batch_size,
                              steps_per_epoch=None,
                              validation_steps=None,
                              callbacks=[self.mock_callback])
        params = {
            'epochs': ModelTest.epochs,
            'steps': train_real_steps_per_epoch,
            'valid_steps': valid_real_steps_per_epoch
        }
        self._test_callbacks_train(params, logs)

    def test_fitting_with_np_array(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - \
                     train_final_batch_missing_samples
        train_x = np.random.rand(train_size, 1).astype(np.float32)
        train_y = np.random.rand(train_size, 1).astype(np.float32)

        valid_real_steps_per_epoch = 10
        # valid_batch_size will be the same as train_batch_size in the fit method.
        valid_batch_size = train_batch_size
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - \
                     valid_final_batch_missing_samples
        valid_x = np.random.rand(valid_size, 1).astype(np.float32)
        valid_y = np.random.rand(valid_size, 1).astype(np.float32)

        logs = self.model.fit(train_x,
                              train_y,
                              validation_data=(valid_x, valid_y),
                              epochs=ModelTest.epochs,
                              batch_size=train_batch_size,
                              steps_per_epoch=None,
                              validation_steps=None,
                              callbacks=[self.mock_callback])
        params = {
            'epochs': ModelTest.epochs,
            'steps': train_real_steps_per_epoch,
            'valid_steps': valid_real_steps_per_epoch
        }
        self._test_callbacks_train(params, logs)

    def test_fitting_with_generator_with_len(self):
        train_real_steps_per_epoch = 30
        train_generator = SomeDataGeneratorWithLen(
            batch_size=ModelTest.batch_size,
            length=train_real_steps_per_epoch,
            num_missing_samples=7)
        valid_real_steps_per_epoch = 10
        valid_generator = SomeDataGeneratorWithLen(
            batch_size=15,
            length=valid_real_steps_per_epoch,
            num_missing_samples=3)
        logs = self.model.fit_generator(train_generator,
                                        valid_generator,
                                        epochs=ModelTest.epochs,
                                        steps_per_epoch=None,
                                        validation_steps=None,
                                        callbacks=[self.mock_callback])
        params = {
            'epochs': ModelTest.epochs,
            'steps': train_real_steps_per_epoch,
            'valid_steps': valid_real_steps_per_epoch
        }
        self._test_callbacks_train(params, logs)

    def test_fitting_with_generator_with_stop_iteration(self):
        train_real_steps_per_epoch = 30
        train_generator = SomeDataGeneratorUsingStopIteration(
            batch_size=ModelTest.batch_size, length=train_real_steps_per_epoch)
        valid_generator = SomeDataGeneratorUsingStopIteration(batch_size=15,
                                                              length=10)
        logs = self.model.fit_generator(train_generator,
                                        valid_generator,
                                        epochs=ModelTest.epochs,
                                        steps_per_epoch=None,
                                        validation_steps=None,
                                        callbacks=[self.mock_callback])
        params = {'epochs': ModelTest.epochs, 'steps': None}
        self._test_callbacks_train(params,
                                   logs,
                                   steps=train_real_steps_per_epoch)

    def test_tensor_train_on_batch(self):
        x = torch.rand(ModelTest.batch_size, 1)
        y = torch.rand(ModelTest.batch_size, 1)
        loss, metrics = self.model.train_on_batch(x, y)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(), self.batch_metrics_values)

    def test_train_on_batch_with_pred(self):
        x = torch.rand(ModelTest.batch_size, 1)
        y = torch.rand(ModelTest.batch_size, 1)
        loss, metrics, pred_y = self.model.train_on_batch(x,
                                                          y,
                                                          return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(), self.batch_metrics_values)
        self.assertEqual(pred_y.shape, (ModelTest.batch_size, 1))

    def test_ndarray_train_on_batch(self):
        x = np.random.rand(ModelTest.batch_size, 1).astype(np.float32)
        y = np.random.rand(ModelTest.batch_size, 1).astype(np.float32)
        loss, metrics = self.model.train_on_batch(x, y)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(), self.batch_metrics_values)

    def test_evaluate(self):
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelTest.evaluate_dataset_len, 1)
        loss, metrics = self.model.evaluate(x,
                                            y,
                                            batch_size=ModelTest.batch_size)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(),
                         self.batch_metrics_values + self.epoch_metrics_values)

    def test_evaluate_with_pred(self):
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelTest.evaluate_dataset_len, 1)
        # We also test the unpacking.
        _, _, pred_y = self.model.evaluate(x,
                                           y,
                                           batch_size=ModelTest.batch_size,
                                           return_pred=True)
        self.assertEqual(pred_y.shape, (ModelTest.evaluate_dataset_len, 1))

    def test_evaluate_with_callback(self):
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelTest.evaluate_dataset_len, 1)
        # We also test the unpacking.
        _, _, pred_y = self.model.evaluate(x,
                                           y,
                                           batch_size=ModelTest.batch_size,
                                           return_pred=True,
                                           callbacks=[self.mock_callback])
        self.assertEqual(pred_y.shape, (ModelTest.evaluate_dataset_len, 1))

    def test_evaluate_with_return_dict(self):
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelTest.evaluate_dataset_len, 1)
        logs = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelTest.batch_size,
                                   return_dict_format=True)

        self._test_return_dict_logs(logs)

    def test_evaluate_with_np_array(self):
        x = np.random.rand(ModelTest.evaluate_dataset_len,
                           1).astype(np.float32)
        y = np.random.rand(ModelTest.evaluate_dataset_len,
                           1).astype(np.float32)
        loss, metrics, pred_y = self.model.evaluate(
            x, y, batch_size=ModelTest.batch_size, return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(),
                         self.batch_metrics_values + self.epoch_metrics_values)
        self.assertEqual(pred_y.shape, (ModelTest.evaluate_dataset_len, 1))

    def test_evaluate_data_loader(self):
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelTest.evaluate_dataset_len, 1)
        dataset = TensorDataset(x, y)
        generator = DataLoader(dataset, ModelTest.batch_size)
        loss, metrics, pred_y = self.model.evaluate_generator(generator,
                                                              return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(),
                         self.batch_metrics_values + self.epoch_metrics_values)
        self.assertEqual(pred_y.shape, (ModelTest.evaluate_dataset_len, 1))

    def test_evaluate_generator(self):
        num_steps = 10
        generator = some_data_tensor_generator(ModelTest.batch_size)
        loss, metrics, pred_y = self.model.evaluate_generator(generator,
                                                              steps=num_steps,
                                                              return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(),
                         self.batch_metrics_values + self.epoch_metrics_values)
        self.assertEqual(type(pred_y), np.ndarray)
        self.assertEqual(pred_y.shape, (num_steps * ModelTest.batch_size, 1))

    def test_evaluate_generator_with_stop_iteration(self):
        test_generator = SomeDataGeneratorUsingStopIteration(
            ModelTest.batch_size, 10)

        loss, _ = self.model.evaluate_generator(test_generator)

        self.assertEqual(type(loss), float)

    def test_evaluate_generator_with_callback(self):
        num_steps = 10
        generator = some_data_tensor_generator(ModelTest.batch_size)
        self.model.evaluate_generator(generator,
                                      steps=num_steps,
                                      callbacks=[self.mock_callback])

        params = {'steps': ModelTest.epochs}
        self._test_callbacks_test(params)

    def test_evaluate_generator_with_return_dict(self):
        num_steps = 10
        generator = some_data_tensor_generator(ModelTest.batch_size)
        logs = self.model.evaluate_generator(generator,
                                             steps=num_steps,
                                             return_dict_format=True)

        self._test_return_dict_logs(logs)

    def test_evaluate_generator_with_ground_truth(self):
        num_steps = 10
        generator = some_data_tensor_generator(ModelTest.batch_size)
        loss, metrics, pred_y, true_y = self.model.evaluate_generator(
            generator,
            steps=num_steps,
            return_pred=True,
            return_ground_truth=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(),
                         self.batch_metrics_values + self.epoch_metrics_values)
        self.assertEqual(type(pred_y), np.ndarray)
        self.assertEqual(type(true_y), np.ndarray)
        self.assertEqual(pred_y.shape, (num_steps * ModelTest.batch_size, 1))
        self.assertEqual(true_y.shape, (num_steps * ModelTest.batch_size, 1))

    def test_evaluate_generator_with_no_concatenation(self):
        num_steps = 10
        generator = some_data_tensor_generator(ModelTest.batch_size)
        loss, metrics, pred_y, true_y = self.model.evaluate_generator(
            generator,
            steps=num_steps,
            return_pred=True,
            return_ground_truth=True,
            concatenate_returns=False)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(),
                         self.batch_metrics_values + self.epoch_metrics_values)

        self.assertEqual(type(pred_y), list)
        for pred in pred_y:
            self.assertEqual(type(pred), np.ndarray)
            self.assertEqual(pred.shape, (ModelTest.batch_size, 1))
        self.assertEqual(type(true_y), list)
        for true in true_y:
            self.assertEqual(type(true), np.ndarray)
            self.assertEqual(true.shape, (ModelTest.batch_size, 1))

    def test_evaluate_with_only_one_metric(self):
        model = Model(self.pytorch_network,
                      self.optimizer,
                      self.loss_function,
                      batch_metrics=self.batch_metrics[:1])
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelTest.evaluate_dataset_len, 1)
        loss, first_metric = model.evaluate(x,
                                            y,
                                            batch_size=ModelTest.batch_size)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(first_metric), float)
        self.assertEqual(first_metric, some_metric_1_value)

    def test_metrics_integration(self):
        num_steps = 10
        model = Model(self.pytorch_network,
                      self.optimizer,
                      self.loss_function,
                      batch_metrics=[F.mse_loss])
        train_generator = some_data_tensor_generator(ModelTest.batch_size)
        valid_generator = some_data_tensor_generator(ModelTest.batch_size)
        model.fit_generator(train_generator,
                            valid_generator,
                            epochs=ModelTest.epochs,
                            steps_per_epoch=ModelTest.steps_per_epoch,
                            validation_steps=ModelTest.steps_per_epoch,
                            callbacks=[self.mock_callback])
        generator = some_data_tensor_generator(ModelTest.batch_size)
        loss, mse = model.evaluate_generator(generator, steps=num_steps)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(mse), float)

    def test_epoch_metrics_integration(self):
        model = Model(self.pytorch_network,
                      self.optimizer,
                      self.loss_function,
                      epoch_metrics=[SomeEpochMetric()])
        train_generator = some_data_tensor_generator(ModelTest.batch_size)
        valid_generator = some_data_tensor_generator(ModelTest.batch_size)
        logs = model.fit_generator(train_generator,
                                   valid_generator,
                                   epochs=1,
                                   steps_per_epoch=ModelTest.steps_per_epoch,
                                   validation_steps=ModelTest.steps_per_epoch)
        actual_value = logs[-1]['some_epoch_metric']
        val_actual_value = logs[-1]['val_some_epoch_metric']
        expected_value = 5
        self.assertEqual(val_actual_value, expected_value)
        self.assertEqual(actual_value, expected_value)

    def test_evaluate_with_no_metric(self):
        model = Model(self.pytorch_network, self.optimizer, self.loss_function)
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelTest.evaluate_dataset_len, 1)
        loss = model.evaluate(x, y, batch_size=ModelTest.batch_size)
        self.assertEqual(type(loss), float)

    def test_tensor_evaluate_on_batch(self):
        x = torch.rand(ModelTest.batch_size, 1)
        y = torch.rand(ModelTest.batch_size, 1)
        loss, metrics = self.model.evaluate_on_batch(x, y)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(), self.batch_metrics_values)

    def test_evaluate_on_batch_with_pred(self):
        x = torch.rand(ModelTest.batch_size, 1)
        y = torch.rand(ModelTest.batch_size, 1)
        loss, metrics, pred_y = self.model.evaluate_on_batch(x,
                                                             y,
                                                             return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(), self.batch_metrics_values)
        self.assertEqual(pred_y.shape, (ModelTest.batch_size, 1))

    def test_ndarray_evaluate_on_batch(self):
        x = np.random.rand(ModelTest.batch_size, 1).astype(np.float32)
        y = np.random.rand(ModelTest.batch_size, 1).astype(np.float32)
        loss, metrics = self.model.evaluate_on_batch(x, y)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(), self.batch_metrics_values)

    def test_predict(self):
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        pred_y = self.model.predict(x, batch_size=ModelTest.batch_size)
        self.assertEqual(pred_y.shape, (ModelTest.evaluate_dataset_len, 1))

    def test_predict_with_np_array(self):
        x = np.random.rand(ModelTest.evaluate_dataset_len,
                           1).astype(np.float32)
        pred_y = self.model.predict(x, batch_size=ModelTest.batch_size)
        self.assertEqual(type(pred_y), np.ndarray)
        self.assertEqual(pred_y.shape, (ModelTest.evaluate_dataset_len, 1))

    def test_predict_data_loader(self):
        x = torch.rand(ModelTest.evaluate_dataset_len, 1)
        generator = DataLoader(x, ModelTest.batch_size)
        pred_y = self.model.predict_generator(generator)
        self.assertEqual(type(pred_y), np.ndarray)
        self.assertEqual(pred_y.shape, (ModelTest.evaluate_dataset_len, 1))

    def test_predict_generator(self):
        num_steps = 10
        generator = some_data_tensor_generator(ModelTest.batch_size)
        generator = (x for x, _ in generator)
        pred_y = self.model.predict_generator(generator, steps=num_steps)
        self.assertEqual(type(pred_y), np.ndarray)
        self.assertEqual(pred_y.shape, (num_steps * ModelTest.batch_size, 1))

    def test_predict_generator_with_no_concatenation(self):
        num_steps = 10
        generator = some_data_tensor_generator(ModelTest.batch_size)
        generator = (x for x, _ in generator)
        pred_y = self.model.predict_generator(generator,
                                              steps=num_steps,
                                              concatenate_returns=False)
        self.assertEqual(type(pred_y), list)
        for pred in pred_y:
            self.assertEqual(type(pred), np.ndarray)
            self.assertEqual(pred.shape, (ModelTest.batch_size, 1))

    def test_tensor_predict_on_batch(self):
        x = torch.rand(ModelTest.batch_size, 1)
        pred_y = self.model.predict_on_batch(x)
        self.assertEqual(pred_y.shape, (ModelTest.batch_size, 1))

    def test_ndarray_predict_on_batch(self):
        x = np.random.rand(ModelTest.batch_size, 1).astype(np.float32)
        pred_y = self.model.predict_on_batch(x)
        self.assertEqual(pred_y.shape, (ModelTest.batch_size, 1))

    @skipIf(not torch.cuda.is_available(), "no gpu available")
    def test_cpu_cuda(self):
        train_generator = some_data_tensor_generator(ModelTest.batch_size)
        valid_generator = some_data_tensor_generator(ModelTest.batch_size)

        self._capture_output()

        with torch.cuda.device(ModelTest.cuda_device):
            self.model.cuda()
            self.model.fit_generator(
                train_generator,
                valid_generator,
                epochs=ModelTest.epochs,
                steps_per_epoch=ModelTest.steps_per_epoch,
                validation_steps=ModelTest.steps_per_epoch,
                callbacks=[self.mock_callback])

        # The context manager is also used here because of this bug:
        # https://github.com/pytorch/pytorch/issues/7320
        with torch.cuda.device(ModelTest.cuda_device):
            self.model.cuda(ModelTest.cuda_device)
            self._test_device(
                torch.device('cuda:' + str(ModelTest.cuda_device)))
            self.model.fit_generator(
                train_generator,
                valid_generator,
                epochs=ModelTest.epochs,
                steps_per_epoch=ModelTest.steps_per_epoch,
                validation_steps=ModelTest.steps_per_epoch,
                callbacks=[self.mock_callback])

            self.model.cpu()
            self._test_device(torch.device('cpu'))
            self.model.fit_generator(
                train_generator,
                valid_generator,
                epochs=ModelTest.epochs,
                steps_per_epoch=ModelTest.steps_per_epoch,
                validation_steps=ModelTest.steps_per_epoch,
                callbacks=[self.mock_callback])

            self.model.to(torch.device('cuda:' + str(ModelTest.cuda_device)))
            self._test_device(
                torch.device('cuda:' + str(ModelTest.cuda_device)))
            self.model.fit_generator(
                train_generator,
                valid_generator,
                epochs=ModelTest.epochs,
                steps_per_epoch=ModelTest.steps_per_epoch,
                validation_steps=ModelTest.steps_per_epoch,
                callbacks=[self.mock_callback])

            self.model.to(torch.device('cpu'))
            self._test_device(torch.device('cpu'))
            self.model.fit_generator(
                train_generator,
                valid_generator,
                epochs=ModelTest.epochs,
                steps_per_epoch=ModelTest.steps_per_epoch,
                validation_steps=ModelTest.steps_per_epoch,
                callbacks=[self.mock_callback])

    def test_get_batch_size(self):
        batch_size = ModelTest.batch_size
        x = np.random.rand(batch_size, 1).astype(np.float32)
        y = np.random.rand(batch_size, 1).astype(np.float32)

        batch_size2 = ModelTest.batch_size + 1
        x2 = np.random.rand(batch_size2, 1).astype(np.float32)
        y2 = np.random.rand(batch_size2, 1).astype(np.float32)

        other_batch_size = batch_size2 + 1

        inf_batch_size = self.model.get_batch_size(x, y)
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size(x2, y2)
        self.assertEqual(inf_batch_size, batch_size2)

        inf_batch_size = self.model.get_batch_size(x, y2)
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size(x2, y)
        self.assertEqual(inf_batch_size, batch_size2)

        inf_batch_size = self.model.get_batch_size((x, x2), y)
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size((x2, x), y)
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size((x, x2), (y, y2))
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size((x2, x), (y, y2))
        self.assertEqual(inf_batch_size, batch_size2)

        inf_batch_size = self.model.get_batch_size([x, x2], y)
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size([x2, x], y)
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size([x, x2], [y, y2])
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size([x2, x], [y, y2])
        self.assertEqual(inf_batch_size, batch_size2)

        inf_batch_size = self.model.get_batch_size(
            {
                'batch_size': other_batch_size,
                'x': x
            }, {'y': y})
        self.assertEqual(inf_batch_size, other_batch_size)

        inf_batch_size = self.model.get_batch_size({'x': x}, {
            'batch_size': other_batch_size,
            'y': y
        })
        self.assertEqual(inf_batch_size, other_batch_size)

        inf_batch_size = self.model.get_batch_size({'x': x}, {'y': y})
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size(
            OrderedDict([('x1', x), ('x2', x2)]), {'y': y})
        self.assertEqual(inf_batch_size, batch_size)

        inf_batch_size = self.model.get_batch_size(
            OrderedDict([('x1', x2), ('x2', x)]), {'y': y})
        self.assertEqual(inf_batch_size, batch_size2)

        inf_batch_size = self.model.get_batch_size([1, 2, 3], {'y': y})
        self.assertEqual(inf_batch_size, batch_size)

    def test_get_batch_size_warning(self):
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            inf_batch_size = self.model.get_batch_size([1, 2, 3], [4, 5, 6])
            self.assertEqual(inf_batch_size, 1)
            self.assertEqual(len(w), 1)

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            warning_settings['batch_size'] = 'ignore'
            inf_batch_size = self.model.get_batch_size([1, 2, 3], [4, 5, 6])
            self.assertEqual(inf_batch_size, 1)
            self.assertEqual(len(w), 0)
class ModelMultiInputTest(ModelFittingTestCase):
    def setUp(self):
        super().setUp()
        torch.manual_seed(42)
        self.pytorch_network = MultiIOModel(num_input=1, num_output=1)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(),
                                         lr=1e-3)

        self.model = Model(
            self.pytorch_network,
            self.optimizer,
            self.loss_function,
            batch_metrics=self.batch_metrics,
            epoch_metrics=self.epoch_metrics,
        )

    def test_fitting_tensor_generator_multi_input(self):
        train_generator = some_data_tensor_generator_multi_input(
            ModelMultiInputTest.batch_size)
        valid_generator = some_data_tensor_generator_multi_input(
            ModelMultiInputTest.batch_size)
        logs = self.model.fit_generator(
            train_generator,
            valid_generator,
            epochs=ModelMultiInputTest.epochs,
            steps_per_epoch=ModelMultiInputTest.steps_per_epoch,
            validation_steps=ModelMultiInputTest.steps_per_epoch,
            callbacks=[self.mock_callback],
        )
        params = {
            'epochs': ModelMultiInputTest.epochs,
            'steps': ModelMultiInputTest.steps_per_epoch,
            'valid_steps': ModelMultiInputTest.steps_per_epoch,
        }
        self._test_callbacks_train(
            params, logs, valid_steps=ModelMultiInputTest.steps_per_epoch)

    def test_fitting_with_tensor_multi_input(self):
        train_real_steps_per_epoch = 30
        train_batch_size = ModelMultiInputTest.batch_size
        train_final_batch_missing_samples = 7
        train_size = train_real_steps_per_epoch * train_batch_size - train_final_batch_missing_samples
        train_x = (torch.rand(train_size, 1), torch.rand(train_size, 1))
        train_y = torch.rand(train_size, 1)

        valid_real_steps_per_epoch = 10
        # valid_batch_size will be the same as train_batch_size in the fit method.
        valid_batch_size = train_batch_size
        valid_final_batch_missing_samples = 3
        valid_size = valid_real_steps_per_epoch * valid_batch_size - valid_final_batch_missing_samples
        valid_x = (torch.rand(valid_size, 1), torch.rand(valid_size, 1))
        valid_y = torch.rand(valid_size, 1)

        logs = self.model.fit(
            train_x,
            train_y,
            validation_data=(valid_x, valid_y),
            epochs=ModelMultiInputTest.epochs,
            batch_size=train_batch_size,
            steps_per_epoch=None,
            validation_steps=None,
            callbacks=[self.mock_callback],
        )
        params = {
            'epochs': ModelMultiInputTest.epochs,
            'steps': train_real_steps_per_epoch,
            'valid_steps': valid_real_steps_per_epoch,
        }
        self._test_callbacks_train(params, logs)

    def test_tensor_train_on_batch_multi_input(self):
        x1 = torch.rand(ModelMultiInputTest.batch_size, 1)
        x2 = torch.rand(ModelMultiInputTest.batch_size, 1)
        y = torch.rand(ModelMultiInputTest.batch_size, 1)
        loss = self.model.train_on_batch((x1, x2), y)
        self.assertEqual(type(loss), float)

    def test_train_on_batch_with_pred_multi_input(self):
        x1 = torch.rand(ModelMultiInputTest.batch_size, 1)
        x2 = torch.rand(ModelMultiInputTest.batch_size, 1)
        y = torch.rand(ModelMultiInputTest.batch_size, 1)
        loss, pred_y = self.model.train_on_batch((x1, x2), y, return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(pred_y.shape, (ModelMultiInputTest.batch_size, 1))

    def test_ndarray_train_on_batch_multi_input(self):
        x1 = np.random.rand(ModelMultiInputTest.batch_size,
                            1).astype(np.float32)
        x2 = np.random.rand(ModelMultiInputTest.batch_size,
                            1).astype(np.float32)
        y = np.random.rand(ModelMultiInputTest.batch_size,
                           1).astype(np.float32)
        loss = self.model.train_on_batch((x1, x2), y)
        self.assertEqual(type(loss), float)

    def test_evaluate_multi_input(self):
        x = (
            torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1),
            torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1),
        )
        y = torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1)
        loss = self.model.evaluate(x,
                                   y,
                                   batch_size=ModelMultiInputTest.batch_size)
        self.assertEqual(type(loss), float)

    def test_evaluate_with_pred_multi_input(self):
        x = (
            torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1),
            torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1),
        )
        y = torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1)
        # We also test the unpacking.
        _, pred_y = self.model.evaluate(
            x, y, batch_size=ModelMultiInputTest.batch_size, return_pred=True)
        self.assertEqual(pred_y.shape,
                         (ModelMultiInputTest.evaluate_dataset_len, 1))

    def test_evaluate_with_np_array_multi_input(self):
        x1 = np.random.rand(ModelMultiInputTest.evaluate_dataset_len,
                            1).astype(np.float32)
        x2 = np.random.rand(ModelMultiInputTest.evaluate_dataset_len,
                            1).astype(np.float32)
        x = (x1, x2)
        y = np.random.rand(ModelMultiInputTest.evaluate_dataset_len,
                           1).astype(np.float32)
        loss, pred_y = self.model.evaluate(
            x, y, batch_size=ModelMultiInputTest.batch_size, return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(pred_y.shape,
                         (ModelMultiInputTest.evaluate_dataset_len, 1))

    def test_evaluate_data_loader_multi_input(self):
        x1 = torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1)
        x2 = torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1)
        y = torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1)
        dataset = TensorDataset((x1, x2), y)
        generator = DataLoader(dataset, ModelMultiInputTest.batch_size)
        loss, pred_y = self.model.evaluate_generator(generator,
                                                     return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(pred_y.shape,
                         (ModelMultiInputTest.evaluate_dataset_len, 1))

    def test_evaluate_generator_multi_input(self):
        num_steps = 10
        generator = some_data_tensor_generator_multi_input(
            ModelMultiInputTest.batch_size)
        loss, pred_y = self.model.evaluate_generator(generator,
                                                     steps=num_steps,
                                                     return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(pred_y.shape,
                         (num_steps * ModelMultiInputTest.batch_size, 1))

    def test_tensor_evaluate_on_batch_multi_input(self):
        x1 = torch.rand(ModelMultiInputTest.batch_size, 1)
        x2 = torch.rand(ModelMultiInputTest.batch_size, 1)
        y = torch.rand(ModelMultiInputTest.batch_size, 1)
        loss = self.model.evaluate_on_batch((x1, x2), y)
        self.assertEqual(type(loss), float)

    def test_predict_multi_input(self):
        x = (
            torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1),
            torch.rand(ModelMultiInputTest.evaluate_dataset_len, 1),
        )
        pred_y = self.model.predict(x,
                                    batch_size=ModelMultiInputTest.batch_size)
        self.assertEqual(pred_y.shape,
                         (ModelMultiInputTest.evaluate_dataset_len, 1))

    def test_predict_with_np_array_multi_input(self):
        x1 = np.random.rand(ModelMultiInputTest.evaluate_dataset_len,
                            1).astype(np.float32)
        x2 = np.random.rand(ModelMultiInputTest.evaluate_dataset_len,
                            1).astype(np.float32)
        x = (x1, x2)
        pred_y = self.model.predict(x,
                                    batch_size=ModelMultiInputTest.batch_size)
        self.assertEqual(pred_y.shape,
                         (ModelMultiInputTest.evaluate_dataset_len, 1))

    def test_predict_generator_multi_input(self):
        num_steps = 10
        generator = some_data_tensor_generator_multi_input(
            ModelMultiInputTest.batch_size)
        generator = (x for x, _ in generator)
        pred_y = self.model.predict_generator(generator, steps=num_steps)
        self.assertEqual(type(pred_y), np.ndarray)
        self.assertEqual(pred_y.shape,
                         (num_steps * ModelMultiInputTest.batch_size, 1))

    def test_tensor_predict_on_batch_multi_input(self):
        x1 = torch.rand(ModelMultiInputTest.batch_size, 1)
        x2 = torch.rand(ModelMultiInputTest.batch_size, 1)
        pred_y = self.model.predict_on_batch((x1, x2))
        self.assertEqual(pred_y.shape, (ModelMultiInputTest.batch_size, 1))