Пример #1
0
 def test_raises_no_repeat_and_no_take(self):
     with self.assertRaisesRegex(
             ValueError, 'Argument client_epochs_per_round is set to -1'):
         dataset.construct_character_level_datasets(
             client_batch_size=10,
             client_epochs_per_round=-1,
             max_batches_per_client=-1)
Пример #2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.compat.v1.enable_v2_behavior()
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=25))

    train_clientdata, _ = dataset.construct_character_level_datasets(
        FLAGS.client_batch_size, epochs=1)
    loss_fn_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    # Need to iterate until we find a client with data.
    for client_id in train_clientdata.client_ids:
        try:
            sample_batch = next(
                iter(train_clientdata.create_tf_dataset_for_client(client_id)))
            break
        except StopIteration:
            pass  # Client had no batches.
    sample_batch = tf.nest.map_structure(lambda t: t.numpy(), sample_batch)

    tff_model = tff.learning.from_keras_model(keras_model=model_builder(),
                                              dummy_batch=sample_batch,
                                              loss=loss_fn_builder())

    yogi_init_accum_estimate = optimizer_utils.compute_yogi_init(
        train_clientdata, tff_model, num_clients=FLAGS.num_clients)
    logging.info('Yogi initializer: {:s}'.format(
        format(yogi_init_accum_estimate, '10.6E')))
Пример #3
0
 def test_take_with_repeat(self):
     shakespeare_train = dataset.construct_character_level_datasets(
         client_batch_size=10,
         client_epochs_per_round=-1,
         max_batches_per_client=10,
         shuffle_buffer_size=0)
     self.assertEqual(len(shakespeare_train.client_ids), 715)
     for i in range(10):
         client_ds = shakespeare_train.create_tf_dataset_for_client(
             shakespeare_train.client_ids[i])
         self.assertEqual(_compute_length_of_dataset(client_ds), 10)
Пример #4
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.compat.v1.enable_v2_behavior()
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=25))

    train_clientdata, test_dataset = dataset.construct_character_level_datasets(
        FLAGS.client_batch_size, FLAGS.client_epochs_per_round,
        FLAGS.sequence_length)
    test_dataset = test_dataset.cache()

    loss_fn_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    # Need to iterate until we find a client with data.
    for client_id in train_clientdata.client_ids:
        try:
            sample_batch = next(
                iter(train_clientdata.create_tf_dataset_for_client(client_id)))
            break
        except StopIteration:
            pass  # Client had no batches.
    sample_batch = tf.nest.map_structure(lambda t: t.numpy(), sample_batch)

    def client_weight_fn(local_outputs):
        # Num_tokens is a tensor with type int64[1], to use as a weight need
        # a float32 scalar.
        return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    training_process = iterative_process_builder.from_flags(
        dummy_batch=sample_batch,
        model_builder=model_builder,
        loss_builder=loss_fn_builder,
        metrics_builder=metrics_builder,
        client_weight_fn=client_weight_fn)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=training_utils.build_client_datasets_fn(
            train_clientdata, FLAGS.clients_per_round),
        evaluate_fn=training_utils.build_evaluate_fn(
            eval_dataset=test_dataset,
            model_builder=model_builder,
            loss_builder=loss_fn_builder,
            metrics_builder=metrics_builder),
    )
Пример #5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.compat.v1.enable_v2_behavior()

    train_clientdata, test_dataset = dataset.construct_character_level_datasets(
        FLAGS.client_batch_size, FLAGS.client_epochs_per_round,
        FLAGS.sequence_length)
    test_dataset = test_dataset.cache()

    loss_fn_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    input_spec = train_clientdata.create_tf_dataset_for_client(
        train_clientdata.client_ids[0]).element_spec

    def client_weight_fn(local_outputs):
        # Num_tokens is a tensor with type int64[1], to use as a weight need
        # a float32 scalar.
        return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    training_process = iterative_process_builder.from_flags(
        input_spec=input_spec,
        model_builder=model_builder,
        loss_builder=loss_fn_builder,
        metrics_builder=metrics_builder,
        client_weight_fn=client_weight_fn)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset=train_clientdata,
        train_clients_per_round=FLAGS.clients_per_round,
        random_seed=FLAGS.client_datasets_random_seed)

    assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=test_dataset,
        model_builder=model_builder,
        loss_builder=loss_fn_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(iterative_process=training_process,
                      client_datasets_fn=client_datasets_fn,
                      evaluate_fn=evaluate_fn)
Пример #6
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=25))

    train_clientdata, _ = dataset.construct_character_level_datasets(
        FLAGS.client_batch_size, epochs=1)
    loss_fn_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    input_spec = train_clientdata.create_tf_dataset_for_client(
        train_clientdata.client_ids[0])

    tff_model = tff.learning.from_keras_model(keras_model=model_builder(),
                                              input_spec=input_spec,
                                              loss=loss_fn_builder())

    yogi_init_accum_estimate = optimizer_utils.compute_yogi_init(
        train_clientdata, tff_model, num_clients=FLAGS.num_clients)
    logging.info('Yogi initializer: {:s}'.format(
        format(yogi_init_accum_estimate, '10.6E')))