Example #1
0
def iterator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN,
             client_state_fn: CLIENT_STATE_FN,
             server_optimizer_fn: OPTIMIZER_FN,
             client_optimizer_fn: OPTIMIZER_FN):
    model = model_fn()
    client_state = client_state_fn()

    init_tf = tff.tf_computation(
        lambda: __initialize_server(model_fn, server_optimizer_fn))

    server_state_type = init_tf.type_signature.result
    client_state_type = tff.framework.type_from_tensors(client_state)

    update_server_tf = tff.tf_computation(
        lambda state, weights_delta: __update_server(
            state, weights_delta, model_fn, server_optimizer_fn,
            tf.function(server.update)),
        (server_state_type, server_state_type.model.trainable))

    state_to_message_tf = tff.tf_computation(
        lambda state: __state_to_message(state, tf.function(server.to_message)
                                         ), server_state_type)

    dataset_type = tff.SequenceType(model.input_spec)
    server_message_type = state_to_message_tf.type_signature.result

    update_client_tf = tff.tf_computation(
        lambda dataset, state, message: __update_client(
            dataset, state, message, coefficient_fn, model_fn,
            client_optimizer_fn, tf.function(client.update)),
        (dataset_type, client_state_type, server_message_type))

    federated_server_state_type = tff.type_at_server(server_state_type)
    federated_dataset_type = tff.type_at_clients(dataset_type)
    federated_client_state_type = tff.type_at_clients(client_state_type)

    def init_tff():
        return tff.federated_value(init_tf(), tff.SERVER)

    def next_tff(server_state, datasets, client_states):
        message = tff.federated_map(state_to_message_tf, server_state)
        broadcast = tff.federated_broadcast(message)

        outputs = tff.federated_map(update_client_tf,
                                    (datasets, client_states, broadcast))
        weights_delta = tff.federated_mean(outputs.weights_delta,
                                           weight=outputs.client_weight)

        metrics = model.federated_output_computation(outputs.metrics)

        next_state = tff.federated_map(update_server_tf,
                                       (server_state, weights_delta))

        return next_state, metrics, outputs.client_state

    return tff.templates.IterativeProcess(
        initialize_fn=tff.federated_computation(init_tff),
        next_fn=tff.federated_computation(
            next_tff, (federated_server_state_type, federated_dataset_type,
                       federated_client_state_type)))
Example #2
0
    def test_twice_used_variable_keeps_separate_state(self):
        def count_one_body():
            variable = tf.Variable(initial_value=0, name='var_of_interest')
            with tf.control_dependencies([variable.assign_add(1)]):
                return variable.read_value()

        count_one_1 = tff.tf_computation(count_one_body)
        count_one_2 = tff.tf_computation(count_one_body)

        @tff.tf_computation
        def count_one_twice():
            return count_one_1(), count_one_1(), count_one_2()

        self.assertEqual((1, 1, 1), count_one_twice())
Example #3
0
def validator(
  model_fn: MODEL_FN,
  client_state_fn: CLIENT_STATE_FN
):
  model = model_fn()
  client_state = client_state_fn()

  dataset_type = tff.SequenceType(model.input_spec)
  client_state_type = tff.framework.type_from_tensors(client_state)

  validate_client_tf = tff.tf_computation(
    lambda dataset, state: __validate_client(
      dataset,
      state,
      model_fn,
      tf.function(client.validate)
    ),
    (dataset_type, client_state_type)
  )

  federated_dataset_type = tff.type_at_clients(dataset_type)
  federated_client_state_type = tff.type_at_clients(client_state_type)    

  def validate(datasets, client_states):
    outputs = tff.federated_map(validate_client_tf, (datasets, client_states))
    metrics = model.federated_output_computation(outputs.metrics)

    return metrics

  return tff.federated_computation(
    validate,
    (federated_dataset_type, federated_client_state_type)
  )
Example #4
0
  def test_dp_momentum_training(self, model_fn, optimzer_fn, total_rounds=3):

    def server_optimzier_fn(model_weights):
      model_weight_specs = tf.nest.map_structure(
          lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights)
      return optimzer_fn(
          learning_rate=1.0,
          momentum=0.9,
          noise_std=1e-5,
          model_weight_specs=model_weight_specs)

    it_process = dp_fedavg.build_federated_averaging_process(
        model_fn, server_optimizer_fn=server_optimzier_fn)
    server_state = it_process.initialize()

    def deterministic_batch():
      return collections.OrderedDict(
          x=np.ones([1, 28, 28, 1], dtype=np.float32),
          y=np.ones([1], dtype=np.int32))

    batch = tff.tf_computation(deterministic_batch)()
    federated_data = [[batch]]

    loss_list = []
    for i in range(total_rounds):
      server_state, loss = it_process.next(server_state, federated_data)
      loss_list.append(loss)
      self.assertEqual(i + 1, server_state.round_num)
      if server_state.optimizer_state is optimizer_utils.FTRLState:
        self.assertEqual(
            i + 1,
            tree_aggregation.get_step_idx(
                server_state.optimizer_state.dp_tree_state))
    self.assertLess(np.mean(loss_list[1:]), loss_list[0])
Example #5
0
def evaluator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN,
              client_state_fn: CLIENT_STATE_FN):
    model = model_fn()
    client_state = client_state_fn()

    dataset_type = tff.SequenceType(model.input_spec)
    client_state_type = tff.framework.type_from_tensors(client_state)
    weights_type = tff.framework.type_from_tensors(
        tff.learning.ModelWeights.from_model(model))

    evaluate_client_tf = tff.tf_computation(
        lambda dataset, state, weights: __evaluate_client(
            dataset, state, weights, coefficient_fn, model_fn,
            tf.function(client.evaluate)),
        (dataset_type, client_state_type, weights_type))

    federated_weights_type = tff.type_at_server(weights_type)
    federated_dataset_type = tff.type_at_clients(dataset_type)
    federated_client_state_type = tff.type_at_clients(client_state_type)

    def evaluate(weights, datasets, client_states):
        broadcast = tff.federated_broadcast(weights)
        outputs = tff.federated_map(evaluate_client_tf,
                                    (datasets, client_states, broadcast))

        confusion_matrix = tff.federated_sum(outputs.confusion_matrix)
        aggregated_metrics = model.federated_output_computation(
            outputs.metrics)
        collected_metrics = tff.federated_collect(outputs.metrics)

        return confusion_matrix, aggregated_metrics, collected_metrics

    return tff.federated_computation(
        evaluate, (federated_weights_type, federated_dataset_type,
                   federated_client_state_type))
Example #6
0
def validator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN,
              client_state_fn: CLIENT_STATE_FN):
    model = model_fn()
    client_state = client_state_fn()

    dataset_type = tff.SequenceType(model.input_spec)
    client_state_type = tff.framework.type_from_tensors(client_state)
    weights_type = tff.learning.framework.weights_type_from_model(model)

    validate_client_tf = tff.tf_computation(
        lambda dataset, state, weights: __validate_client(
            dataset, state, weights, coefficient_fn, model_fn,
            tf.function(client.validate)),
        (dataset_type, client_state_type, weights_type))

    federated_weights_type = tff.type_at_server(weights_type)
    federated_dataset_type = tff.type_at_clients(dataset_type)
    federated_client_state_type = tff.type_at_clients(client_state_type)

    def validate(weights, datasets, client_states):
        broadcast = tff.federated_broadcast(weights)
        outputs = tff.federated_map(validate_client_tf,
                                    (datasets, client_states, broadcast))
        metrics = model.federated_output_computation(outputs.metrics)

        return metrics

    return tff.federated_computation(
        validate, (federated_weights_type, federated_dataset_type,
                   federated_client_state_type))
    def test_dpftal_training(self, total_rounds=5):
        def server_optimzier_fn(model_weights):
            model_weight_shape = tf.nest.map_structure(tf.shape, model_weights)
            return optimizer_utils.DPFTRLMServerOptimizer(
                learning_rate=0.1,
                momentum=0.9,
                noise_std=1e-5,
                model_weight_shape=model_weight_shape)

        it_process = dp_fedavg.build_federated_averaging_process(
            _rnn_model_fn, server_optimizer_fn=server_optimzier_fn)
        server_state = it_process.initialize()

        def deterministic_batch():
            return collections.OrderedDict(x=np.array([[0, 1, 2, 3, 4]],
                                                      dtype=np.int32),
                                           y=np.array([[1, 2, 3, 4, 0]],
                                                      dtype=np.int32))

        batch = tff.tf_computation(deterministic_batch)()
        federated_data = [[batch]]

        loss_list = []
        for i in range(total_rounds):
            server_state, loss = it_process.next(server_state, federated_data)
            loss_list.append(loss)
            self.assertEqual(i + 1, server_state.round_num)
            self.assertEqual(
                i + 1,
                tree_aggregation.get_step_idx(
                    server_state.optimizer_state['dp_tree_state'].level_state))
        self.assertLess(np.mean(loss_list[1:]), loss_list[0])
    def test_simple_training(self):
        it_process = simple_fedavg_tff.build_federated_averaging_process(
            _model_fn)
        server_state = it_process.initialize()
        Batch = collections.namedtuple('Batch', ['x', 'y'])  # pylint: disable=invalid-name

        # Test out manually setting weights:
        keras_model = _create_test_cnn_model(only_digits=True)

        def deterministic_batch():
            return Batch(x=np.ones([1, 28, 28, 1], dtype=np.float32),
                         y=np.ones([1], dtype=np.int32))

        batch = tff.tf_computation(deterministic_batch)()
        federated_data = [[batch]]

        def keras_evaluate(state):
            tff.learning.assign_weights_to_keras_model(keras_model,
                                                       state.model_weights)
            keras_model.predict(batch.x)

        loss_list = []
        for _ in range(3):
            keras_evaluate(server_state)
            server_state, loss = it_process.next(server_state, federated_data)
            loss_list.append(loss)
        keras_evaluate(server_state)

        self.assertLess(np.mean(loss_list[1:]), loss_list[0])
Example #9
0
    def test_dp_momentum_training(self, model_fn, optimzer_fn, total_rounds=3):
        def server_optimzier_fn(model_weights):
            model_weight_shape = tf.nest.map_structure(tf.shape, model_weights)
            return optimzer_fn(learning_rate=1.0,
                               momentum=0.9,
                               noise_std=1e-5,
                               model_weight_shape=model_weight_shape)

        print('defining it process')
        it_process = dp_fedavg.build_federated_averaging_process(
            model_fn, server_optimizer_fn=server_optimzier_fn)
        print('next type', it_process.next.type_signature.parameter[0])
        server_state = it_process.initialize()

        def deterministic_batch():
            return collections.OrderedDict(x=np.ones([1, 28, 28, 1],
                                                     dtype=np.float32),
                                           y=np.ones([1], dtype=np.int32))

        batch = tff.tf_computation(deterministic_batch)()
        federated_data = [[batch]]

        loss_list = []
        for i in range(total_rounds):
            print('round', i)
            server_state, loss = it_process.next(server_state, federated_data)
            loss_list.append(loss)
            self.assertEqual(i + 1, server_state.round_num)
            if 'server_state_type' in server_state.optimizer_state:
                self.assertEqual(
                    i + 1,
                    tree_aggregation.get_step_idx(
                        server_state.optimizer_state['dp_tree_state']))
        self.assertLess(np.mean(loss_list[1:]), loss_list[0])
Example #10
0
    def test_simple_training(self):
        it_process = build_federated_averaging_process(models.model_fn)
        server_state = it_process.initialize()
        Batch = collections.namedtuple('Batch', ['x', 'y'])  # pylint: disable=invalid-name

        # Test out manually setting weights:
        keras_model = models.create_keras_model(compile_model=True)

        def deterministic_batch():
            return Batch(x=np.ones([1, 784], dtype=np.float32),
                         y=np.ones([1, 1], dtype=np.int64))

        batch = tff.tf_computation(deterministic_batch)()
        federated_data = [[batch]]

        def keras_evaluate(state):
            tff.learning.assign_weights_to_keras_model(keras_model,
                                                       state.model)
            # N.B. The loss computed here won't match the
            # loss computed by TFF because of the Dropout layer.
            keras_model.test_on_batch(batch.x, batch.y)

        loss_list = []
        for _ in range(3):
            keras_evaluate(server_state)
            server_state, loss = it_process.next(server_state, federated_data)
            loss_list.append(loss)
        keras_evaluate(server_state)

        self.assertLess(np.mean(loss_list[1:]), loss_list[0])
    def _create_next_fn(self, inner_agg_next, state_type):

        value_type = inner_agg_next.type_signature.parameter[1]
        modular_clip_by_value_tff = tff.tf_computation(modular_clip_by_value)

        @tff.federated_computation(state_type, value_type)
        def next_fn(state, value):
            clip_range_lower, clip_range_upper = self._get_clip_range()

            # Modular clip values before aggregation.
            clipped_value = tff.federated_map(
                modular_clip_by_value_tff,
                (value, tff.federated_broadcast(clip_range_lower),
                 tff.federated_broadcast(clip_range_upper)))

            (agg_output_state, agg_output_result,
             agg_output_measurements) = inner_agg_next(state, clipped_value)

            # Clip the aggregate to the same range again (not considering summands).
            clipped_agg_output_result = tff.federated_map(
                modular_clip_by_value_tff,
                (agg_output_result, clip_range_lower, clip_range_upper))

            measurements = collections.OrderedDict(
                agg_process=agg_output_measurements)

            return tff.templates.MeasuredProcessOutput(
                state=agg_output_state,
                result=clipped_agg_output_result,
                measurements=tff.federated_zip(measurements))

        return next_fn
Example #12
0
def iterator(
  model_fn: MODEL_FN,
  client_state_fn: CLIENT_STATE_FN,
  client_optimizer_fn: OPTIMIZER_FN
):
  model = model_fn()
  client_state = client_state_fn()

  init_tf = tff.tf_computation(
    lambda: ()
  )
  
  server_state_type = init_tf.type_signature.result
  client_state_type = tff.framework.type_from_tensors(client_state)
  dataset_type = tff.SequenceType(model.input_spec)
  
  update_client_tf = tff.tf_computation(
    lambda dataset, state: __update_client(
      dataset,
      state,
      model_fn,
      client_optimizer_fn,
      tf.function(client.update)
    ),
    (dataset_type, client_state_type)
  )
  
  federated_server_state_type = tff.type_at_server(server_state_type)
  federated_dataset_type = tff.type_at_clients(dataset_type)
  federated_client_state_type = tff.type_at_clients(client_state_type)

  def init_tff():
    return tff.federated_value(init_tf(), tff.SERVER)
  
  def next_tff(server_state, datasets, client_states):
    outputs = tff.federated_map(update_client_tf, (datasets, client_states))
    metrics = model.federated_output_computation(outputs.metrics)

    return server_state, metrics, outputs.client_state

  return tff.templates.IterativeProcess(
    initialize_fn=tff.federated_computation(init_tff),
    next_fn=tff.federated_computation(
      next_tff,
      (federated_server_state_type, federated_dataset_type, federated_client_state_type)
    )
  )
Example #13
0
    def test_inferred_type_assignable_to_type_spec(self):
        tf_comp = tff.tf_computation(create_sparse)
        type_from_return = tf_comp.type_signature.result

        sparse_tensor_spec = tf.SparseTensorSpec.from_value(create_sparse())
        type_from_spec = tff.to_type(sparse_tensor_spec)

        type_from_spec.check_assignable_from(type_from_return)
Example #14
0
    def test_inferred_type_assignable_to_type_spec(self):
        tf_comp = tff.tf_computation(create_ragged)
        type_from_return = tf_comp.type_signature.result

        ragged_tensor_spec = tf.RaggedTensorSpec.from_value(create_ragged())
        type_from_spec = tff.to_type(ragged_tensor_spec)

        type_from_spec.check_assignable_from(type_from_return)
Example #15
0
  def test_dpftal_restart(self, total_rounds=3):

    def server_optimizer_fn(model_weights):
      model_weight_specs = tf.nest.map_structure(
          lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights)
      return optimizer_utils.DPFTRLMServerOptimizer(
          learning_rate=0.1,
          momentum=0.9,
          noise_std=1e-5,
          model_weight_specs=model_weight_specs,
          efficient_tree=True,
          use_nesterov=True)

    it_process = dp_fedavg.build_federated_averaging_process(
        _rnn_model_fn,
        server_optimizer_fn=server_optimizer_fn,
        use_simulation_loop=True)
    server_state = it_process.initialize()

    model = _rnn_model_fn()
    optimizer = server_optimizer_fn(model.weights.trainable)

    def server_state_update(state):
      return tff.structure.update_struct(
          state,
          model=state.model,
          optimizer_state=optimizer.restart_dp_tree(state.model.trainable),
          round_num=state.round_num)

    def deterministic_batch():
      return collections.OrderedDict(
          x=np.array([[0, 1, 2, 3, 4]], dtype=np.int32),
          y=np.array([[1, 2, 3, 4, 0]], dtype=np.int32))

    batch = tff.tf_computation(deterministic_batch)()
    federated_data = [[batch]]

    loss_list = []
    for i in range(total_rounds):
      server_state, loss = it_process.next(server_state, federated_data)
      server_state = server_state_update(server_state)
      loss_list.append(loss)
      self.assertEqual(i + 1, server_state.round_num)
      self.assertEqual(
          0,
          tree_aggregation.get_step_idx(
              server_state.optimizer_state.dp_tree_state))
    self.assertLess(np.mean(loss_list[1:]), loss_list[0])
Example #16
0
    def test_simple_training(self, model_fn):
        it_process = dp_fedavg.build_federated_averaging_process(model_fn)
        server_state = it_process.initialize()

        def deterministic_batch():
            return collections.OrderedDict(x=np.ones([1, 28, 28, 1],
                                                     dtype=np.float32),
                                           y=np.ones([1], dtype=np.int32))

        batch = tff.tf_computation(deterministic_batch)()
        federated_data = [[batch]]

        loss_list = []
        for _ in range(3):
            server_state, loss = it_process.next(server_state, federated_data)
            loss_list.append(loss)

        self.assertLess(np.mean(loss_list[1:]), loss_list[0])
Example #17
0
    def test_simple_training(self, model_fn):
        it_process = simple_fedavg_tff.build_federated_averaging_process(
            model_fn)
        server_state = it_process.initialize()
        Batch = collections.namedtuple('Batch', ['x', 'y'])  # pylint: disable=invalid-name

        def deterministic_batch():
            return Batch(x=np.ones([1, 28, 28, 1], dtype=np.float32),
                         y=np.ones([1], dtype=np.int32))

        batch = tff.tf_computation(deterministic_batch)()
        federated_data = [[batch]]

        loss_list = []
        for _ in range(3):
            server_state, loss = it_process.next(server_state, federated_data)
            loss_list.append(loss)

        self.assertLess(np.mean(loss_list[1:]), loss_list[0])
Example #18
0
  def test_client_adagrad_train(self):
    it_process = simple_fedavg_tff.build_federated_averaging_process(
        _rnn_model_fn,
        client_optimizer_fn=functools.partial(
            tf.keras.optimizers.Adagrad, learning_rate=0.01))
    server_state = it_process.initialize()

    def deterministic_batch():
      return collections.OrderedDict(
          x=np.array([[0, 1, 2, 3, 4]], dtype=np.int32),
          y=np.array([[1, 2, 3, 4, 0]], dtype=np.int32))

    batch = tff.tf_computation(deterministic_batch)()
    federated_data = [[batch]]

    loss_list = []
    for _ in range(3):
      server_state, loss = it_process.next(server_state, federated_data)
      loss_list.append(loss)

    self.assertLess(np.mean(loss_list[1:]), loss_list[0])
Example #19
0
    def run_one_round(server_state, federated_dataset, client_weight):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_model, client_round_num))

        #client_weight = client_outputs.client_weight
        # model_delta = tff.federated_mean(
        #     client_outputs.weights_delta, weight=client_weight)

        participant_client_weight = tff.federated_map(
            tff.tf_computation(lambda x, y: x * y),
            (client_weight, client_outputs.client_weight))

        aggregation_output = aggregation_process.next(
            server_state.aggregation_state, client_outputs.weights_delta,
            participant_client_weight)

        server_state = tff.federated_map(
            server_update_fn, (server_state, aggregation_output.result))

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
Example #20
0
    def test_training_keras_model_converges(self):
        it_process = simple_fedavg_tff.build_federated_averaging_process(
            _tff_learning_model_fn)
        server_state = it_process.initialize()

        def deterministic_batch():
            return collections.OrderedDict(x=np.ones([1, 28, 28, 1],
                                                     dtype=np.float32),
                                           y=np.ones([1], dtype=np.int32))

        batch = tff.tf_computation(deterministic_batch)()
        federated_data = [tf.data.Dataset.from_tensor_slices(batch).batch(1)]

        previous_loss = None
        for _ in range(10):
            server_state, outputs = it_process.next(server_state,
                                                    federated_data)
            loss = outputs['loss']
            if previous_loss is not None:
                self.assertLessEqual(loss, previous_loss)
            previous_loss = loss
        self.assertLess(loss, 0.1)
Example #21
0
  def test_simple_training(self, model_fn):
    it_process = stateful_fedavg_tff.build_federated_averaging_process(
        model_fn, _create_one_client_state)
    server_state = it_process.initialize()

    def deterministic_batch():
      return collections.OrderedDict(
          x=np.ones([1, 28, 28, 1], dtype=np.float32),
          y=np.ones([1], dtype=np.int32))

    batch = tff.tf_computation(deterministic_batch)()
    federated_data = [[batch]]
    client_states = [_create_one_client_state()]

    loss_list = []
    for _ in range(3):
      server_state, loss, client_states = it_process.next(
          server_state, federated_data, client_states)
      loss_list.append(loss)

    self.assertEqual(server_state.total_iters_count,
                     client_states[0].iters_count)
    self.assertLess(np.mean(loss_list[1:]), loss_list[0])
Example #22
0
def evaluator(
  model_fn: MODEL_FN,
  client_state_fn: CLIENT_STATE_FN
):
  model = model_fn()
  client_state = client_state_fn()

  dataset_type = tff.SequenceType(model.input_spec)
  client_state_type = tff.framework.type_from_tensors(client_state)

  evaluate_client_tf = tff.tf_computation(
    lambda dataset, state: __evaluate_client(
      dataset,
      state,
      model_fn,
      tf.function(client.evaluate)
    ),
    (dataset_type, client_state_type)
  )

  federated_dataset_type = tff.type_at_clients(dataset_type)
  federated_client_state_type = tff.type_at_clients(client_state_type)    

  def evaluate(datasets, client_states):
    outputs = tff.federated_map(evaluate_client_tf, (datasets, client_states))
    
    confusion_matrix = tff.federated_sum(outputs.confusion_matrix)
    aggregated_metrics = model.federated_output_computation(outputs.metrics)
    collected_metrics = tff.federated_collect(outputs.metrics)

    return confusion_matrix, aggregated_metrics, collected_metrics

  return tff.federated_computation(
    evaluate,
    (federated_dataset_type, federated_client_state_type)
  )