Exemple #1
0
    def test_dp_momentum_training(self, model_fn, optimzer_fn, total_rounds=3):
        def server_optimzier_fn(model_weights):
            model_weight_shape = tf.nest.map_structure(tf.shape, model_weights)
            return optimzer_fn(learning_rate=1.0,
                               momentum=0.9,
                               noise_std=1e-5,
                               model_weight_shape=model_weight_shape)

        print('defining it process')
        it_process = dp_fedavg.build_federated_averaging_process(
            model_fn, server_optimizer_fn=server_optimzier_fn)
        print('next type', it_process.next.type_signature.parameter[0])
        server_state = it_process.initialize()

        def deterministic_batch():
            return collections.OrderedDict(x=np.ones([1, 28, 28, 1],
                                                     dtype=np.float32),
                                           y=np.ones([1], dtype=np.int32))

        batch = tff.tf_computation(deterministic_batch)()
        federated_data = [[batch]]

        loss_list = []
        for i in range(total_rounds):
            print('round', i)
            server_state, loss = it_process.next(server_state, federated_data)
            loss_list.append(loss)
            self.assertEqual(i + 1, server_state.round_num)
            if 'server_state_type' in server_state.optimizer_state:
                self.assertEqual(
                    i + 1,
                    tree_aggregation.get_step_idx(
                        server_state.optimizer_state['dp_tree_state']))
        self.assertLess(np.mean(loss_list[1:]), loss_list[0])
    def test_dpftal_training(self, total_rounds=5):
        def server_optimzier_fn(model_weights):
            model_weight_shape = tf.nest.map_structure(tf.shape, model_weights)
            return optimizer_utils.DPFTRLMServerOptimizer(
                learning_rate=0.1,
                momentum=0.9,
                noise_std=1e-5,
                model_weight_shape=model_weight_shape)

        it_process = dp_fedavg.build_federated_averaging_process(
            _rnn_model_fn, server_optimizer_fn=server_optimzier_fn)
        server_state = it_process.initialize()

        def deterministic_batch():
            return collections.OrderedDict(x=np.array([[0, 1, 2, 3, 4]],
                                                      dtype=np.int32),
                                           y=np.array([[1, 2, 3, 4, 0]],
                                                      dtype=np.int32))

        batch = tff.tf_computation(deterministic_batch)()
        federated_data = [[batch]]

        loss_list = []
        for i in range(total_rounds):
            server_state, loss = it_process.next(server_state, federated_data)
            loss_list.append(loss)
            self.assertEqual(i + 1, server_state.round_num)
            self.assertEqual(
                i + 1,
                tree_aggregation.get_step_idx(
                    server_state.optimizer_state['dp_tree_state'].level_state))
        self.assertLess(np.mean(loss_list[1:]), loss_list[0])
  def test_dp_momentum_training(self, model_fn, optimzer_fn, total_rounds=3):

    def server_optimzier_fn(model_weights):
      model_weight_specs = tf.nest.map_structure(
          lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights)
      return optimzer_fn(
          learning_rate=1.0,
          momentum=0.9,
          noise_std=1e-5,
          model_weight_specs=model_weight_specs)

    it_process = dp_fedavg.build_federated_averaging_process(
        model_fn, server_optimizer_fn=server_optimzier_fn)
    server_state = it_process.initialize()

    def deterministic_batch():
      return collections.OrderedDict(
          x=np.ones([1, 28, 28, 1], dtype=np.float32),
          y=np.ones([1], dtype=np.int32))

    batch = tff.tf_computation(deterministic_batch)()
    federated_data = [[batch]]

    loss_list = []
    for i in range(total_rounds):
      server_state, loss = it_process.next(server_state, federated_data)
      loss_list.append(loss)
      self.assertEqual(i + 1, server_state.round_num)
      if server_state.optimizer_state is optimizer_utils.FTRLState:
        self.assertEqual(
            i + 1,
            tree_aggregation.get_step_idx(
                server_state.optimizer_state.dp_tree_state))
    self.assertLess(np.mean(loss_list[1:]), loss_list[0])
 def test_something(self, model_fn):
   it_process = dp_fedavg.build_federated_averaging_process(model_fn)
   self.assertIsInstance(it_process, tff.templates.IterativeProcess)
   federated_data_type = it_process.next.type_signature.parameter[1]
   self.assertEqual(
       str(federated_data_type),
       '{<x=float32[?,28,28,1],y=int32[?]>*}@CLIENTS')
Exemple #5
0
    def test_self_contained_example_custom_model(self):

        client_data = _create_client_data()
        train_data = [client_data()]

        trainer = dp_fedavg.build_federated_averaging_process(MnistModel)
        state = trainer.initialize()
        losses = []
        for _ in range(2):
            state, loss = trainer.next(state, train_data)
            losses.append(loss)
        self.assertLess(losses[1], losses[0])
  def test_dpftal_restart(self, total_rounds=3):

    def server_optimizer_fn(model_weights):
      model_weight_specs = tf.nest.map_structure(
          lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights)
      return optimizer_utils.DPFTRLMServerOptimizer(
          learning_rate=0.1,
          momentum=0.9,
          noise_std=1e-5,
          model_weight_specs=model_weight_specs,
          efficient_tree=True,
          use_nesterov=True)

    it_process = dp_fedavg.build_federated_averaging_process(
        _rnn_model_fn,
        server_optimizer_fn=server_optimizer_fn,
        use_simulation_loop=True)
    server_state = it_process.initialize()

    model = _rnn_model_fn()
    optimizer = server_optimizer_fn(model.weights.trainable)

    def server_state_update(state):
      return tff.structure.update_struct(
          state,
          model=state.model,
          optimizer_state=optimizer.restart_dp_tree(state.model.trainable),
          round_num=state.round_num)

    def deterministic_batch():
      return collections.OrderedDict(
          x=np.array([[0, 1, 2, 3, 4]], dtype=np.int32),
          y=np.array([[1, 2, 3, 4, 0]], dtype=np.int32))

    batch = tff.tf_computation(deterministic_batch)()
    federated_data = [[batch]]

    loss_list = []
    for i in range(total_rounds):
      server_state, loss = it_process.next(server_state, federated_data)
      server_state = server_state_update(server_state)
      loss_list.append(loss)
      self.assertEqual(i + 1, server_state.round_num)
      self.assertEqual(
          0,
          tree_aggregation.get_step_idx(
              server_state.optimizer_state.dp_tree_state))
    self.assertLess(np.mean(loss_list[1:]), loss_list[0])
Exemple #7
0
    def test_simple_training(self, model_fn):
        it_process = dp_fedavg.build_federated_averaging_process(model_fn)
        server_state = it_process.initialize()

        def deterministic_batch():
            return collections.OrderedDict(x=np.ones([1, 28, 28, 1],
                                                     dtype=np.float32),
                                           y=np.ones([1], dtype=np.int32))

        batch = tff.tf_computation(deterministic_batch)()
        federated_data = [[batch]]

        loss_list = []
        for _ in range(3):
            server_state, loss = it_process.next(server_state, federated_data)
            loss_list.append(loss)

        self.assertLess(np.mean(loss_list[1:]), loss_list[0])
  def test_client_adagrad_train(self):
    it_process = dp_fedavg.build_federated_averaging_process(
        _rnn_model_fn,
        client_optimizer_fn=functools.partial(
            tf.keras.optimizers.Adagrad, learning_rate=0.01))
    server_state = it_process.initialize()

    def deterministic_batch():
      return collections.OrderedDict(
          x=np.array([[0, 1, 2, 3, 4]], dtype=np.int32),
          y=np.array([[1, 2, 3, 4, 0]], dtype=np.int32))

    batch = tff.tf_computation(deterministic_batch)()
    federated_data = [[batch]]

    loss_list = []
    for _ in range(3):
      server_state, loss = it_process.next(server_state, federated_data)
      loss_list.append(loss)

    self.assertLess(np.mean(loss_list[1:]), loss_list[0])
Exemple #9
0
    def test_tff_learning_evaluate(self):
        it_process = dp_fedavg.build_federated_averaging_process(
            _tff_learning_model_fn)
        server_state = it_process.initialize()
        sample_data = [
            collections.OrderedDict(x=np.ones([1, 28, 28, 1],
                                              dtype=np.float32),
                                    y=np.ones([1], dtype=np.int32))
        ]
        keras_model = _create_test_cnn_model()
        server_state.model.assign_weights_to(keras_model)

        sample_data = [
            collections.OrderedDict(x=np.ones([1, 28, 28, 1],
                                              dtype=np.float32),
                                    y=np.ones([1], dtype=np.int32))
        ]
        metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
        metrics = dp_fedavg.keras_evaluate(keras_model, sample_data, metrics)
        accuracy = metrics[0].result()
        self.assertIsInstance(accuracy, tf.Tensor)
        self.assertBetween(accuracy, 0.0, 1.0)
Exemple #10
0
def train_and_eval():
    """Train and evaluate StackOver NWP task."""
    logging.info('Show FLAGS for debugging:')
    for f in HPARAM_FLAGS:
        logging.info('%s=%s', f, FLAGS[f].value)

    train_dataset_computation, train_set, validation_set, test_set = _preprocess_stackoverflow(
        FLAGS.vocab_size, FLAGS.num_oov_buckets, FLAGS.sequence_length,
        FLAGS.num_validation_examples, FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round, FLAGS.max_elements_per_user)

    input_spec = train_dataset_computation.type_signature.result.element

    def tff_model_fn():
        keras_model = models.create_recurrent_model(
            vocab_size=FLAGS.vocab_size,
            embedding_size=FLAGS.embedding_size,
            latent_size=FLAGS.latent_size,
            num_layers=FLAGS.num_layers,
            shared_embedding=FLAGS.shared_embedding)
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        return dp_fedavg.KerasModelWrapper(keras_model, input_spec, loss)

    noise_std = FLAGS.clip_norm * FLAGS.noise_multiplier / float(
        FLAGS.clients_per_round)
    server_optimizer_fn = functools.partial(_server_optimizer_fn,
                                            name=FLAGS.server_optimizer,
                                            learning_rate=FLAGS.server_lr,
                                            noise_std=noise_std)
    client_optimizer_fn = functools.partial(_client_optimizer_fn,
                                            name=FLAGS.client_optimizer,
                                            learning_rate=FLAGS.client_lr)
    iterative_process = dp_fedavg.build_federated_averaging_process(
        tff_model_fn,
        dp_clip_norm=FLAGS.clip_norm,
        server_optimizer_fn=server_optimizer_fn,
        client_optimizer_fn=client_optimizer_fn)
    iterative_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        dataset_computation=train_dataset_computation,
        process=iterative_process)

    keras_metics = _get_stackoverflow_metrics(FLAGS.vocab_size,
                                              FLAGS.num_oov_buckets)
    model = tff_model_fn()

    def evaluate_fn(model_weights, dataset):
        model.from_weights(model_weights)
        metrics = dp_fedavg.keras_evaluate(model.keras_model, dataset,
                                           keras_metics)
        return collections.OrderedDict(
            (metric.name, metric.result().numpy()) for metric in metrics)

    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in HPARAM_FLAGS])

    if FLAGS.total_epochs is None:

        def client_dataset_ids_fn(round_num: int, epoch: int):
            return _sample_client_ids(FLAGS.clients_per_round, train_set,
                                      round_num, epoch)

        logging.info('Sample clients for max %d rounds', FLAGS.total_rounds)
        total_epochs = 0
    else:
        client_shuffer = training_loop.ClientIDShuffler(
            FLAGS.clients_per_round, train_set)
        client_dataset_ids_fn = client_shuffer.sample_client_ids
        logging.info('Shuffle clients for max %d epochs and %d rounds',
                     FLAGS.total_epochs, FLAGS.total_rounds)
        total_epochs = FLAGS.total_epochs

    training_loop.run(iterative_process,
                      client_dataset_ids_fn,
                      validation_fn=functools.partial(evaluate_fn,
                                                      dataset=validation_set),
                      total_epochs=total_epochs,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      train_eval_fn=None,
                      test_fn=functools.partial(evaluate_fn, dataset=test_set),
                      root_output_dir=FLAGS.root_output_dir,
                      hparam_dict=hparam_dict,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_train_eval=2000)
Exemple #11
0
 def test_build_fedavg_process(self):
     it_process = dp_fedavg.build_federated_averaging_process(_rnn_model_fn)
     self.assertIsInstance(it_process, tff.templates.IterativeProcess)
     federated_type = it_process.next.type_signature.parameter
     self.assertEqual(str(federated_type[1]),
                      '{<x=int32[?,5],y=int32[?,5]>*}@CLIENTS')
Exemple #12
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_devices = tf.config.list_logical_devices('GPU')
    server_device = tf.config.list_logical_devices('CPU')[0]
    tff.backends.native.set_local_execution_context(
        max_fanout=2 * FLAGS.clients_per_round,
        server_tf_device=server_device,
        client_tf_devices=client_devices,
        clients_per_thread=FLAGS.clients_per_thread)

    logging.info('Show FLAGS for debugging:')
    for f in HPARAM_FLAGS:
        logging.info('%s=%s', f, FLAGS[f].value)

    train_data, test_data = _get_emnist_dataset(
        FLAGS.only_digits,
        FLAGS.client_epochs_per_round,
        FLAGS.client_batch_size,
    )

    def tff_model_fn():
        keras_model = _create_original_fedavg_cnn_model(FLAGS.only_digits)
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
        return dp_fedavg.KerasModelWrapper(keras_model, test_data.element_spec,
                                           loss)

    noise_std = FLAGS.clip_norm * FLAGS.noise_multiplier / float(
        FLAGS.clients_per_round)
    server_optimizer_fn = functools.partial(_server_optimizer_fn,
                                            name=FLAGS.server_optimizer,
                                            learning_rate=FLAGS.server_lr,
                                            noise_std=noise_std)
    client_optimizer_fn = functools.partial(_client_optimizer_fn,
                                            name=FLAGS.client_optimizer,
                                            learning_rate=FLAGS.client_lr)
    iterative_process = dp_fedavg.build_federated_averaging_process(
        tff_model_fn,
        dp_clip_norm=FLAGS.clip_norm,
        server_optimizer_fn=server_optimizer_fn,
        client_optimizer_fn=client_optimizer_fn)

    keras_metics = [tf.keras.metrics.SparseCategoricalAccuracy()]
    model = tff_model_fn()

    def evaluate_fn(model_weights, dataset):
        model.from_weights(model_weights)
        metrics = dp_fedavg.keras_evaluate(model.keras_model, dataset,
                                           keras_metics)
        return collections.OrderedDict(
            (metric.name, metric.result().numpy()) for metric in metrics)

    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in HPARAM_FLAGS])
    total_epochs = 0 if FLAGS.total_epochs is None else FLAGS.total_epochs
    training_loop.run(iterative_process,
                      client_datasets_fn=_get_client_datasets_fn(train_data),
                      validation_fn=functools.partial(evaluate_fn,
                                                      dataset=test_data),
                      total_rounds=FLAGS.total_rounds,
                      total_epochs=total_epochs,
                      experiment_name=FLAGS.experiment_name,
                      train_eval_fn=None,
                      test_fn=functools.partial(evaluate_fn,
                                                dataset=test_data),
                      root_output_dir=FLAGS.root_output_dir,
                      hparam_dict=hparam_dict,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_train_eval=2000)