def test_print_summary_without_print_fn(self): model = keras.Sequential([ keras.layers.Dense(5, input_shape=(10,), name='dense')]) io_utils.enable_interactive_logging() with self.captureWritesToStream(sys.stdout) as printed: layer_utils.print_summary(model) self.assertIn('dense (Dense)', printed.contents())
def test_saving_after_fit(self): temp_dir = os.path.join(self.get_temp_dir(), "my_model") subclassed_model = self._get_subclassed_model() x = np.random.random((100, 32)) y = np.random.random((100, 1)) subclassed_model.fit(x, y, epochs=1) subclassed_model._save_new(temp_dir) loaded_model = saving_lib.load(temp_dir) io_utils.enable_interactive_logging() # `tf.print` writes to stderr. This is to make sure the custom training # step is used. with self.captureWritesToStream(sys.stderr) as printed: loaded_model.fit(x, y, epochs=1) self.assertRegex(printed.contents(), train_step_message) # Check that the custom classes do get used. self.assertIsInstance(loaded_model, CustomModelX) self.assertIsInstance(loaded_model.dense1, MyDense) # Check that the custom method is available. self.assertEqual(loaded_model.one(), 1) self.assertEqual(loaded_model.dense1.two(), 2) # Everything should be the same class or function for the original model # and the loaded model. for model in [subclassed_model, loaded_model]: self.assertIs( model.optimizer.__class__, keras.optimizers.optimizer_v2.adam.Adam, ) self.assertIs( model.compiled_loss.__class__, keras.engine.compile_utils.LossesContainer, ) self.assertIs( model.compiled_loss._losses[0].__class__, keras.losses.LossFunctionWrapper, ) self.assertIs( model.compiled_loss._losses[1].__class__, keras.losses.LossFunctionWrapper, ) self.assertIs( model.compiled_loss._losses[2].__class__, keras.losses.MeanSquaredError, ) self.assertIs( model.compiled_loss._losses[3].__class__, keras.losses.LossFunctionWrapper, ) self.assertIs( model.compiled_loss._total_loss_mean.__class__, keras.metrics.base_metric.Mean, )
def test_custom_object_scope_correct_class(self): train_step_message = "This is my training step" temp_dir = os.path.join(self.get_temp_dir(), "my_model") class CustomModelX(keras.Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.dense1 = keras.layers.Dense(1) def call(self, inputs): return self.dense1(inputs) def train_step(self, data): tf.print(train_step_message) x, y = data with tf.GradientTape() as tape: y_pred = self(x) loss = self.compiled_loss(y, y_pred) gradients = tape.gradient(loss, self.trainable_variables) self.optimizer.apply_gradients( zip(gradients, self.trainable_variables)) return {} def func_that_returns_one(self): return 1 subclassed_model = CustomModelX() subclassed_model.compile(optimizer="adam", loss="mse") x = np.random.random((100, 32)) y = np.random.random((100, 1)) subclassed_model.fit(x, y, epochs=1) subclassed_model.save(temp_dir, save_format="tf") with keras.utils.generic_utils.custom_object_scope( {"CustomModelX": CustomModelX}): loaded_model = keras.models.load_model(temp_dir) io_utils.enable_interactive_logging() # `tf.print` writes to stderr. with self.captureWritesToStream(sys.stderr) as printed: loaded_model.fit(x, y, epochs=1) if tf.__internal__.tf2.enabled(): # `tf.print` message is only available in stderr in TF2. Check that # custom `train_step` is used. self.assertRegex(printed.contents(), train_step_message) # Check that the custom class does get used. self.assertIsInstance(loaded_model, CustomModelX) # Check that the custom method is available. self.assertEqual(loaded_model.func_that_returns_one(), 1)
def test_finite_dataset_unknown_cardinality_no_step_with_train_and_val( self, ): class CaptureStdout: def __enter__(self): self._stdout = sys.stdout string_io = io.StringIO() sys.stdout = string_io self._stringio = string_io return self def __exit__(self, *args): self.output = self._stringio.getvalue() sys.stdout = self._stdout model = test_utils.get_small_mlp(1, 4, input_dim=3) model.compile("rmsprop", "mse", run_eagerly=test_utils.should_run_eagerly()) inputs = np.zeros((100, 3), dtype=np.float32) targets = np.random.randint(0, 4, size=100, dtype=np.int32) dataset = tf.data.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.filter(lambda x, y: True).batch(10) self.assertEqual( keras.backend.get_value(tf.data.experimental.cardinality(dataset)), tf.data.experimental.UNKNOWN_CARDINALITY, ) batch_counter = BatchCounterCallback() io_utils.enable_interactive_logging() with CaptureStdout() as capture: history = model.fit( dataset, epochs=2, callbacks=[batch_counter], validation_data=dataset.take(3), ) lines = capture.output.splitlines() self.assertIn("10/10", lines[-1]) self.assertLen(history.history["loss"], 2) self.assertEqual(batch_counter.batch_begin_count, 21) self.assertEqual(batch_counter.batch_end_count, 20) model.evaluate(dataset) out = model.predict(dataset) self.assertEqual(out.shape[0], 100)
def test_print_info_with_datasets(self): """Print training info should work with val datasets (b/133391839).""" model = keras.models.Sequential([keras.layers.Dense(1, input_shape=(1,))]) model.compile(loss="mse", optimizer="sgd") dataset = tf.data.Dataset.from_tensors( ([1.], [1.])).repeat(100).batch(10) val_dataset = tf.data.Dataset.from_tensors( ([1.], [1.])).repeat(50).batch(10) mock_stdout = io.StringIO() io_utils.enable_interactive_logging() with tf.compat.v1.test.mock.patch.object(sys, "stdout", mock_stdout): model.fit(dataset, epochs=2, validation_data=val_dataset) self.assertIn( "Train on 10 steps, validate on 5 steps", mock_stdout.getvalue())
def test_print_msg(self): enabled = io_utils.is_interactive_logging_enabled() io_utils.disable_interactive_logging() self.assertFalse(io_utils.is_interactive_logging_enabled()) with self.assertLogs(level="INFO") as logged: io_utils.print_msg("Testing Message") self.assertIn("Testing Message", logged.output[0]) io_utils.enable_interactive_logging() self.assertTrue(io_utils.is_interactive_logging_enabled()) with self.captureWritesToStream(sys.stdout) as printed: io_utils.print_msg("Testing Message") self.assertEqual("Testing Message\n", printed.contents()) if enabled: io_utils.enable_interactive_logging() else: io_utils.disable_interactive_logging()