Ejemplo n.º 1
0
 def test_learning_process_can_be_reconstructed(self):
     process = learning_process.LearningProcess(test_init_fn, test_next_fn,
                                                test_report_fn)
     try:
         learning_process.LearningProcess(process.initialize, process.next,
                                          process.report)
     except:  # pylint: disable=bare-except
         self.fail('Could not reconstruct the LearningProcess.')
Ejemplo n.º 2
0
 def test_next_not_tff_computation_raises(self):
     with self.assertRaisesRegex(TypeError,
                                 r'Expected .*\.Computation, .*'):
         learning_process.LearningProcess(initialize_fn=test_init_fn,
                                          next_fn=lambda state, client_data:
                                          LearningProcessOutput(state, ()),
                                          report_fn=test_report_fn)
Ejemplo n.º 3
0
  def test_next_fn_with_client_placed_metrics_result_raises(self):

    @computations.federated_computation(test_state_type, ClientIntSequenceType)
    def next_fn(state, metrics):
      return LearningProcessOutput(state, metrics)

    with self.assertRaises(learning_process.LearningProcessPlacementError):
      learning_process.LearningProcess(test_init_fn, next_fn, test_report_fn)
Ejemplo n.º 4
0
    def test_next_fn_with_one_parameter_raises(self):
        @computations.federated_computation(test_state_type)
        def next_fn(state):
            return LearningProcessOutput(state, 0)

        with self.assertRaises(errors.TemplateNextFnNumArgsError):
            learning_process.LearningProcess(test_init_fn, next_fn,
                                             test_report_fn)
Ejemplo n.º 5
0
    def test_init_param_not_empty_raises(self):
        @computations.federated_computation(
            computation_types.FederatedType(tf.int32, placements.SERVER))
        def one_arg_initialize_fn(x):
            return x

        with self.assertRaises(errors.TemplateInitFnParamNotEmptyError):
            learning_process.LearningProcess(one_arg_initialize_fn,
                                             test_next_fn, test_report_fn)
Ejemplo n.º 6
0
  def test_next_fn_with_non_sequence_second_arg_raises(self):
    ints_at_clients = computation_types.FederatedType(tf.int32,
                                                      placements.CLIENTS)

    @computations.federated_computation(test_state_type, ints_at_clients)
    def next_fn(state, client_values):
      metrics = intrinsics.federated_sum(client_values)
      return LearningProcessOutput(state, metrics)

    with self.assertRaises(learning_process.LearningProcessSequenceTypeError):
      learning_process.LearningProcess(test_init_fn, next_fn, test_report_fn)
Ejemplo n.º 7
0
  def test_next_fn_with_three_parameters_raises(self):

    @computations.federated_computation(test_state_type, ClientIntSequenceType,
                                        test_state_type)
    def next_fn(state, client_values, second_state):  # pylint: disable=unused-argument
      metrics = intrinsics.federated_map(sum_sequence, client_values)
      metrics = intrinsics.federated_sum(metrics)
      return LearningProcessOutput(state, metrics)

    with self.assertRaises(errors.TemplateNextFnNumArgsError):
      learning_process.LearningProcess(test_init_fn, next_fn, test_report_fn)
Ejemplo n.º 8
0
    def test_next_return_odict_raises(self):
        @computations.federated_computation(test_state_type,
                                            ClientIntSequenceType)
        def odict_next_fn(state, client_values):
            metrics = intrinsics.federated_map(sum_sequence, client_values)
            metrics = intrinsics.federated_sum(metrics)
            return collections.OrderedDict(state=state, metrics=metrics)

        with self.assertRaises(learning_process.LearningProcessOutputError):
            learning_process.LearningProcess(test_init_fn, odict_next_fn,
                                             test_report_fn)
Ejemplo n.º 9
0
    def test_init_fn_with_client_placed_state_raises(self):

        init_fn = computations.federated_computation(
            lambda: intrinsics.federated_value(0, placements.CLIENTS))

        @computations.federated_computation(init_fn.type_signature.result,
                                            ClientIntSequenceType)
        def next_fn(state, client_values):
            return LearningProcessOutput(state, client_values)

        with self.assertRaises(learning_process.LearningProcessPlacementError):
            learning_process.LearningProcess(init_fn, next_fn, test_report_fn)
Ejemplo n.º 10
0
  def test_next_fn_with_non_client_placed_second_arg_raises(self):

    int_sequence_at_server = computation_types.FederatedType(
        computation_types.SequenceType(tf.int32), placements.SERVER)

    @computations.federated_computation(test_state_type, int_sequence_at_server)
    def next_fn(state, server_values):
      metrics = intrinsics.federated_map(sum_sequence, server_values)
      return LearningProcessOutput(state, metrics)

    with self.assertRaises(learning_process.LearningProcessPlacementError):
      learning_process.LearningProcess(test_init_fn, next_fn, test_report_fn)
Ejemplo n.º 11
0
  def test_construction_with_empty_state_does_not_raise(self):

    @computations.federated_computation()
    def empty_initialize_fn():
      return intrinsics.federated_value((), placements.SERVER)

    next_fn = build_next_fn(empty_initialize_fn)
    report_fn = build_report_fn(empty_initialize_fn)

    try:
      learning_process.LearningProcess(empty_initialize_fn, next_fn, report_fn)
    except:  # pylint: disable=bare-except
      self.fail('Could not construct a LearningProcess with empty state.')
Ejemplo n.º 12
0
  def test_next_return_namedtuple_raises(self):

    learning_process_output = collections.namedtuple('LearningProcessOutput',
                                                     ['state', 'metrics'])

    @computations.federated_computation(test_state_type, ClientIntSequenceType)
    def namedtuple_next_fn(state, client_values):
      metrics = intrinsics.federated_map(sum_sequence, client_values)
      metrics = intrinsics.federated_sum(metrics)
      return learning_process_output(state, metrics)

    with self.assertRaises(learning_process.LearningProcessOutputError):
      learning_process.LearningProcess(test_init_fn, namedtuple_next_fn,
                                       test_report_fn)
Ejemplo n.º 13
0
  def test_construction_with_unknown_dimension_does_not_raise(self):
    create_empty_string = computations.tf_computation()(
        lambda: tf.constant([], dtype=tf.string))
    initialize_fn = computations.federated_computation()(
        lambda: intrinsics.federated_value(create_empty_string(), placements.
                                           SERVER))
    next_fn = build_next_fn(initialize_fn)
    report_fn = build_report_fn(initialize_fn)

    try:
      learning_process.LearningProcess(initialize_fn, next_fn, report_fn)
    except:  # pylint: disable=bare-except
      self.fail('Could not construct a LearningProcess with state type having '
                'statically unknown shape.')
Ejemplo n.º 14
0
 def test_construction_does_not_raise(self):
     try:
         learning_process.LearningProcess(test_init_fn, test_next_fn,
                                          test_report_fn)
     except:  # pylint: disable=bare-except
         self.fail('Could not construct a valid LearningProcess.')
Ejemplo n.º 15
0
 def test_report_param_not_assignable(self):
     report_fn = computations.tf_computation(tf.float32)(lambda x: x)
     with self.assertRaises(learning_process.ReportFnTypeSignatureError):
         learning_process.LearningProcess(test_init_fn, test_next_fn,
                                          report_fn)
Ejemplo n.º 16
0
 def test_federated_report_fn_raises(self):
     report_fn = computations.federated_computation(test_state_type)(
         lambda x: x)
     with self.assertRaises(learning_process.ReportFnTypeSignatureError):
         learning_process.LearningProcess(test_init_fn, test_next_fn,
                                          report_fn)
Ejemplo n.º 17
0
 def test_non_tff_computation_report_fn_raises(self):
     report_fn = lambda x: x
     with self.assertRaisesRegex(TypeError,
                                 r'Expected .*\.Computation, .*'):
         learning_process.LearningProcess(test_init_fn, test_next_fn,
                                          report_fn)
Ejemplo n.º 18
0
 def test_next_state_not_assignable(self):
     float_initialize_fn = computations.federated_computation()(lambda: 0.0)
     float_next_fn = build_next_fn(float_initialize_fn)
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         learning_process.LearningProcess(test_init_fn, float_next_fn,
                                          test_report_fn)
Ejemplo n.º 19
0
def compose_learning_process(
    initial_model_weights_fn: computation_base.Computation,
    model_weights_distributor: distributors.DistributionProcess,
    client_work: client_works.ClientWorkProcess,
    model_update_aggregator: aggregation_process.AggregationProcess,
    model_finalizer: finalizers.FinalizerProcess
) -> learning_process.LearningProcess:
    """Composes specialized measured processes into a learning process.

  Given the 4 specialized measured processes that make a learning process as
  documented in [TODO(b/190334722)], and a computation that returns initial
  model weights to be used for training, this method validates that the
  processes fit together, and returns a `LearningProcess`.

  The main purpose of the 4 measured processes are:
    * `model_weights_distributor`: Make global model weights at server available
      as the starting point for learning work to be done at clients.
    * `client_work`: Produce an update to the model received at clients.
    * `model_update_aggregator`: Aggregates the model updates from clients to
      the server.
    * `model_finalizer`: Updates the global model weights using the aggregated
      model update at server.

  The `next` computation of the created learning process is composed from the
  `next` computations of the 4 measured processes, in order as visualized below.
  The type signatures of the processes must be such that this chaining is
  possible. Each process also reports its own metrics.

  ```
  ┌─────────────────────────┐
  │model_weights_distributor│
  └△─┬─┬────────────────────┘
   │ │┌▽──────────┐
   │ ││client_work│
   │ │└┬─────┬────┘
   │┌▽─▽────┐│
   ││metrics││
   │└△─△────┘│
   │ │┌┴─────▽────────────────┐
   │ ││model_update_aggregator│
   │ │└┬──────────────────────┘
  ┌┴─┴─▽──────────┐
  │model_finalizer│
  └┬──────────────┘
  ┌▽─────┐
  │result│
  └──────┘
  ```

  Args:
    initial_model_weights_fn: A `tff.Computation` that returns (unplaced)
      initial model weights.
    model_weights_distributor: A `DistributionProcess`.
    client_work: A `ClientWorkProcess`.
    model_update_aggregator: A `tff.templates.AggregationProcess`.
    model_finalizer: A `FinalizerProcess`.

  Returns:
    A `LearningProcess`.
  """
    # pyformat: enable
    _validate_args(initial_model_weights_fn, model_weights_distributor,
                   client_work, model_update_aggregator, model_finalizer)
    client_data_type = client_work.next.type_signature.parameter[2]

    @computations.federated_computation()
    def init_fn():
        initial_model_weights = intrinsics.federated_eval(
            initial_model_weights_fn, placements.SERVER)
        return intrinsics.federated_zip(
            LearningAlgorithmState(initial_model_weights,
                                   model_weights_distributor.initialize(),
                                   client_work.initialize(),
                                   model_update_aggregator.initialize(),
                                   model_finalizer.initialize()))

    @computations.federated_computation(init_fn.type_signature.result,
                                        client_data_type)
    def next_fn(state, client_data):
        # Compose processes.
        distributor_output = model_weights_distributor.next(
            state.distributor, state.global_model_weights)
        client_work_output = client_work.next(state.client_work,
                                              distributor_output.result,
                                              client_data)
        aggregator_output = model_update_aggregator.next(
            state.aggregator, client_work_output.result.update,
            client_work_output.result.update_weight)
        finalizer_output = model_finalizer.next(state.finalizer,
                                                state.global_model_weights,
                                                aggregator_output.result)

        # Form the learning process output.
        new_global_model_weights = finalizer_output.result
        new_state = intrinsics.federated_zip(
            LearningAlgorithmState(new_global_model_weights,
                                   distributor_output.state,
                                   client_work_output.state,
                                   aggregator_output.state,
                                   finalizer_output.state))
        metrics = intrinsics.federated_zip(
            collections.OrderedDict(
                distributor=distributor_output.measurements,
                client_work=client_work_output.measurements,
                aggregator=aggregator_output.measurements,
                finalizer=finalizer_output.measurements))

        return learning_process.LearningProcessOutput(new_state, metrics)

    @computations.tf_computation(next_fn.type_signature.result.state.member)
    def report_fn(state):
        return state.global_model_weights

    return learning_process.LearningProcess(init_fn, next_fn, report_fn)