def _simple_fedavg_model_fn(): keras_model = _create_test_cnn_model() loss = tf.keras.losses.SparseCategoricalCrossentropy() input_spec = collections.OrderedDict( x=tf.TensorSpec([None, 28, 28, 1], tf.float32), y=tf.TensorSpec([None], tf.int32)) return dp_fedavg.KerasModelWrapper( keras_model=keras_model, input_spec=input_spec, loss=loss)
def tff_model_fn(): keras_model = models.create_recurrent_model( vocab_size=FLAGS.vocab_size, embedding_size=FLAGS.embedding_size, latent_size=FLAGS.latent_size, num_layers=FLAGS.num_layers, shared_embedding=FLAGS.shared_embedding) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) return dp_fedavg.KerasModelWrapper(keras_model, input_spec, loss)
def _rnn_model_fn(use_tff_learning=False) -> tff.learning.Model: keras_model = _create_test_rnn_model() loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) input_spec = collections.OrderedDict( x=tf.TensorSpec([None, 5], tf.int32), y=tf.TensorSpec([None, 5], tf.int32)) if use_tff_learning: return tff.learning.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=loss) else: return dp_fedavg.KerasModelWrapper( keras_model=keras_model, input_spec=input_spec, loss=loss)
def tff_model_fn(): keras_model = _create_original_fedavg_cnn_model(FLAGS.only_digits) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) return dp_fedavg.KerasModelWrapper(keras_model, test_data.element_spec, loss)