def construct_example_training_comp(): """Constructs a `computation_utils.IterativeProcess` via the FL API.""" np.random.seed(0) sample_batch = collections.OrderedDict([ ('x', np.array([[1., 1.]], dtype=np.float32)), ('y', np.array([[0]], dtype=np.int32)) ]) def model_fn(): """Constructs keras model.""" keras_model = tf.keras.models.Sequential([ tf.keras.layers.Dense(1, activation=tf.nn.softmax, kernel_initializer='zeros', input_shape=(2, )) ]) def loss_fn(y_true, y_pred): return tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( y_true, y_pred)) keras_model.compile( loss=loss_fn, optimizer=tf.keras.optimizers.SGD(learning_rate=0.01), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) return learning.from_compiled_keras_model(keras_model, sample_batch) return learning.build_federated_averaging_process(model_fn)
def construct_example_training_comp(): """Constructs a `tff.utils.IterativeProcess` via the FL API.""" np.random.seed(0) sample_batch = collections.OrderedDict(x=np.array([[1., 1.]], dtype=np.float32), y=np.array([[0]], dtype=np.int32)) def model_fn(): """Constructs keras model.""" keras_model = tf.keras.models.Sequential([ tf.keras.layers.Dense(1, activation=tf.nn.softmax, kernel_initializer='zeros', input_shape=(2, )) ]) return learning.from_keras_model( keras_model, dummy_batch=sample_batch, loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) return learning.build_federated_averaging_process( model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01 ))