def test_keras_model_lookup_table(self): model = model_examples.build_lookup_table_keras_model() model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01), loss=tf.keras.losses.MSE, metrics=[NumBatchesCounter(), NumExamplesCounter()]) dummy_batch = collections.OrderedDict([ ('x', tf.constant([['G']], dtype=tf.string)), ('y', tf.zeros([1, 1], dtype=tf.float32)), ]) tff_model = keras_utils.from_compiled_keras_model( keras_model=model, dummy_batch=dummy_batch) batch_size = 3 batch = { 'x': tf.constant([['G'], ['B'], ['R']], dtype=tf.string), 'y': tf.constant([[1.0], [2.0], [3.0]], dtype=tf.float32), } num_iterations = 2 for _ in range(num_iterations): self.evaluate(tff_model.train_on_batch(batch)) metrics = self.evaluate(tff_model.report_local_outputs()) self.assertEqual(metrics['num_batches'], [num_iterations]) self.assertEqual(metrics['num_examples'], [batch_size * num_iterations]) self.assertGreater(metrics['loss'][0], 0.0) self.assertEqual(metrics['loss'][1], batch_size * num_iterations) # Ensure we can assign the FL trained model weights to a new model. tff_weights = model_utils.ModelWeights.from_model(tff_model) keras_model = model_examples.build_lookup_table_keras_model() tff_weights.assign_weights_to(keras_model) keras_model.compile( optimizer=tf.keras.optimizers.SGD(learning_rate=0.01), loss=tf.keras.losses.MSE, metrics=[NumBatchesCounter(), NumExamplesCounter()]) loaded_model = keras_utils.from_compiled_keras_model( keras_model=keras_model, dummy_batch=dummy_batch) orig_model_output = tff_model.forward_pass(batch) loaded_model_output = loaded_model.forward_pass(batch) self.assertAlmostEqual(self.evaluate(orig_model_output.loss), self.evaluate(loaded_model_output.loss))
def model_fn(): keras_model = model_examples.build_lookup_table_keras_model() return keras_utils.from_keras_model( keras_model, loss=tf.keras.losses.MeanSquaredError(), input_spec=ds.element_spec, metrics=[NumExamplesCounter()])
def test_keras_model_lookup_table(self): model = model_examples.build_lookup_table_keras_model() input_spec = collections.OrderedDict(x=tf.TensorSpec(shape=[None, 1], dtype=tf.string), y=tf.TensorSpec(shape=[None, 1], dtype=tf.float32)) tff_model = keras_utils.from_keras_model( keras_model=model, input_spec=input_spec, loss=tf.keras.losses.MeanSquaredError(), metrics=[NumBatchesCounter(), NumExamplesCounter()]) batch_size = 3 batch = collections.OrderedDict(x=tf.constant([['G'], ['B'], ['R']], dtype=tf.string), y=tf.constant([[1.0], [2.0], [3.0]], dtype=tf.float32)) num_train_steps = 2 for _ in range(num_train_steps): self.evaluate(tff_model.forward_pass(batch)) metrics = self.evaluate(tff_model.report_local_outputs()) self.assertEqual(metrics['num_batches'], [num_train_steps]) self.assertEqual(metrics['num_examples'], [batch_size * num_train_steps]) self.assertGreater(metrics['loss'][0], 0.0) self.assertEqual(metrics['loss'][1], batch_size * num_train_steps) # Ensure we can assign the FL trained model weights to a new model. tff_weights = model_utils.ModelWeights.from_model(tff_model) keras_model = model_examples.build_lookup_table_keras_model() tff_weights.assign_weights_to(keras_model) loaded_model = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=tf.keras.losses.MeanSquaredError(), metrics=[NumBatchesCounter(), NumExamplesCounter()]) orig_model_output = tff_model.forward_pass(batch) loaded_model_output = loaded_model.forward_pass(batch) self.assertAlmostEqual(self.evaluate(orig_model_output.loss), self.evaluate(loaded_model_output.loss))
def model_fn(): dummy_batch = collections.OrderedDict( x=tf.constant([['R']], tf.string), y=tf.zeros([1, 1], tf.float32)) keras_model = model_examples.build_lookup_table_keras_model() return keras_utils.from_keras_model( keras_model, dummy_batch, loss=tf.keras.losses.MeanSquaredError(), metrics=[])
def model_fn(): dummy_batch = collections.OrderedDict([ ('x', tf.constant([['R']], tf.string)), ('y', tf.zeros([1, 1], tf.float32)), ]) keras_model = model_examples.build_lookup_table_keras_model() keras_model.compile( optimizer=tf.keras.optimizers.SGD(learning_rate=0.1), loss=tf.keras.losses.MeanSquaredError(), metrics=[]) return keras_utils.from_compiled_keras_model( keras_model, dummy_batch)