Esempio n. 1
0
def construct_example_training_comp():
    """Constructs a `computation_utils.IterativeProcess` via the FL API."""
    np.random.seed(0)

    sample_batch = collections.OrderedDict([
        ('x', np.array([[1., 1.]], dtype=np.float32)),
        ('y', np.array([[0]], dtype=np.int32))
    ])

    def model_fn():
        """Constructs keras model."""
        keras_model = tf.keras.models.Sequential([
            tf.keras.layers.Dense(1,
                                  activation=tf.nn.softmax,
                                  kernel_initializer='zeros',
                                  input_shape=(2, ))
        ])

        def loss_fn(y_true, y_pred):
            return tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(
                    y_true, y_pred))

        keras_model.compile(
            loss=loss_fn,
            optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
            metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
        return learning.from_compiled_keras_model(keras_model, sample_batch)

    return learning.build_federated_averaging_process(model_fn)
Esempio n. 2
0
def construct_example_training_comp():
    """Constructs a `tff.utils.IterativeProcess` via the FL API."""
    np.random.seed(0)
    sample_batch = collections.OrderedDict(x=np.array([[1., 1.]],
                                                      dtype=np.float32),
                                           y=np.array([[0]], dtype=np.int32))

    def model_fn():
        """Constructs keras model."""
        keras_model = tf.keras.models.Sequential([
            tf.keras.layers.Dense(1,
                                  activation=tf.nn.softmax,
                                  kernel_initializer='zeros',
                                  input_shape=(2, ))
        ])

        return learning.from_keras_model(
            keras_model,
            dummy_batch=sample_batch,
            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    return learning.build_federated_averaging_process(
        model_fn,
        client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01
                                                            ))