def test_iterative_process_type_signature(self):
        iterative_process = decay_iterative_process_builder.from_flags(
            input_spec=get_input_spec(),
            model_builder=model_builder,
            loss_builder=loss_builder,
            metrics_builder=metrics_builder)

        dummy_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=FLAGS.client_learning_rate,
            decay_factor=FLAGS.client_decay_factor,
            min_delta=FLAGS.min_delta,
            min_lr=FLAGS.min_lr,
            window_size=FLAGS.window_size,
            patience=FLAGS.patience)
        lr_callback_type = tff.framework.type_from_tensors(dummy_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        dataset_type = tff.FederatedType(
            tff.SequenceType(
                collections.OrderedDict(
                    x=tff.TensorType(tf.float32, [None, 1]),
                    y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS)

        metrics_type = tff.FederatedType(
            collections.OrderedDict(
                mean_squared_error=tff.TensorType(tf.float32),
                loss=tff.TensorType(tf.float32)), tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)

        expected_result_type = (server_state_type, output_type)
        expected_type = tff.FunctionType(parameter=(server_state_type,
                                                    dataset_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertTrue(actual_type.is_equivalent_to(expected_type))
    def test_iterative_process_decreases_loss(self):
        iterative_process = decay_iterative_process_builder.from_flags(
            input_spec=get_input_spec(),
            model_builder=model_builder,
            loss_builder=loss_builder,
            metrics_builder=metrics_builder)

        state, train_outputs = self._run_rounds(iterative_process, 4)
        self.assertLess(train_outputs[-1]['before_training']['loss'],
                        train_outputs[0]['before_training']['loss'])
        self.assertLess(train_outputs[-1]['during_training']['loss'],
                        train_outputs[0]['during_training']['loss'])
        self.assertNear(state.client_lr_callback.learning_rate, 0.1, 1e-8)
        self.assertNear(state.server_lr_callback.learning_rate, 0.1, 1e-8)
    def test_client_decay_schedule(self):
        FLAGS.client_decay_factor = 0.5
        FLAGS.server_decay_factor = 1.0
        FLAGS.min_delta = 0.5
        FLAGS.min_lr = 0.05
        FLAGS.window_size = 1
        FLAGS.patience = 1

        iterative_process = decay_iterative_process_builder.from_flags(
            input_spec=get_input_spec(),
            model_builder=model_builder,
            loss_builder=loss_builder,
            metrics_builder=metrics_builder)

        state, train_outputs = self._run_rounds(iterative_process, 10)
        self.assertLess(train_outputs[-1]['before_training']['loss'],
                        train_outputs[0]['before_training']['loss'])
        self.assertLess(train_outputs[-1]['during_training']['loss'],
                        train_outputs[0]['during_training']['loss'])
        self.assertNear(state.client_lr_callback.learning_rate, 0.05, 1e-8)
        self.assertNear(state.server_lr_callback.learning_rate, 0.1, 1e-8)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    model_builder = functools.partial(
        stackoverflow_models.create_recurrent_model,
        vocab_size=FLAGS.vocab_size,
        embedding_size=FLAGS.embedding_size,
        latent_size=FLAGS.latent_size,
        num_layers=FLAGS.num_layers,
        shared_embedding=FLAGS.shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    special_tokens = stackoverflow_dataset.get_special_tokens(FLAGS.vocab_size)

    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token])
        ]

    train_set, validation_set, test_set = stackoverflow_dataset.construct_word_level_datasets(
        FLAGS.vocab_size,
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        FLAGS.sequence_length,
        FLAGS.max_elements_per_user,
        FLAGS.num_validation_examples,
        max_batches_per_user=FLAGS.max_batches_per_client)

    input_spec = validation_set.element_spec

    if FLAGS.client_weight == 'uniform':

        def client_weight_fn(local_outputs):
            del local_outputs
            return 1.0

    elif FLAGS.client_weight == 'num_tokens':

        def client_weight_fn(local_outputs):
            # Num_tokens is a tensor with type int64[1], to use as a weight need
            # a float32 scalar.
            return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    else:
        raise ValueError(
            'Unsupported client_weight flag [{!s}]. Currently only '
            '`uniform` and `num_tokens` are supported.'.format(
                FLAGS.client_weight))

    training_process = decay_iterative_process_builder.from_flags(
        input_spec=input_spec,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        client_weight_fn=client_weight_fn)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_set,
        FLAGS.clients_per_round,
        random_seed=FLAGS.client_datasets_random_seed)

    assign_weights_fn = adaptive_fed_avg.ServerState.assign_weights_to_keras_model

    evaluate_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=validation_set,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    test_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=validation_set.concatenate(test_set),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(training_process,
                      client_datasets_fn,
                      evaluate_fn,
                      test_fn=test_fn)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    stackoverflow_train, stackoverflow_validation, stackoverflow_test = stackoverflow_lr_dataset.get_stackoverflow_datasets(
        vocab_tokens_size=FLAGS.vocab_tokens_size,
        vocab_tags_size=FLAGS.vocab_tags_size,
        client_batch_size=FLAGS.client_batch_size,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        max_training_elements_per_user=FLAGS.max_elements_per_user,
        max_batches_per_user=FLAGS.max_batches_per_client,
        num_validation_examples=FLAGS.num_validation_examples)

    input_spec = stackoverflow_train.create_tf_dataset_for_client(
        stackoverflow_train.client_ids[0]).element_spec

    model_builder = functools.partial(
        stackoverflow_lr_models.create_logistic_model,
        vocab_tokens_size=FLAGS.vocab_tokens_size,
        vocab_tags_size=FLAGS.vocab_tags_size)

    loss_builder = functools.partial(tf.keras.losses.BinaryCrossentropy,
                                     from_logits=False,
                                     reduction=tf.keras.losses.Reduction.SUM)

    if FLAGS.client_weight == 'uniform':

        def client_weight_fn(local_outputs):
            del local_outputs
            return 1.0

    elif FLAGS.client_weight == 'num_samples':
        client_weight_fn = None
    else:
        raise ValueError(
            'Unsupported client_weight flag [{!s}]. Currently only '
            '`uniform` and `num_samples` are supported.'.format(
                FLAGS.client_weight))

    training_process = decay_iterative_process_builder.from_flags(
        input_spec=input_spec,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        client_weight_fn=client_weight_fn)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        stackoverflow_train,
        FLAGS.clients_per_round,
        random_seed=FLAGS.client_datasets_random_seed)

    assign_weights_fn = adaptive_fed_avg.ServerState.assign_weights_to_keras_model

    evaluate_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=stackoverflow_validation,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    test_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=stackoverflow_validation.concatenate(stackoverflow_test),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(training_process,
                      client_datasets_fn,
                      evaluate_fn,
                      test_fn=test_fn)
Exemple #6
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    train_clientdata = shakespeare_dataset.construct_character_level_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        sequence_length=FLAGS.sequence_length,
        max_batches_per_client=FLAGS.max_batches_per_client,
        shuffle_buffer_size=0)
    eval_train_dataset, eval_test_dataset = (
        shakespeare_dataset.construct_centralized_datasets())

    loss_fn_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    input_spec = train_clientdata.create_tf_dataset_for_client(
        train_clientdata.client_ids[0]).element_spec

    if FLAGS.client_weight == 'uniform':

        def client_weight_fn(local_outputs):
            del local_outputs
            return 1.0

    elif FLAGS.client_weight == 'num_tokens':

        def client_weight_fn(local_outputs):
            # Num_tokens is a tensor with type int64[1], to use as a weight need
            # a float32 scalar.
            return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    else:
        raise ValueError(
            'Unsupported client_weight flag [{!s}]. Currently only '
            '`uniform` and `num_tokens` are supported.'.format(
                FLAGS.client_weight))

    training_process = decay_iterative_process_builder.from_flags(
        input_spec=input_spec,
        model_builder=model_builder,
        loss_builder=loss_fn_builder,
        metrics_builder=metrics_builder,
        client_weight_fn=client_weight_fn)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_clientdata,
        FLAGS.clients_per_round,
        random_seed=FLAGS.client_datasets_random_seed)

    assign_weights_fn = adaptive_fed_avg.ServerState.assign_weights_to_keras_model

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=eval_test_dataset,
        model_builder=model_builder,
        loss_builder=loss_fn_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    train_evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=eval_train_dataset,
        model_builder=model_builder,
        loss_builder=loss_fn_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=client_datasets_fn,
        validation_fn=evaluate_fn,
        train_eval_fn=train_evaluate_fn,
    )
Exemple #7
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    emnist_train, emnist_test = emnist_dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        max_batches_per_client=FLAGS.max_batches_per_client,
        only_digits=FLAGS.only_digits)

    central_emnist_train, _ = emnist_dataset.get_centralized_emnist_datasets(
        batch_size=100, only_digits=FLAGS.only_digits, shuffle_train=False)

    input_spec = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0]).element_spec

    if FLAGS.model == 'cnn':
        model_builder = functools.partial(
            emnist_models.create_conv_dropout_model,
            only_digits=FLAGS.only_digits)
    elif FLAGS.model == 'orig_cnn':
        model_builder = functools.partial(
            emnist_models.create_original_fedavg_cnn_model,
            only_digits=FLAGS.only_digits)
    elif FLAGS.model == '2nn':
        model_builder = functools.partial(
            emnist_models.create_two_hidden_layer_model,
            only_digits=FLAGS.only_digits)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    if FLAGS.client_weight == 'uniform':

        def client_weight_fn(local_outputs):
            del local_outputs
            return 1.0

    elif FLAGS.client_weight == 'num_samples':
        client_weight_fn = None
    else:
        raise ValueError(
            'Unsupported client_weight flag [{!s}]. Currently only '
            '`uniform` and `num_samples` are supported.'.format(
                FLAGS.client_weight))

    training_process = decay_iterative_process_builder.from_flags(
        input_spec=input_spec,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        client_weight_fn=client_weight_fn)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train,
        FLAGS.clients_per_round,
        random_seed=FLAGS.client_datasets_random_seed)

    assign_weights_fn = adaptive_fed_avg.ServerState.assign_weights_to_keras_model

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    train_evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=central_emnist_train,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=client_datasets_fn,
        validation_fn=evaluate_fn,
        train_eval_fn=train_evaluate_fn,
    )
Exemple #8
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    cifar_train, _ = cifar100_dataset.get_federated_cifar100(
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        train_batch_size=FLAGS.client_batch_size,
        crop_shape=CROP_SHAPE,
        max_batches_per_client=FLAGS.max_batches_per_client)

    central_cifar_train, cifar_test = cifar100_dataset.get_centralized_cifar100(
        100, crop_shape=CROP_SHAPE)

    input_spec = cifar_train.create_tf_dataset_for_client(
        cifar_train.client_ids[0]).element_spec

    model_builder = functools.partial(resnet_models.create_resnet18,
                                      input_shape=CROP_SHAPE,
                                      num_classes=NUM_CLASSES)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    if FLAGS.client_weight == 'uniform':

        def client_weight_fn(local_outputs):
            del local_outputs
            return 1.0

    elif FLAGS.client_weight == 'num_samples':
        client_weight_fn = None
    else:
        raise ValueError(
            'Unsupported client_weight flag [{!s}]. Currently only '
            '`uniform` and `num_samples` are supported.'.format(
                FLAGS.client_weight))

    training_process = decay_iterative_process_builder.from_flags(
        input_spec=input_spec,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        client_weight_fn=client_weight_fn)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        cifar_train,
        FLAGS.clients_per_round,
        random_seed=FLAGS.client_datasets_random_seed)

    assign_weights_fn = adaptive_fed_avg.ServerState.assign_weights_to_keras_model

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=cifar_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    train_evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=central_cifar_train,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    training_loop.run(iterative_process=training_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=evaluate_fn,
                      train_eval_fn=train_evaluate_fn)