def test_augmentation_is_setup_correctly_vgg_model_is_used( self, mock_data_generators): self.args.augment = False self.args.model_type = utils.VGG_ARCHITECTURE utils.train_model(self.mock_model, self.args) mock_data_generators.assert_called_once_with(self.args, utils.AUGMENTATION_KWARGS)
def test_csv_logger_callback_is_setup(self, mock_get_log_name, mock_csv_logger, mock_get_generators): utils.train_model(self.mock_model, self.args) # check csv logger is there and called with the right argument mock_csv_logger.assert_called_once_with(mock_get_log_name.return_value) self.assertIn(mock_csv_logger.return_value, self.mock_model.fit_generator.call_args[1]['callbacks'])
def test_fit_generator_is_called(self, mock_get_generators): utils.train_model(self.mock_model, self.args) # check fit generator is called and with right arguments self.mock_model.fit_generator.assert_called_once() call_args, call_kwargs = self.mock_model.fit_generator.call_args self.assertEqual(call_args[0], 'train_data_generator') self.assertEqual(call_kwargs['validation_data'], 'val_data_generator') self.assertEqual(call_kwargs['epochs'], self.args.epochs)
def test_model_is_saved(self, mock_storage_name, mock_get_generators): utils.train_model(self.mock_model, self.args) # check model is saved with the right argument self.mock_model.save.assert_called_once_with( mock_storage_name.return_value)
def test_augmentation_is_not_setup_when_not_required_and_inceptionv3_model_is_used( self, mock_data_generators): self.args.augment = False utils.train_model(self.mock_model, self.args) mock_data_generators.assert_called_once_with(self.args, {})
def test_augmentation_is_setup_correctly_when_required( self, mock_data_generators): self.args.augment = True utils.train_model(self.mock_model, self.args) mock_data_generators.assert_called_once_with(self.args, utils.AUGMENTATION_KWARGS)