def test_checkpoint_sam(self):
    model = keras.Sequential([
        keras.Input([2, 2]),
        keras.layers.Dense(4),
        keras.layers.Dense(1),
    ])
    sam_model_1 = sharpness_aware_minimization.SharpnessAwareMinimization(model)
    sam_model_2 = sharpness_aware_minimization.SharpnessAwareMinimization(model)
    data = tf.random.uniform([1, 2, 2])
    label = data[:, 0] > 0.5

    sam_model_1.compile(
        optimizer=adam.Adam(),
        loss=keras.losses.BinaryCrossentropy(from_logits=True),
    )

    sam_model_1.fit(data, label)

    checkpoint = tf.train.Checkpoint(sam_model_1)
    checkpoint2 = tf.train.Checkpoint(sam_model_2)
    temp_dir = self.get_temp_dir()
    save_path = checkpoint.save(temp_dir)
    checkpoint2.restore(save_path)

    self.assertAllClose(sam_model_1(data), sam_model_2(data))
Exemple #2
0
    def test_save_sam(self):
        model = keras.Sequential(
            [
                keras.Input([2, 2]),
                keras.layers.Dense(4),
                keras.layers.Dense(1),
            ]
        )
        sam_model = sharpness_aware_minimization.SharpnessAwareMinimization(
            model
        )
        data = tf.random.uniform([1, 2, 2])
        label = data[:, 0] > 0.5

        sam_model.compile(
            optimizer=adam.Adam(),
            loss=keras.losses.BinaryCrossentropy(from_logits=True),
        )

        sam_model.fit(data, label)

        path = os.path.join(self.get_temp_dir(), "model")
        sam_model.save(path)
        loaded_sam_model = keras.models.load_model(path)
        loaded_sam_model.load_weights(path)

        self.assertAllClose(sam_model(data), loaded_sam_model(data))
 def test_sam_model_call(self):
   model = keras.Sequential([
       keras.Input([2, 2]),
       keras.layers.Dense(4),
   ])
   sam_model = sharpness_aware_minimization.SharpnessAwareMinimization(model)
   data = tf.random.uniform([2, 2])
   self.assertAllClose(model(data), sam_model(data))
  def test_sam_model_fit(self, strategy):
    with strategy.scope():
      model = keras.Sequential([
          keras.Input([2, 2]),
          keras.layers.Dense(4),
          keras.layers.Dense(1),
      ])
      sam_model = sharpness_aware_minimization.SharpnessAwareMinimization(model)
      data = tf.random.uniform([2, 2])
      label = data[:, 0] > 0.5

      sam_model.compile(
          optimizer=adam.Adam(),
          loss=keras.losses.BinaryCrossentropy(from_logits=True),
      )

      sam_model.fit(data, label, steps_per_epoch=1)