def test_keras_model_multiple_outputs(self): keras_model = model_examples.build_multiple_outputs_keras_model() input_spec = collections.OrderedDict( x=[ tf.TensorSpec(shape=[None, 1], dtype=tf.float32), tf.TensorSpec(shape=[None, 1], dtype=tf.float32) ], y=[ tf.TensorSpec(shape=[None, 1], dtype=tf.float32), tf.TensorSpec(shape=[None, 1], dtype=tf.float32), tf.TensorSpec(shape=[None, 1], dtype=tf.float32) ]) with self.subTest('loss_output_len_mismatch'): with self.assertRaises(ValueError): _ = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ]) with self.subTest('invalid_loss'): with self.assertRaises(TypeError): _ = keras_utils.from_keras_model(keras_model=keras_model, input_spec=input_spec, loss=3) with self.subTest('loss_list_no_opt'): tff_model = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ]) self.assertIsInstance(tff_model, model_utils.EnhancedModel) dummy_batch = collections.OrderedDict( x=[ np.zeros([1, 1], dtype=np.float32), np.zeros([1, 1], dtype=np.float32) ], y=[ np.zeros([1, 1], dtype=np.float32), np.ones([1, 1], dtype=np.float32), np.ones([1, 1], dtype=np.float32) ]) output = tff_model.forward_pass(dummy_batch) self.assertAllClose(output.loss, 2.0) keras_model = model_examples.build_multiple_outputs_keras_model() with self.subTest('loss_weights_as_list'): tff_model = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ], loss_weights=[0.1, 0.2, 0.3]) output = tff_model.forward_pass(dummy_batch) self.assertAllClose(output.loss, 0.5) output = tff_model.forward_pass(dummy_batch) self.assertAllClose(output.loss, 0.5) with self.subTest('loss_weights_assert_fail_list'): with self.assertRaises(ValueError): _ = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ], loss_weights=[0.1, 0.2]) with self.subTest('loss_weights_assert_fail_dict'): with self.assertRaises(TypeError): _ = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ], loss_weights={ 'dense_5': 0.1, 'dense_6': 0.2, 'dummy': 0.4 })
def test_keras_model_multiple_outputs(self, input_spec): keras_model = model_examples.build_multiple_outputs_keras_model() with self.subTest('loss_output_len_mismatch'): with self.assertRaises(ValueError): _ = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ]) with self.subTest('invalid_loss'): with self.assertRaises(TypeError): _ = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=3) with self.subTest('loss_as_dict_fails'): with self.assertRaises(TypeError): _ = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss={ 'dense_5': tf.keras.losses.MeanSquaredError(), 'dense_6': tf.keras.losses.MeanSquaredError(), 'whimsy': tf.keras.losses.MeanSquaredError() }) with self.subTest('loss_list_no_opt'): tff_model = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ]) self.assertIsInstance(tff_model, model_utils.EnhancedModel) example_batch = collections.OrderedDict( x=[ np.zeros([1, 1], dtype=np.float32), np.zeros([1, 1], dtype=np.float32) ], y=[ np.zeros([1, 1], dtype=np.float32), np.ones([1, 1], dtype=np.float32), np.ones([1, 1], dtype=np.float32) ]) output = tff_model.forward_pass(example_batch) self.assertAllClose(output.loss, 2.0) class CustomLoss(tf.keras.losses.Loss): def __init__(self): super().__init__(name='custom_loss') def call(self, y_true, y_pred): loss = tf.constant(0.0) for label, prediction in zip(y_true, y_pred): loss += tf.keras.losses.MeanSquaredError()(label, prediction) return loss keras_model = model_examples.build_multiple_outputs_keras_model() with self.subTest('single_custom_loss_can_work_with_multiple_outputs'): tff_model = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=CustomLoss()) output = tff_model.forward_pass(example_batch) self.assertAllClose(output.loss, 2.0) keras_model = model_examples.build_multiple_outputs_keras_model() with self.subTest('loss_weights_as_list'): tff_model = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ], loss_weights=[0.1, 0.2, 0.3]) output = tff_model.forward_pass(example_batch) self.assertAllClose(output.loss, 0.5) output = tff_model.forward_pass(example_batch) self.assertAllClose(output.loss, 0.5) with self.subTest('loss_weights_assert_fail_list'): with self.assertRaises(ValueError): _ = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ], loss_weights=[0.1, 0.2]) with self.subTest('loss_weights_assert_fail_dict'): with self.assertRaises(TypeError): _ = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ], loss_weights={ 'dense_5': 0.1, 'dense_6': 0.2, 'whimsy': 0.4 })
def test_keras_model_multiple_outputs(self): keras_model = model_examples.build_multiple_outputs_keras_model() dummy_batch = collections.OrderedDict([ ('x', [ np.zeros([1, 1], dtype=np.float32), np.zeros([1, 1], dtype=np.float32) ]), ('y', [ np.zeros([1, 1], dtype=np.float32), np.ones([1, 1], dtype=np.float32), np.ones([1, 1], dtype=np.float32) ]), ]) with self.subTest('loss_output_len_mismatch'): with self.assertRaises(ValueError): _ = keras_utils.from_keras_model( keras_model=keras_model, dummy_batch=dummy_batch, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ]) with self.subTest('invalid_loss'): with self.assertRaises(TypeError): _ = keras_utils.from_keras_model(keras_model=keras_model, dummy_batch=dummy_batch, loss=3) with self.subTest('loss_list_no_opt'): tff_model = keras_utils.from_keras_model( keras_model=keras_model, dummy_batch=dummy_batch, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ]) self.assertIsInstance(tff_model, model_utils.EnhancedModel) output = tff_model.forward_pass(dummy_batch) self.assertAllClose(output.loss, 2.0) with self.subTest('loss_dict_no_opt'): tff_model = keras_utils.from_keras_model( keras_model=keras_model, dummy_batch=dummy_batch, loss={ 'dense': tf.keras.losses.MeanSquaredError(), 'dense_1': tf.keras.losses.MeanSquaredError(), 'dense_2': tf.keras.losses.MeanSquaredError() }) self.assertIsInstance(tff_model, model_utils.EnhancedModel) output = tff_model.forward_pass(dummy_batch) self.assertAllClose(output.loss, 2.0) with self.subTest('trainable_model'): tff_model = keras_utils.from_keras_model( keras_model=keras_model, dummy_batch=dummy_batch, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ], optimizer=tf.keras.optimizers.SGD(learning_rate=0.01)) self.assertIsInstance(tff_model, model_utils.EnhancedTrainableModel) self.assertTrue(hasattr(tff_model._model._keras_model, 'optimizer')) output = tff_model.forward_pass(dummy_batch) self.assertAllClose(output.loss, 2.0) keras_model = model_examples.build_multiple_outputs_keras_model() with self.subTest('loss_weights_as_list'): tff_model = keras_utils.from_keras_model( keras_model=keras_model, dummy_batch=dummy_batch, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ], loss_weights=[0.1, 0.2, 0.3]) output = tff_model.forward_pass(dummy_batch) self.assertAllClose(output.loss, 0.5) with self.subTest('loss_weights_as_dict'): tff_model = keras_utils.from_keras_model( keras_model=keras_model, dummy_batch=dummy_batch, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ], loss_weights={ 'dense_5': 0.1, 'dense_6': 0.2, 'dense_7': 0.3 }) output = tff_model.forward_pass(dummy_batch) self.assertAllClose(output.loss, 0.5) with self.subTest('loss_weights_assert_fail_list'): with self.assertRaises(ValueError): _ = keras_utils.from_keras_model( keras_model=keras_model, dummy_batch=dummy_batch, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ], loss_weights=[0.1, 0.2]) with self.subTest('loss_weights_assert_fail_dict'): with self.assertRaises(KeyError): _ = keras_utils.from_keras_model( keras_model=keras_model, dummy_batch=dummy_batch, loss=[ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ], loss_weights={ 'dense_5': 0.1, 'dense_6': 0.2, 'dummy': 0.4 })