Пример #1
0
    def test_keras_model_multiple_outputs(self):
        keras_model = model_examples.build_multiple_outputs_keras_model()
        input_spec = collections.OrderedDict(
            x=[
                tf.TensorSpec(shape=[None, 1], dtype=tf.float32),
                tf.TensorSpec(shape=[None, 1], dtype=tf.float32)
            ],
            y=[
                tf.TensorSpec(shape=[None, 1], dtype=tf.float32),
                tf.TensorSpec(shape=[None, 1], dtype=tf.float32),
                tf.TensorSpec(shape=[None, 1], dtype=tf.float32)
            ])

        with self.subTest('loss_output_len_mismatch'):
            with self.assertRaises(ValueError):
                _ = keras_utils.from_keras_model(
                    keras_model=keras_model,
                    input_spec=input_spec,
                    loss=[
                        tf.keras.losses.MeanSquaredError(),
                        tf.keras.losses.MeanSquaredError()
                    ])

        with self.subTest('invalid_loss'):
            with self.assertRaises(TypeError):
                _ = keras_utils.from_keras_model(keras_model=keras_model,
                                                 input_spec=input_spec,
                                                 loss=3)

        with self.subTest('loss_list_no_opt'):
            tff_model = keras_utils.from_keras_model(
                keras_model=keras_model,
                input_spec=input_spec,
                loss=[
                    tf.keras.losses.MeanSquaredError(),
                    tf.keras.losses.MeanSquaredError(),
                    tf.keras.losses.MeanSquaredError()
                ])

            self.assertIsInstance(tff_model, model_utils.EnhancedModel)
            dummy_batch = collections.OrderedDict(
                x=[
                    np.zeros([1, 1], dtype=np.float32),
                    np.zeros([1, 1], dtype=np.float32)
                ],
                y=[
                    np.zeros([1, 1], dtype=np.float32),
                    np.ones([1, 1], dtype=np.float32),
                    np.ones([1, 1], dtype=np.float32)
                ])
            output = tff_model.forward_pass(dummy_batch)
            self.assertAllClose(output.loss, 2.0)

        keras_model = model_examples.build_multiple_outputs_keras_model()
        with self.subTest('loss_weights_as_list'):
            tff_model = keras_utils.from_keras_model(
                keras_model=keras_model,
                input_spec=input_spec,
                loss=[
                    tf.keras.losses.MeanSquaredError(),
                    tf.keras.losses.MeanSquaredError(),
                    tf.keras.losses.MeanSquaredError()
                ],
                loss_weights=[0.1, 0.2, 0.3])

            output = tff_model.forward_pass(dummy_batch)
            self.assertAllClose(output.loss, 0.5)

            output = tff_model.forward_pass(dummy_batch)
            self.assertAllClose(output.loss, 0.5)

        with self.subTest('loss_weights_assert_fail_list'):
            with self.assertRaises(ValueError):
                _ = keras_utils.from_keras_model(
                    keras_model=keras_model,
                    input_spec=input_spec,
                    loss=[
                        tf.keras.losses.MeanSquaredError(),
                        tf.keras.losses.MeanSquaredError(),
                        tf.keras.losses.MeanSquaredError()
                    ],
                    loss_weights=[0.1, 0.2])

        with self.subTest('loss_weights_assert_fail_dict'):
            with self.assertRaises(TypeError):
                _ = keras_utils.from_keras_model(
                    keras_model=keras_model,
                    input_spec=input_spec,
                    loss=[
                        tf.keras.losses.MeanSquaredError(),
                        tf.keras.losses.MeanSquaredError(),
                        tf.keras.losses.MeanSquaredError()
                    ],
                    loss_weights={
                        'dense_5': 0.1,
                        'dense_6': 0.2,
                        'dummy': 0.4
                    })
Пример #2
0
  def test_keras_model_multiple_outputs(self, input_spec):
    keras_model = model_examples.build_multiple_outputs_keras_model()

    with self.subTest('loss_output_len_mismatch'):
      with self.assertRaises(ValueError):
        _ = keras_utils.from_keras_model(
            keras_model=keras_model,
            input_spec=input_spec,
            loss=[
                tf.keras.losses.MeanSquaredError(),
                tf.keras.losses.MeanSquaredError()
            ])

    with self.subTest('invalid_loss'):
      with self.assertRaises(TypeError):
        _ = keras_utils.from_keras_model(
            keras_model=keras_model, input_spec=input_spec, loss=3)

    with self.subTest('loss_as_dict_fails'):
      with self.assertRaises(TypeError):
        _ = keras_utils.from_keras_model(
            keras_model=keras_model,
            input_spec=input_spec,
            loss={
                'dense_5': tf.keras.losses.MeanSquaredError(),
                'dense_6': tf.keras.losses.MeanSquaredError(),
                'whimsy': tf.keras.losses.MeanSquaredError()
            })

    with self.subTest('loss_list_no_opt'):
      tff_model = keras_utils.from_keras_model(
          keras_model=keras_model,
          input_spec=input_spec,
          loss=[
              tf.keras.losses.MeanSquaredError(),
              tf.keras.losses.MeanSquaredError(),
              tf.keras.losses.MeanSquaredError()
          ])

      self.assertIsInstance(tff_model, model_utils.EnhancedModel)
      example_batch = collections.OrderedDict(
          x=[
              np.zeros([1, 1], dtype=np.float32),
              np.zeros([1, 1], dtype=np.float32)
          ],
          y=[
              np.zeros([1, 1], dtype=np.float32),
              np.ones([1, 1], dtype=np.float32),
              np.ones([1, 1], dtype=np.float32)
          ])
      output = tff_model.forward_pass(example_batch)
      self.assertAllClose(output.loss, 2.0)

    class CustomLoss(tf.keras.losses.Loss):

      def __init__(self):
        super().__init__(name='custom_loss')

      def call(self, y_true, y_pred):
        loss = tf.constant(0.0)
        for label, prediction in zip(y_true, y_pred):
          loss += tf.keras.losses.MeanSquaredError()(label, prediction)
        return loss

    keras_model = model_examples.build_multiple_outputs_keras_model()
    with self.subTest('single_custom_loss_can_work_with_multiple_outputs'):
      tff_model = keras_utils.from_keras_model(
          keras_model=keras_model, input_spec=input_spec, loss=CustomLoss())

      output = tff_model.forward_pass(example_batch)
      self.assertAllClose(output.loss, 2.0)

    keras_model = model_examples.build_multiple_outputs_keras_model()
    with self.subTest('loss_weights_as_list'):
      tff_model = keras_utils.from_keras_model(
          keras_model=keras_model,
          input_spec=input_spec,
          loss=[
              tf.keras.losses.MeanSquaredError(),
              tf.keras.losses.MeanSquaredError(),
              tf.keras.losses.MeanSquaredError()
          ],
          loss_weights=[0.1, 0.2, 0.3])

      output = tff_model.forward_pass(example_batch)
      self.assertAllClose(output.loss, 0.5)

      output = tff_model.forward_pass(example_batch)
      self.assertAllClose(output.loss, 0.5)

    with self.subTest('loss_weights_assert_fail_list'):
      with self.assertRaises(ValueError):
        _ = keras_utils.from_keras_model(
            keras_model=keras_model,
            input_spec=input_spec,
            loss=[
                tf.keras.losses.MeanSquaredError(),
                tf.keras.losses.MeanSquaredError(),
                tf.keras.losses.MeanSquaredError()
            ],
            loss_weights=[0.1, 0.2])

    with self.subTest('loss_weights_assert_fail_dict'):
      with self.assertRaises(TypeError):
        _ = keras_utils.from_keras_model(
            keras_model=keras_model,
            input_spec=input_spec,
            loss=[
                tf.keras.losses.MeanSquaredError(),
                tf.keras.losses.MeanSquaredError(),
                tf.keras.losses.MeanSquaredError()
            ],
            loss_weights={
                'dense_5': 0.1,
                'dense_6': 0.2,
                'whimsy': 0.4
            })
Пример #3
0
    def test_keras_model_multiple_outputs(self):
        keras_model = model_examples.build_multiple_outputs_keras_model()
        dummy_batch = collections.OrderedDict([
            ('x', [
                np.zeros([1, 1], dtype=np.float32),
                np.zeros([1, 1], dtype=np.float32)
            ]),
            ('y', [
                np.zeros([1, 1], dtype=np.float32),
                np.ones([1, 1], dtype=np.float32),
                np.ones([1, 1], dtype=np.float32)
            ]),
        ])

        with self.subTest('loss_output_len_mismatch'):
            with self.assertRaises(ValueError):
                _ = keras_utils.from_keras_model(
                    keras_model=keras_model,
                    dummy_batch=dummy_batch,
                    loss=[
                        tf.keras.losses.MeanSquaredError(),
                        tf.keras.losses.MeanSquaredError()
                    ])

        with self.subTest('invalid_loss'):
            with self.assertRaises(TypeError):
                _ = keras_utils.from_keras_model(keras_model=keras_model,
                                                 dummy_batch=dummy_batch,
                                                 loss=3)

        with self.subTest('loss_list_no_opt'):
            tff_model = keras_utils.from_keras_model(
                keras_model=keras_model,
                dummy_batch=dummy_batch,
                loss=[
                    tf.keras.losses.MeanSquaredError(),
                    tf.keras.losses.MeanSquaredError(),
                    tf.keras.losses.MeanSquaredError()
                ])

            self.assertIsInstance(tff_model, model_utils.EnhancedModel)
            output = tff_model.forward_pass(dummy_batch)
            self.assertAllClose(output.loss, 2.0)

        with self.subTest('loss_dict_no_opt'):
            tff_model = keras_utils.from_keras_model(
                keras_model=keras_model,
                dummy_batch=dummy_batch,
                loss={
                    'dense': tf.keras.losses.MeanSquaredError(),
                    'dense_1': tf.keras.losses.MeanSquaredError(),
                    'dense_2': tf.keras.losses.MeanSquaredError()
                })

            self.assertIsInstance(tff_model, model_utils.EnhancedModel)
            output = tff_model.forward_pass(dummy_batch)
            self.assertAllClose(output.loss, 2.0)

        with self.subTest('trainable_model'):
            tff_model = keras_utils.from_keras_model(
                keras_model=keras_model,
                dummy_batch=dummy_batch,
                loss=[
                    tf.keras.losses.MeanSquaredError(),
                    tf.keras.losses.MeanSquaredError(),
                    tf.keras.losses.MeanSquaredError()
                ],
                optimizer=tf.keras.optimizers.SGD(learning_rate=0.01))

            self.assertIsInstance(tff_model,
                                  model_utils.EnhancedTrainableModel)
            self.assertTrue(hasattr(tff_model._model._keras_model,
                                    'optimizer'))
            output = tff_model.forward_pass(dummy_batch)
            self.assertAllClose(output.loss, 2.0)

        keras_model = model_examples.build_multiple_outputs_keras_model()
        with self.subTest('loss_weights_as_list'):
            tff_model = keras_utils.from_keras_model(
                keras_model=keras_model,
                dummy_batch=dummy_batch,
                loss=[
                    tf.keras.losses.MeanSquaredError(),
                    tf.keras.losses.MeanSquaredError(),
                    tf.keras.losses.MeanSquaredError()
                ],
                loss_weights=[0.1, 0.2, 0.3])

            output = tff_model.forward_pass(dummy_batch)
            self.assertAllClose(output.loss, 0.5)

        with self.subTest('loss_weights_as_dict'):
            tff_model = keras_utils.from_keras_model(
                keras_model=keras_model,
                dummy_batch=dummy_batch,
                loss=[
                    tf.keras.losses.MeanSquaredError(),
                    tf.keras.losses.MeanSquaredError(),
                    tf.keras.losses.MeanSquaredError()
                ],
                loss_weights={
                    'dense_5': 0.1,
                    'dense_6': 0.2,
                    'dense_7': 0.3
                })

            output = tff_model.forward_pass(dummy_batch)
            self.assertAllClose(output.loss, 0.5)

        with self.subTest('loss_weights_assert_fail_list'):
            with self.assertRaises(ValueError):
                _ = keras_utils.from_keras_model(
                    keras_model=keras_model,
                    dummy_batch=dummy_batch,
                    loss=[
                        tf.keras.losses.MeanSquaredError(),
                        tf.keras.losses.MeanSquaredError(),
                        tf.keras.losses.MeanSquaredError()
                    ],
                    loss_weights=[0.1, 0.2])

        with self.subTest('loss_weights_assert_fail_dict'):
            with self.assertRaises(KeyError):
                _ = keras_utils.from_keras_model(
                    keras_model=keras_model,
                    dummy_batch=dummy_batch,
                    loss=[
                        tf.keras.losses.MeanSquaredError(),
                        tf.keras.losses.MeanSquaredError(),
                        tf.keras.losses.MeanSquaredError()
                    ],
                    loss_weights={
                        'dense_5': 0.1,
                        'dense_6': 0.2,
                        'dummy': 0.4
                    })