Ejemplo n.º 1
0
    def test_keras_model_lookup_table(self):
        model = model_examples.build_lookup_table_keras_model()

        model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
                      loss=tf.keras.losses.MSE,
                      metrics=[NumBatchesCounter(),
                               NumExamplesCounter()])

        dummy_batch = collections.OrderedDict([
            ('x', tf.constant([['G']], dtype=tf.string)),
            ('y', tf.zeros([1, 1], dtype=tf.float32)),
        ])
        tff_model = keras_utils.from_compiled_keras_model(
            keras_model=model, dummy_batch=dummy_batch)

        batch_size = 3
        batch = {
            'x': tf.constant([['G'], ['B'], ['R']], dtype=tf.string),
            'y': tf.constant([[1.0], [2.0], [3.0]], dtype=tf.float32),
        }

        num_iterations = 2
        for _ in range(num_iterations):
            self.evaluate(tff_model.train_on_batch(batch))

        metrics = self.evaluate(tff_model.report_local_outputs())
        self.assertEqual(metrics['num_batches'], [num_iterations])
        self.assertEqual(metrics['num_examples'],
                         [batch_size * num_iterations])
        self.assertGreater(metrics['loss'][0], 0.0)
        self.assertEqual(metrics['loss'][1], batch_size * num_iterations)

        # Ensure we can assign the FL trained model weights to a new model.
        tff_weights = model_utils.ModelWeights.from_model(tff_model)
        keras_model = model_examples.build_lookup_table_keras_model()
        tff_weights.assign_weights_to(keras_model)
        keras_model.compile(
            optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
            loss=tf.keras.losses.MSE,
            metrics=[NumBatchesCounter(),
                     NumExamplesCounter()])
        loaded_model = keras_utils.from_compiled_keras_model(
            keras_model=keras_model, dummy_batch=dummy_batch)

        orig_model_output = tff_model.forward_pass(batch)
        loaded_model_output = loaded_model.forward_pass(batch)
        self.assertAlmostEqual(self.evaluate(orig_model_output.loss),
                               self.evaluate(loaded_model_output.loss))
Ejemplo n.º 2
0
 def model_fn():
     keras_model = model_examples.build_lookup_table_keras_model()
     return keras_utils.from_keras_model(
         keras_model,
         loss=tf.keras.losses.MeanSquaredError(),
         input_spec=ds.element_spec,
         metrics=[NumExamplesCounter()])
Ejemplo n.º 3
0
    def test_keras_model_lookup_table(self):
        model = model_examples.build_lookup_table_keras_model()
        input_spec = collections.OrderedDict(x=tf.TensorSpec(shape=[None, 1],
                                                             dtype=tf.string),
                                             y=tf.TensorSpec(shape=[None, 1],
                                                             dtype=tf.float32))
        tff_model = keras_utils.from_keras_model(
            keras_model=model,
            input_spec=input_spec,
            loss=tf.keras.losses.MeanSquaredError(),
            metrics=[NumBatchesCounter(),
                     NumExamplesCounter()])

        batch_size = 3
        batch = collections.OrderedDict(x=tf.constant([['G'], ['B'], ['R']],
                                                      dtype=tf.string),
                                        y=tf.constant([[1.0], [2.0], [3.0]],
                                                      dtype=tf.float32))

        num_train_steps = 2
        for _ in range(num_train_steps):
            self.evaluate(tff_model.forward_pass(batch))

        metrics = self.evaluate(tff_model.report_local_outputs())
        self.assertEqual(metrics['num_batches'], [num_train_steps])
        self.assertEqual(metrics['num_examples'],
                         [batch_size * num_train_steps])
        self.assertGreater(metrics['loss'][0], 0.0)
        self.assertEqual(metrics['loss'][1], batch_size * num_train_steps)

        # Ensure we can assign the FL trained model weights to a new model.
        tff_weights = model_utils.ModelWeights.from_model(tff_model)
        keras_model = model_examples.build_lookup_table_keras_model()
        tff_weights.assign_weights_to(keras_model)
        loaded_model = keras_utils.from_keras_model(
            keras_model=keras_model,
            input_spec=input_spec,
            loss=tf.keras.losses.MeanSquaredError(),
            metrics=[NumBatchesCounter(),
                     NumExamplesCounter()])

        orig_model_output = tff_model.forward_pass(batch)
        loaded_model_output = loaded_model.forward_pass(batch)
        self.assertAlmostEqual(self.evaluate(orig_model_output.loss),
                               self.evaluate(loaded_model_output.loss))
Ejemplo n.º 4
0
 def model_fn():
   dummy_batch = collections.OrderedDict(
       x=tf.constant([['R']], tf.string), y=tf.zeros([1, 1], tf.float32))
   keras_model = model_examples.build_lookup_table_keras_model()
   return keras_utils.from_keras_model(
       keras_model,
       dummy_batch,
       loss=tf.keras.losses.MeanSquaredError(),
       metrics=[])
Ejemplo n.º 5
0
 def model_fn():
     dummy_batch = collections.OrderedDict([
         ('x', tf.constant([['R']], tf.string)),
         ('y', tf.zeros([1, 1], tf.float32)),
     ])
     keras_model = model_examples.build_lookup_table_keras_model()
     keras_model.compile(
         optimizer=tf.keras.optimizers.SGD(learning_rate=0.1),
         loss=tf.keras.losses.MeanSquaredError(),
         metrics=[])
     return keras_utils.from_compiled_keras_model(
         keras_model, dummy_batch)