Пример #1
0
 def test_learning_process_can_be_reconstructed(self):
   process = learning_process.LearningProcess(test_init_fn, test_next_fn,
                                              test_get_model_weights_fn,
                                              test_set_model_weights_fn)
   try:
     learning_process.LearningProcess(process.initialize, process.next,
                                      process.get_model_weights,
                                      process.set_model_weights)
   except:  # pylint: disable=bare-except
     self.fail('Could not reconstruct the LearningProcess.')
Пример #2
0
 def test_federated_get_model_weights_raises(self):
   bad_get_model_weights = create_pass_through_get_model_weights(
       at_server(tf.float32))
   with self.assertRaises(learning_process.GetModelWeightsTypeSignatureError):
     learning_process.LearningProcess(test_init_fn, test_next_fn,
                                      bad_get_model_weights,
                                      test_set_model_weights_fn)
Пример #3
0
 def test_next_not_tff_computation_raises(self):
   with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'):
     learning_process.LearningProcess(
         initialize_fn=test_init_fn,
         next_fn=lambda state, client_data: LearningProcessOutput(state, ()),
         get_model_weights=test_get_model_weights_fn,
         set_model_weights=test_set_model_weights_fn)
Пример #4
0
 def test_set_model_weights_result_not_assignable(self):
   bad_set_model_weights_fn = create_take_arg_set_model_weights(
       tf.int32, tf.float32)
   with self.assertRaises(learning_process.SetModelWeightsTypeSignatureError):
     learning_process.LearningProcess(test_init_fn, test_next_fn,
                                      test_get_model_weights_fn,
                                      bad_set_model_weights_fn)
Пример #5
0
  def test_construction_with_nested_datasets_does_not_raise(self):

    @federated_computation
    def initialize_fn():
      return intrinsics.federated_eval(
          tf_computation(lambda: tf.constant(0.0, tf.float32)),
          placements.SERVER)

    # Test that clients can receive multiple datasets.
    datasets_type = (SequenceType(tf.string), (SequenceType(tf.string),
                                               SequenceType(tf.string)))

    @federated_computation(at_server(tf.float32), at_clients(datasets_type))
    def next_fn(state, datasets):
      del datasets  # Unused.
      return LearningProcessOutput(
          state, intrinsics.federated_value((), placements.SERVER))

    try:
      learning_process.LearningProcess(
          initialize_fn, next_fn,
          create_pass_through_get_model_weights(tf.float32),
          create_take_arg_set_model_weights(tf.float32, tf.float32))
    except learning_process.LearningProcessSequenceTypeError:
      self.fail('Could not construct a LearningProcess with second parameter '
                'type having nested sequences.')
Пример #6
0
  def test_construction_with_unknown_dimension_does_not_raise(self):

    @federated_computation
    def initialize_fn():
      return intrinsics.federated_eval(
          tf_computation(lambda: tf.constant([], tf.string)), placements.SERVER)

    # This replicates a tensor that can grow in string length. The
    # `initialize_fn` will concretely start with shape `[0]`, but `next_fn` will
    # grow this, hence the need to define the shape as `[None]`.
    none_dimension_string_type = TensorType(shape=[None], dtype=tf.string)

    @federated_computation(
        at_server(none_dimension_string_type),
        at_clients(SequenceType(tf.string)))
    def next_fn(state, datasets):
      del datasets  # Unused.
      return LearningProcessOutput(
          state, intrinsics.federated_value((), placements.SERVER))

    try:
      learning_process.LearningProcess(
          initialize_fn, next_fn,
          create_pass_through_get_model_weights(none_dimension_string_type),
          create_take_arg_set_model_weights(none_dimension_string_type,
                                            none_dimension_string_type))
    except:  # pylint: disable=bare-except
      self.fail('Could not construct a LearningProcess with state type having '
                'statically unknown shape.')
Пример #7
0
 def test_construction_does_not_raise(self):
   try:
     learning_process.LearningProcess(test_init_fn, test_next_fn,
                                      test_get_model_weights_fn,
                                      test_set_model_weights_fn)
   except:  # pylint: disable=bare-except
     self.fail('Could not construct a valid LearningProcess.')
Пример #8
0
  def test_next_fn_with_one_parameter_raises(self):

    @federated_computation(at_server(tf.int32))
    def next_fn(state):
      return LearningProcessOutput(state, 0)

    with self.assertRaises(errors.TemplateNextFnNumArgsError):
      learning_process.LearningProcess(test_init_fn, next_fn,
                                       test_get_model_weights_fn,
                                       test_set_model_weights_fn)
Пример #9
0
  def test_init_state_not_federated(self):

    @federated_computation
    def float_initialize_fn():
      return 0.0

    with self.assertRaises(errors.TemplateStateNotAssignableError):
      learning_process.LearningProcess(float_initialize_fn, test_next_fn,
                                       test_get_model_weights_fn,
                                       test_set_model_weights_fn)
Пример #10
0
  def test_init_param_not_empty_raises(self):

    @federated_computation(at_server(tf.int32))
    def one_arg_initialize_fn(x):
      return x

    with self.assertRaises(errors.TemplateInitFnParamNotEmptyError):
      learning_process.LearningProcess(one_arg_initialize_fn, test_next_fn,
                                       test_get_model_weights_fn,
                                       test_set_model_weights_fn)
Пример #11
0
  def test_next_fn_with_nonsequence_second_arg_raises(self):

    @federated_computation(at_server(tf.int32), at_clients(tf.int32))
    def next_fn(state, client_values):
      metrics = intrinsics.federated_sum(client_values)
      return LearningProcessOutput(state, metrics)

    with self.assertRaises(learning_process.LearningProcessSequenceTypeError):
      learning_process.LearningProcess(test_init_fn, next_fn,
                                       test_get_model_weights_fn,
                                       test_set_model_weights_fn)
Пример #12
0
  def test_next_fn_with_client_placed_metrics_result_raises(self):

    @federated_computation(
        at_server(tf.int32), at_clients(SequenceType(tf.int32)))
    def next_fn(state, metrics):
      return LearningProcessOutput(state, metrics)

    with self.assertRaises(learning_process.LearningProcessPlacementError):
      learning_process.LearningProcess(test_init_fn, next_fn,
                                       test_get_model_weights_fn,
                                       test_set_model_weights_fn)
Пример #13
0
  def test_next_state_not_federated(self):

    @federated_computation(tf.float32, at_clients(SequenceType(tf.float32)))
    def float_next_fn(state, datasets):
      del datasets  # Unused.
      return state

    float_next_fn = create_pass_through_get_model_weights(tf.float32)
    with self.assertRaises(errors.TemplateStateNotAssignableError):
      learning_process.LearningProcess(test_init_fn, float_next_fn,
                                       test_get_model_weights_fn,
                                       test_set_model_weights_fn)
Пример #14
0
  def test_next_fn_with_server_placed_second_arg_raises(self):

    @federated_computation(
        at_server(tf.int32), at_server(SequenceType(tf.int32)))
    def next_fn(state, server_values):
      metrics = intrinsics.federated_map(sum_dataset, server_values)
      return LearningProcessOutput(state, metrics)

    with self.assertRaises(learning_process.LearningProcessPlacementError):
      learning_process.LearningProcess(test_init_fn, next_fn,
                                       test_get_model_weights_fn,
                                       test_set_model_weights_fn)
Пример #15
0
  def test_next_return_odict_raises(self):

    @federated_computation(
        at_server(tf.int32), at_clients(SequenceType(tf.int32)))
    def odict_next_fn(state, client_values):
      metrics = intrinsics.federated_map(sum_dataset, client_values)
      metrics = intrinsics.federated_sum(metrics)
      return collections.OrderedDict(state=state, metrics=metrics)

    with self.assertRaises(learning_process.LearningProcessOutputError):
      learning_process.LearningProcess(test_init_fn, odict_next_fn,
                                       test_get_model_weights_fn,
                                       test_set_model_weights_fn)
Пример #16
0
  def test_init_fn_with_client_placed_state_raises(self):

    @federated_computation
    def init_fn():
      return intrinsics.federated_value(0, placements.CLIENTS)

    @federated_computation(
        at_clients(tf.int32), at_clients(SequenceType(tf.int32)))
    def next_fn(state, client_values):
      return LearningProcessOutput(state, client_values)

    with self.assertRaises(learning_process.LearningProcessPlacementError):
      learning_process.LearningProcess(init_fn, next_fn,
                                       test_get_model_weights_fn,
                                       test_set_model_weights_fn)
Пример #17
0
  def test_next_fn_with_three_parameters_raises(self):

    @federated_computation(
        at_server(tf.int32), at_clients(SequenceType(tf.int32)),
        at_server(tf.int32))
    def next_fn(state, client_values, second_state):
      del second_state  # Unused.
      metrics = intrinsics.federated_map(sum_dataset, client_values)
      metrics = intrinsics.federated_sum(metrics)
      return LearningProcessOutput(state, metrics)

    with self.assertRaises(errors.TemplateNextFnNumArgsError):
      learning_process.LearningProcess(test_init_fn, next_fn,
                                       test_get_model_weights_fn,
                                       test_set_model_weights_fn)
Пример #18
0
  def test_next_return_namedtuple_raises(self):

    learning_process_output = collections.namedtuple('LearningProcessOutput',
                                                     ['state', 'metrics'])

    @federated_computation(
        at_server(tf.int32), at_clients(SequenceType(tf.int32)))
    def namedtuple_next_fn(state, client_values):
      metrics = intrinsics.federated_map(sum_dataset, client_values)
      metrics = intrinsics.federated_sum(metrics)
      return learning_process_output(state, metrics)

    with self.assertRaises(learning_process.LearningProcessOutputError):
      learning_process.LearningProcess(test_init_fn, namedtuple_next_fn,
                                       test_get_model_weights_fn,
                                       test_set_model_weights_fn)
Пример #19
0
  def test_returns_iterative_process_with_same_non_next_type_signatures(self):

    @tensorflow_computation.tf_computation()
    def make_zero():
      return tf.cast(0, tf.int64)

    @federated_computation.federated_computation()
    def initialize_fn():
      return intrinsics.federated_eval(make_zero, placements.SERVER)

    @federated_computation.federated_computation(
        (initialize_fn.type_signature.result,
         computation_types.FederatedType(
             computation_types.SequenceType(tf.int64), placements.CLIENTS)))
    def next_fn(server_state, client_data):
      del client_data
      return learning_process.LearningProcessOutput(
          state=server_state,
          metrics=intrinsics.federated_value((), placements.SERVER))

    @tensorflow_computation.tf_computation(tf.int64)
    def get_model_weights(server_state):
      return server_state + 1

    @tensorflow_computation.tf_computation(tf.int64, tf.int64)
    def set_model_weights(state, state_update):
      return state + state_update

    process = learning_process.LearningProcess(
        initialize_fn=initialize_fn,
        next_fn=next_fn,
        get_model_weights=get_model_weights,
        set_model_weights=set_model_weights)
    new_process = iterative_process_compositions.compose_dataset_computation_with_learning_process(
        int_dataset_computation, process)

    self.assertIsInstance(new_process, iterative_process.IterativeProcess)
    self.assertTrue(hasattr(new_process, 'get_model_weights'))
    self.assertTrue(
        new_process.get_model_weights.type_signature.is_equivalent_to(
            process.get_model_weights.type_signature))
    self.assertTrue(hasattr(new_process, 'set_model_weights'))
    self.assertTrue(
        new_process.set_model_weights.type_signature.is_equivalent_to(
            process.set_model_weights.type_signature))
Пример #20
0
  def test_construction_with_empty_state_does_not_raise(self):
    empty_tuple = ()

    @federated_computation
    def empty_initialize_fn():
      return intrinsics.federated_value(empty_tuple, placements.SERVER)

    @federated_computation(
        at_server(empty_tuple), at_clients(SequenceType(tf.int32)))
    def next_fn(state, value):
      del value  # Unused.
      return LearningProcessOutput(
          state=state,
          metrics=intrinsics.federated_value(empty_tuple, placements.SERVER))

    try:
      learning_process.LearningProcess(
          empty_initialize_fn, next_fn,
          create_pass_through_get_model_weights(empty_tuple),
          create_take_arg_set_model_weights(empty_tuple, empty_tuple))
    except:  # pylint: disable=bare-except
      self.fail('Could not construct a LearningProcess with empty state.')
Пример #21
0
def compose_learning_process(
    initial_model_weights_fn: computation_base.Computation,
    model_weights_distributor: distributors.DistributionProcess,
    client_work: client_works.ClientWorkProcess,
    model_update_aggregator: aggregation_process.AggregationProcess,
    model_finalizer: finalizers.FinalizerProcess
) -> learning_process.LearningProcess:
    """Composes specialized measured processes into a learning process.

  Given 4 specialized measured processes (described below) that make a learning
  process, and a computation that returns initial model weights to be used for
  training, this method validates that the processes fit together, and returns a
  `LearningProcess`. Please see the tutorial at
  https://www.tensorflow.org/federated/tutorials/composing_learning_algorithms
  for more details on composing learning processes.

  The main purpose of the 4 measured processes are:
    * `model_weights_distributor`: Make global model weights at server available
      as the starting point for learning work to be done at clients.
    * `client_work`: Produce an update to the model received at clients.
    * `model_update_aggregator`: Aggregates the model updates from clients to
      the server.
    * `model_finalizer`: Updates the global model weights using the aggregated
      model update at server.

  The `next` computation of the created learning process is composed from the
  `next` computations of the 4 measured processes, in order as visualized below.
  The type signatures of the processes must be such that this chaining is
  possible. Each process also reports its own metrics.

  ```
  ┌─────────────────────────┐
  │model_weights_distributor│
  └△─┬─┬────────────────────┘
   │ │┌▽──────────┐
   │ ││client_work│
   │ │└┬─────┬────┘
   │┌▽─▽────┐│
   ││metrics││
   │└△─△────┘│
   │ │┌┴─────▽────────────────┐
   │ ││model_update_aggregator│
   │ │└┬──────────────────────┘
  ┌┴─┴─▽──────────┐
  │model_finalizer│
  └┬──────────────┘
  ┌▽─────┐
  │result│
  └──────┘
  ```

  Args:
    initial_model_weights_fn: A `tff.Computation` that returns (unplaced)
      initial model weights.
    model_weights_distributor: A `DistributionProcess`.
    client_work: A `ClientWorkProcess`.
    model_update_aggregator: A `tff.templates.AggregationProcess`.
    model_finalizer: A `FinalizerProcess`.

  Returns:
    A `LearningProcess`.
  """
    # pyformat: enable
    _validate_args(initial_model_weights_fn, model_weights_distributor,
                   client_work, model_update_aggregator, model_finalizer)
    client_data_type = client_work.next.type_signature.parameter[2]

    @federated_computation.federated_computation()
    def init_fn():
        initial_model_weights = intrinsics.federated_eval(
            initial_model_weights_fn, placements.SERVER)
        return intrinsics.federated_zip(
            LearningAlgorithmState(initial_model_weights,
                                   model_weights_distributor.initialize(),
                                   client_work.initialize(),
                                   model_update_aggregator.initialize(),
                                   model_finalizer.initialize()))

    @federated_computation.federated_computation(init_fn.type_signature.result,
                                                 client_data_type)
    def next_fn(state, client_data):
        # Compose processes.
        distributor_output = model_weights_distributor.next(
            state.distributor, state.global_model_weights)
        client_work_output = client_work.next(state.client_work,
                                              distributor_output.result,
                                              client_data)
        aggregator_output = model_update_aggregator.next(
            state.aggregator, client_work_output.result.update,
            client_work_output.result.update_weight)
        finalizer_output = model_finalizer.next(state.finalizer,
                                                state.global_model_weights,
                                                aggregator_output.result)

        # Form the learning process output.
        new_global_model_weights = finalizer_output.result
        new_state = intrinsics.federated_zip(
            LearningAlgorithmState(new_global_model_weights,
                                   distributor_output.state,
                                   client_work_output.state,
                                   aggregator_output.state,
                                   finalizer_output.state))
        metrics = intrinsics.federated_zip(
            collections.OrderedDict(
                distributor=distributor_output.measurements,
                client_work=client_work_output.measurements,
                aggregator=aggregator_output.measurements,
                finalizer=finalizer_output.measurements))

        return learning_process.LearningProcessOutput(new_state, metrics)

    state_parameter_type = next_fn.type_signature.parameter[0].member

    @tensorflow_computation.tf_computation(state_parameter_type)
    def get_model_weights_fn(state):
        return state.global_model_weights

    @tensorflow_computation.tf_computation(
        state_parameter_type, state_parameter_type.global_model_weights)
    def set_model_weights_fn(state, model_weights):
        return attr.evolve(state, global_model_weights=model_weights)

    return learning_process.LearningProcess(init_fn, next_fn,
                                            get_model_weights_fn,
                                            set_model_weights_fn)
Пример #22
0
 def test_init_not_tff_computation_raises(self):
   with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'):
     init_fn = lambda: 0
     learning_process.LearningProcess(init_fn, test_next_fn,
                                      test_get_model_weights_fn,
                                      test_set_model_weights_fn)
Пример #23
0
 def test_get_model_weights_param_not_equivalent_to_next_fn(self):
   bad_get_model_weights = create_pass_through_get_model_weights(tf.float32)
   with self.assertRaises(learning_process.GetModelWeightsTypeSignatureError):
     learning_process.LearningProcess(test_init_fn, test_next_fn,
                                      bad_get_model_weights,
                                      test_set_model_weights_fn)
Пример #24
0
 def test_non_tff_computation_get_model_weights_raises(self):
   get_model_weights = lambda x: x
   with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'):
     learning_process.LearningProcess(test_init_fn, test_next_fn,
                                      get_model_weights,
                                      test_set_model_weights_fn)
Пример #25
0
 def test_non_functional_get_model_weights_raises(self):
   get_model_weights = at_server(tf.int32)
   with self.assertRaises(TypeError):
     learning_process.LearningProcess(test_init_fn, test_next_fn,
                                      get_model_weights,
                                      test_set_model_weights_fn)