def model_fn(): # Note: we don't compile with an optimizer here; FedSGD does not use it. keras_model = build_keras_model_fn(feature_dims=2) return model_utils.from_keras_model( keras_model, dummy_batch, loss=tf.keras.losses.MeanSquaredError())
def test_dummy_batch_types(self, dummy_batch): keras_model = model_examples.build_linear_regresion_keras_functional_model( feature_dims=1) tff_model = model_utils.from_keras_model( keras_model=keras_model, dummy_batch=dummy_batch, loss=tf.keras.losses.MeanSquaredError()) self.assertIsInstance(tff_model, model_utils.EnhancedModel)
def test_keras_model_and_optimizer(self): # Expect TFF to compile the keras model if given an optimizer. keras_model = model_examples.build_linear_regresion_keras_functional_model( feature_dims=1) tff_model = model_utils.from_keras_model( keras_model=keras_model, dummy_batch=_create_dummy_batch(1), loss=tf.keras.losses.MeanSquaredError(), optimizer=gradient_descent.SGD(learning_rate=0.01)) self.assertIsInstance(tff_model, model_utils.EnhancedTrainableModel) # pylint: disable=internal-access self.assertTrue(hasattr(tff_model._model._keras_model, 'optimizer'))
def test_tff_model_from_keras_model(self, feature_dims, model_fn): keras_model = model_fn(feature_dims) tff_model = model_utils.from_keras_model( keras_model=keras_model, dummy_batch=_create_dummy_batch(feature_dims), loss=tf.keras.losses.MeanSquaredError(), metrics=[NumBatchesCounter(), NumExamplesCounter()]) self.assertIsInstance(tff_model, model_utils.EnhancedModel) # Metrics should be zero, though the model wrapper internally executes the # forward pass once. self.assertSequenceEqual(self.evaluate(tff_model.local_variables), [0, 0, 0.0, 0.0]) batch = { 'x': np.stack([ np.zeros(feature_dims, np.float32), np.ones(feature_dims, np.float32) ]), 'y': [[0.0], [1.0]], } # from_model() was called without an optimizer which creates a tff.Model. # There is no train_on_batch() method available in tff.Model. with self.assertRaisesRegex(AttributeError, 'no attribute \'train_on_batch\''): tff_model.train_on_batch(batch) output = tff_model.forward_pass(batch) # Since the model initializes all weights and biases to zero, we expect # all predictions to be zero: # 0*x1 + 0*x2 + ... + 0 = 0 self.assertAllEqual(output.predictions, [[0.0], [0.0]]) # For the single batch: # # Example | Prediction | Label | Residual | Loss # --------+------------+-------+----------+ ----- # 1 | 0.0 | 0.0 | 0.0 | 0.0 # 2 | 0.0 | 1.0 | 1.0 | 1.0 # # Total loss: 1.0 # Batch average loss: 0.5 self.assertEqual(self.evaluate(output.loss), 0.5) metrics = self.evaluate(tff_model.report_local_outputs()) self.assertEqual(metrics['num_batches'], [1]) self.assertEqual(metrics['num_examples'], [2]) self.assertGreater(metrics['loss'][0], 0) self.assertEqual(metrics['loss'][1], 2)
def test_non_keras_model(self): with self.assertRaisesRegexp(TypeError, r'keras\..*\.Model'): model_utils.from_keras_model( keras_model=0, # not a tf.keras.Model dummy_batch=_create_dummy_batch(1), loss=tf.keras.losses.MeanSquaredError())