示例#1
0
 def zero_fn():
     values = tf.nest.map_structure(
         lambda s: tf.zeros(s.shape, s.dtype),
         type_utils.type_to_tf_tensor_specs(part_decoded_x_type))
     state_update_tensors = tf.nest.map_structure(
         lambda s: tf.zeros(s.shape, s.dtype),
         type_utils.type_to_tf_tensor_specs(state_update_tensors_type))
     return _accumulator_value(values, state_update_tensors)
示例#2
0
  def next_fn(global_state, value, weight=None):
    """Defines next_fn for StatefulAggregateFn."""
    # Weighted aggregation is not supported.
    # TODO(b/140236959): Add an assertion that weight is None here, so the
    # contract of this method is better established. Will likely cause some
    # downstream breaks.
    del weight

    #######################################
    # Define local tf_computations

    # TODO(b/129567727): Make most of these tf_computations polymorphic
    # so type manipulation isn't needed.

    global_state_type = initialize_fn.type_signature.result

    @computations.tf_computation(global_state_type)
    def derive_sample_params(global_state):
      return query.derive_sample_params(global_state)

    @computations.tf_computation(derive_sample_params.type_signature.result,
                                 value.type_signature.member)
    def preprocess_record(params, record):
      # TODO(b/123092620): Once TFF passes the expected container type (instead
      # of AnonymousTuple), we shouldn't need this.
      record = from_tff_result_fn(record)

      return query.preprocess_record(params, record)

    # TODO(b/123092620): We should have the expected container type here.
    value_type = value_type_fn(value)

    tensor_specs = type_utils.type_to_tf_tensor_specs(value_type)

    @computations.tf_computation
    def zero():
      return query.initial_sample_state(tensor_specs)

    sample_state_type = zero.type_signature.result

    @computations.tf_computation(sample_state_type,
                                 preprocess_record.type_signature.result)
    def accumulate(sample_state, preprocessed_record):
      return query.accumulate_preprocessed_record(sample_state,
                                                  preprocessed_record)

    @computations.tf_computation(sample_state_type, sample_state_type)
    def merge(sample_state_1, sample_state_2):
      return query.merge_sample_states(sample_state_1, sample_state_2)

    @computations.tf_computation(merge.type_signature.result)
    def report(sample_state):
      return sample_state

    @computations.tf_computation(sample_state_type, global_state_type)
    def post_process(sample_state, global_state):
      result, new_global_state = query.get_noised_result(
          sample_state, global_state)
      return new_global_state, result

    #######################################
    # Orchestration logic

    sample_params = intrinsics.federated_map(derive_sample_params, global_state)
    client_sample_params = intrinsics.federated_broadcast(sample_params)
    preprocessed_record = intrinsics.federated_map(
        preprocess_record, (client_sample_params, value))
    agg_result = intrinsics.federated_aggregate(preprocessed_record, zero(),
                                                accumulate, merge, report)

    return intrinsics.federated_map(post_process, (agg_result, global_state))