def test_checkpoint_sam(self): model = keras.Sequential([ keras.Input([2, 2]), keras.layers.Dense(4), keras.layers.Dense(1), ]) sam_model_1 = sharpness_aware_minimization.SharpnessAwareMinimization(model) sam_model_2 = sharpness_aware_minimization.SharpnessAwareMinimization(model) data = tf.random.uniform([1, 2, 2]) label = data[:, 0] > 0.5 sam_model_1.compile( optimizer=adam.Adam(), loss=keras.losses.BinaryCrossentropy(from_logits=True), ) sam_model_1.fit(data, label) checkpoint = tf.train.Checkpoint(sam_model_1) checkpoint2 = tf.train.Checkpoint(sam_model_2) temp_dir = self.get_temp_dir() save_path = checkpoint.save(temp_dir) checkpoint2.restore(save_path) self.assertAllClose(sam_model_1(data), sam_model_2(data))
def test_save_sam(self): model = keras.Sequential( [ keras.Input([2, 2]), keras.layers.Dense(4), keras.layers.Dense(1), ] ) sam_model = sharpness_aware_minimization.SharpnessAwareMinimization( model ) data = tf.random.uniform([1, 2, 2]) label = data[:, 0] > 0.5 sam_model.compile( optimizer=adam.Adam(), loss=keras.losses.BinaryCrossentropy(from_logits=True), ) sam_model.fit(data, label) path = os.path.join(self.get_temp_dir(), "model") sam_model.save(path) loaded_sam_model = keras.models.load_model(path) loaded_sam_model.load_weights(path) self.assertAllClose(sam_model(data), loaded_sam_model(data))
def test_sam_model_call(self): model = keras.Sequential([ keras.Input([2, 2]), keras.layers.Dense(4), ]) sam_model = sharpness_aware_minimization.SharpnessAwareMinimization(model) data = tf.random.uniform([2, 2]) self.assertAllClose(model(data), sam_model(data))
def test_sam_model_fit(self, strategy): with strategy.scope(): model = keras.Sequential([ keras.Input([2, 2]), keras.layers.Dense(4), keras.layers.Dense(1), ]) sam_model = sharpness_aware_minimization.SharpnessAwareMinimization(model) data = tf.random.uniform([2, 2]) label = data[:, 0] > 0.5 sam_model.compile( optimizer=adam.Adam(), loss=keras.losses.BinaryCrossentropy(from_logits=True), ) sam_model.fit(data, label, steps_per_epoch=1)