def testConfigForMultipleApplications(self):
    original_model = tf.keras.Sequential([
        tf.keras.layers.Dense(1, activation="softmax"),
    ])

    # Arbitrary losses.
    loss = {"k1": losses.MMDLoss(), "k2": losses.AbsoluteCorrelationLoss()}
    # Arbitrary values.
    loss_weight = {"k1": 3.4, "k2": 2.9}
    model = min_diff_model.MinDiffModel(original_model, loss, loss_weight)

    config = model.get_config()

    self.assertSetEqual(
        set(config.keys()),
        set(["original_model", "loss", "loss_weight", "name"]))

    self.assertDictEqual(config["loss"], loss)
    self.assertDictEqual(config["loss_weight"], loss_weight)

    # Test building the model from the config.
    model_from_config = min_diff_model.MinDiffModel.from_config(config)
    self.assertIsInstance(model_from_config, min_diff_model.MinDiffModel)
    self.assertIsInstance(model_from_config.original_model, tf.keras.Sequential)

    self.assertDictEqual(model_from_config._loss, loss)
    self.assertDictEqual(model_from_config._loss_weight, loss_weight)
  def testSerializationWithTransformAndKernel(self):
    original_model = tf.keras.Sequential([
        tf.keras.layers.Dense(1, activation="softmax"),
    ])
    predictions_fn = lambda x: x * 5.1  # Arbitrary operation.

    loss_weight = 2.3  # Arbitrary value.
    model_name = "custom_model_name"  # Arbitrary name.
    model = min_diff_model.MinDiffModel(
        original_model,
        losses.MMDLoss("laplacian"),  # Non-default Kernel.
        loss_weight=loss_weight,
        predictions_transform=predictions_fn,
        name=model_name)

    serialized_model = tf.keras.utils.serialize_keras_object(model)
    deserialized_model = tf.keras.layers.deserialize(serialized_model)

    self.assertIsInstance(deserialized_model, min_diff_model.MinDiffModel)
    val = 7  # Arbitrary value.
    self.assertEqual(
        deserialized_model._predictions_transform(val), predictions_fn(val))
    self.assertIsInstance(deserialized_model._loss, losses.MMDLoss)
    self.assertIsInstance(deserialized_model._loss.predictions_kernel,
                          losses.LaplacianKernel)
    self.assertEqual(deserialized_model._loss_weight, loss_weight)
    self.assertEqual(deserialized_model.name, model_name)
  def testMinDiffModelRaisesErrorWithBadKwarg(self):
    original_model = tf.keras.Sequential(
        [tf.keras.layers.Dense(1, activation="softmax")])

    with self.assertRaisesRegex(
        TypeError, "problem initializing the MinDiffModel instance"):
      _ = min_diff_model.MinDiffModel(
          original_model, losses.MMDLoss(), bad_kwarg="some value")
  def testMultipleApplicationsEvalOutputs(self):
    original_model = tf.keras.Sequential(
        [tf.keras.layers.Dense(1, activation="softmax")])
    model = min_diff_model.MinDiffModel(original_model, {
        "k1": losses.MMDLoss(),
        "k2": losses.MMDLoss()
    })

    model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["acc"])

    # Evaluate with min_diff_data.
    output_metrics = model.test_step(
        iter(self.multi_min_diff_dataset).get_next())
    self.assertSetEqual(
        set(output_metrics.keys()),
        set(["loss", "acc", "k1_min_diff_loss", "k2_min_diff_loss"]))

    # Evaluate without min_diff_data.
    output_metrics = model.test_step(iter(self.original_dataset).get_next())
    self.assertSetEqual(set(output_metrics.keys()), set(["loss", "acc"]))
  def testMultipleApplicationsWithSequentialModel(self):
    original_model = tf.keras.Sequential(
        [tf.keras.layers.Dense(1, activation="softmax")])
    model = min_diff_model.MinDiffModel(original_model, {
        "k1": losses.MMDLoss(),
        "k2": losses.MMDLoss()
    })

    model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["acc"])

    history = model.fit(self.multi_min_diff_dataset)
    self.assertSetEqual(
        set(history.history.keys()),
        set(["loss", "acc", "k1_min_diff_loss", "k2_min_diff_loss"]))

    # Evaluate with min_diff_data.
    model.evaluate(self.multi_min_diff_dataset)

    # Evaluate and run inference without min_diff_data.
    model.evaluate(self.original_dataset)
    model.predict(self.original_dataset)
  def testGetConfigErrorFromOriginalModel(self):

    class CustomModel(tf.keras.Model):
      pass  # No need to add any other implementation for this test.

    original_model = CustomModel()

    model = min_diff_model.MinDiffModel(original_model, losses.MMDLoss())

    with self.assertRaisesRegex(
        NotImplementedError, "MinDiffModel cannot create a config.*"
        "original_model.*not implemented get_config.*or has an error"):
      _ = model.get_config()
  def testConfig(self):
    original_model = tf.keras.Sequential([
        tf.keras.layers.Dense(1, activation="softmax"),
    ])

    model = min_diff_model.MinDiffModel(original_model, losses.MMDLoss())

    config = model.get_config()

    self.assertSetEqual(
        set(config.keys()),
        set(["original_model", "loss", "loss_weight", "name"]))

    # Test building the model from the config.
    model_from_config = min_diff_model.MinDiffModel.from_config(config)
    self.assertIsInstance(model_from_config, min_diff_model.MinDiffModel)
    self.assertIsInstance(model_from_config.original_model, tf.keras.Sequential)
  def testWithFunctionalModel(self):
    inputs = tf.keras.Input(1)
    outputs = tf.keras.layers.Dense(1, activation="softmax")(inputs)
    original_model = tf.keras.Model(inputs=inputs, outputs=outputs)

    model = min_diff_model.MinDiffModel(original_model, losses.MMDLoss())

    model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["acc"])

    history = model.fit(self.min_diff_dataset)
    self.assertSetEqual(
        set(history.history.keys()), set(["loss", "acc", "min_diff_loss"]))

    # Evaluate with min_diff_data.
    model.evaluate(self.min_diff_dataset)

    # Evaluate and run inference without min_diff_data.
    model.evaluate(self.original_dataset)
    model.predict(self.original_dataset)
  def testSerialization(self):
    original_model = tf.keras.Sequential([
        tf.keras.layers.Dense(1, activation="softmax"),
    ])
    loss_weight = 2.3  # Arbitrary value.
    model_name = "custom_model_name"  # Arbitrary name.
    model = min_diff_model.MinDiffModel(
        original_model,
        losses.MMDLoss(),
        loss_weight=loss_weight,
        name=model_name)

    serialized_model = tf.keras.utils.serialize_keras_object(model)
    deserialized_model = tf.keras.layers.deserialize(serialized_model)

    self.assertIsInstance(deserialized_model, min_diff_model.MinDiffModel)
    self.assertIsNone(deserialized_model._predictions_transform)
    self.assertIsInstance(deserialized_model._loss, losses.MMDLoss)
    self.assertEqual(deserialized_model._loss_weight, loss_weight)
    self.assertEqual(deserialized_model.name, model_name)
  def testSerializationForMultipleApplications(self):
    original_model = tf.keras.Sequential([
        tf.keras.layers.Dense(1, activation="softmax"),
    ])

    # Arbitrary losses.
    loss = {"k1": losses.MMDLoss(), "k2": losses.AbsoluteCorrelationLoss()}
    # Arbitrary values.
    loss_weight = {"k1": 3.4, "k2": 2.9}
    model_name = "custom_model_name"  # Arbitrary name.
    model = min_diff_model.MinDiffModel(
        original_model, loss, loss_weight, name=model_name)

    serialized_model = tf.keras.utils.serialize_keras_object(model)
    deserialized_model = tf.keras.layers.deserialize(serialized_model)

    self.assertIsInstance(deserialized_model, min_diff_model.MinDiffModel)
    self.assertIsNone(deserialized_model._predictions_transform)

    self.assertDictEqual(deserialized_model._loss, loss)
    self.assertDictEqual(deserialized_model._loss_weight, loss_weight)
    self.assertEqual(deserialized_model.name, model_name)
  def testConfigWithCustomBaseImplementation(self):

    class CustomModel(tf.keras.Model):

      def __init__(self, val, **kwargs):
        super(CustomModel, self).__init__(**kwargs)
        self.val = val

      def get_config(self):
        return {"val": self.val}

    class CustomMinDiffModel(min_diff_model.MinDiffModel, CustomModel):
      pass  # No additional implementation needed.

    original_val = 4  # Arbitrary value passed in.
    original_model = CustomModel(original_val)

    min_diff_model_val = 5  # Different arbitrary value passed in.
    model = CustomMinDiffModel(
        original_model=original_model,
        loss=losses.MMDLoss(),
        val=min_diff_model_val)
    self.assertEqual(model.val, min_diff_model_val)

    config = model.get_config()

    self.assertSetEqual(
        set(config.keys()),
        set(["original_model", "loss", "loss_weight", "name", "val"]))
    self.assertEqual(config["val"], model.val)
    self.assertEqual(config["original_model"].val, original_model.val)

    # Test building the model from the config.
    model_from_config = CustomMinDiffModel.from_config(config)
    self.assertIsInstance(model_from_config, CustomMinDiffModel)
    self.assertIsInstance(model_from_config.original_model, CustomModel)
    self.assertEqual(model_from_config.val, min_diff_model_val)
    self.assertEqual(model_from_config.original_model.val, original_val)
  def testSaveForContinuedTraining(self):
    original_model = tf.keras.Sequential([
        tf.keras.layers.Dense(1, activation="softmax"),
    ])
    model = min_diff_model.MinDiffModel(original_model, losses.MMDLoss())

    model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["acc"])

    _ = model.fit(self.min_diff_dataset)

    with tempfile.TemporaryDirectory() as tmp:
      path = os.path.join(tmp, "saved_model")
      model.save(path)

      loaded_model = tf.keras.models.load_model(path)

    self.assertIsInstance(loaded_model, min_diff_model.MinDiffModel)

    # Run more training on loaded_model.
    loaded_model.fit(self.min_diff_dataset)

    # Run inference on loaded_model.
    loaded_model.predict(self.original_dataset)