Exemplo n.º 1
0
def _simple_fedavg_model_fn():
  keras_model = _create_test_cnn_model()
  loss = tf.keras.losses.SparseCategoricalCrossentropy()
  input_spec = collections.OrderedDict(
      x=tf.TensorSpec([None, 28, 28, 1], tf.float32),
      y=tf.TensorSpec([None], tf.int32))
  return dp_fedavg.KerasModelWrapper(
      keras_model=keras_model, input_spec=input_spec, loss=loss)
Exemplo n.º 2
0
 def tff_model_fn():
     keras_model = models.create_recurrent_model(
         vocab_size=FLAGS.vocab_size,
         embedding_size=FLAGS.embedding_size,
         latent_size=FLAGS.latent_size,
         num_layers=FLAGS.num_layers,
         shared_embedding=FLAGS.shared_embedding)
     loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
     return dp_fedavg.KerasModelWrapper(keras_model, input_spec, loss)
Exemplo n.º 3
0
def _rnn_model_fn(use_tff_learning=False) -> tff.learning.Model:
  keras_model = _create_test_rnn_model()
  loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  input_spec = collections.OrderedDict(
      x=tf.TensorSpec([None, 5], tf.int32),
      y=tf.TensorSpec([None, 5], tf.int32))
  if use_tff_learning:
    return tff.learning.from_keras_model(
        keras_model=keras_model, input_spec=input_spec, loss=loss)
  else:
    return dp_fedavg.KerasModelWrapper(
        keras_model=keras_model, input_spec=input_spec, loss=loss)
Exemplo n.º 4
0
 def tff_model_fn():
     keras_model = _create_original_fedavg_cnn_model(FLAGS.only_digits)
     loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
     return dp_fedavg.KerasModelWrapper(keras_model, test_data.element_spec,
                                        loss)