예제 #1
0
 def assign_and_compute():
   tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                         initial_model_weights)
   py_typecheck.check_callable(baseline_evaluate_fn)
   return baseline_evaluate_fn(model, test_data)
예제 #2
0
def transform_type_postorder(
    type_signature: computation_types.Type,
    transform_fn: Callable[[computation_types.Type],
                           Tuple[computation_types.Type, bool]]):
    """Walks type tree of `type_signature` postorder, calling `transform_fn`.

  Args:
    type_signature: Instance of `computation_types.Type` to transform
      recursively.
    transform_fn: Transformation function to apply to each node in the type tree
      of `type_signature`. Must be instance of Python function type.

  Returns:
    A possibly transformed version of `type_signature`, with each node in its
    tree the result of applying `transform_fn` to the corresponding node in
    `type_signature`.

  Raises:
    TypeError: If the types don't match the specification above.
  """
    py_typecheck.check_type(type_signature, computation_types.Type)
    py_typecheck.check_callable(transform_fn)
    if type_signature.is_federated():
        transformed_member, member_mutated = transform_type_postorder(
            type_signature.member, transform_fn)
        if member_mutated:
            type_signature = computation_types.FederatedType(
                transformed_member, type_signature.placement,
                type_signature.all_equal)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or member_mutated
    elif type_signature.is_sequence():
        transformed_element, element_mutated = transform_type_postorder(
            type_signature.element, transform_fn)
        if element_mutated:
            type_signature = computation_types.SequenceType(
                transformed_element)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or element_mutated
    elif type_signature.is_function():
        if type_signature.parameter is not None:
            transformed_parameter, parameter_mutated = transform_type_postorder(
                type_signature.parameter, transform_fn)
        else:
            transformed_parameter, parameter_mutated = (None, False)
        transformed_result, result_mutated = transform_type_postorder(
            type_signature.result, transform_fn)
        if parameter_mutated or result_mutated:
            type_signature = computation_types.FunctionType(
                transformed_parameter, transformed_result)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, (type_signature_mutated or parameter_mutated
                                or result_mutated)
    elif type_signature.is_struct():
        elements = []
        elements_mutated = False
        for element in structure.iter_elements(type_signature):
            transformed_element, element_mutated = transform_type_postorder(
                element[1], transform_fn)
            elements_mutated = elements_mutated or element_mutated
            elements.append((element[0], transformed_element))
        if elements_mutated:
            if type_signature.is_struct_with_python():
                type_signature = computation_types.StructWithPythonType(
                    elements, type_signature.python_container)
            else:
                type_signature = computation_types.StructType(elements)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or elements_mutated
    elif type_signature.is_abstract() or type_signature.is_placement(
    ) or type_signature.is_tensor():
        return transform_fn(type_signature)
예제 #3
0
def add_measurements(
    inner_agg_factory: factory.AggregationFactory,
    measurement_fn: Callable[..., Dict[str, Any]],
) -> factory.AggregationFactory:
  """Wraps `AggregationFactory` to report additional measurements.

  The function `measurement_fn` is a python callable that will be called on
  `value` (if `inner_agg_factory` is an `UnweightedAggregationFactory`) or
  `(value, weight)` (if `inner_agg_factory` is a `WeightedAggregationFactory`)
  in the `next` function of the `AggregationProcess` produced by the returned
  factory to generate additional measurements. It must be traceable by TFF and
  expect `tff.Value` objects placed at `CLIENTS` as inputs, and return
  `collections.OrderedDicts` mapping string names to tensor values placed at
  `SERVER`, which will be added to the measurement dict produced by the
  `inner_agg_factory`.

  Args:
    inner_agg_factory: The factory to wrap and add measurements.
    measurement_fn: A python callable that will be called on `value` (and/or
      `weight`) provided to the `next` function to compute additional
      measurements.

  Returns:
    An `AggregationFactory` that reports additional measurements.
  """
  py_typecheck.check_callable(measurement_fn)

  if isinstance(inner_agg_factory, factory.UnweightedAggregationFactory):
    if len(inspect.signature(measurement_fn).parameters) != 1:
      raise ValueError('`measurement_fn` must take a single parameter if '
                       '`inner_agg_factory` is unweighted.')
  elif isinstance(inner_agg_factory, factory.WeightedAggregationFactory):
    if len(inspect.signature(measurement_fn).parameters) != 2:
      raise ValueError('`measurement_fn` must take a two parameters if '
                       '`inner_agg_factory` is weighted.')
  else:
    raise TypeError(
        f'`inner_agg_factory` must be of type `UnweightedAggregationFactory` or'
        f'`WeightedAggregationFactory`. Found {type(inner_agg_factory)}.')

  @computations.tf_computation()
  def dict_update(orig_dict, new_values):
    if not orig_dict:
      return new_values
    orig_dict.update(new_values)
    return orig_dict

  if isinstance(inner_agg_factory, factory.WeightedAggregationFactory):

    class WeightedWrappedFactory(factory.WeightedAggregationFactory):
      """Wrapper for `WeightedAggregationFactory` that adds new measurements."""

      def create(
          self, value_type: factory.ValueType, weight_type: factory.ValueType
      ) -> aggregation_process.AggregationProcess:
        py_typecheck.check_type(value_type, factory.ValueType.__args__)
        py_typecheck.check_type(weight_type, factory.ValueType.__args__)

        inner_agg_process = inner_agg_factory.create(value_type, weight_type)
        init_fn = inner_agg_process.initialize

        @computations.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type),
            computation_types.at_clients(weight_type))
        def next_fn(state, value, weight):
          inner_agg_output = inner_agg_process.next(state, value, weight)
          extra_measurements = measurement_fn(value, weight)
          measurements = intrinsics.federated_map(
              dict_update, (inner_agg_output.measurements, extra_measurements))
          return measured_process.MeasuredProcessOutput(
              state=inner_agg_output.state,
              result=inner_agg_output.result,
              measurements=measurements)

        return aggregation_process.AggregationProcess(init_fn, next_fn)

    return WeightedWrappedFactory()
  else:

    class UnweightedWrappedFactory(factory.UnweightedAggregationFactory):
      """Wrapper for `UnweightedAggregationFactory` that adds new measurements."""

      def create(
          self, value_type: factory.ValueType
      ) -> aggregation_process.AggregationProcess:
        py_typecheck.check_type(value_type, factory.ValueType.__args__)

        inner_agg_process = inner_agg_factory.create(value_type)
        init_fn = inner_agg_process.initialize

        @computations.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type))
        def next_fn(state, value):
          inner_agg_output = inner_agg_process.next(state, value)
          extra_measurements = measurement_fn(value)
          measurements = intrinsics.federated_map(
              dict_update, (inner_agg_output.measurements, extra_measurements))
          return measured_process.MeasuredProcessOutput(
              state=inner_agg_output.state,
              result=inner_agg_output.result,
              measurements=measurements)

        return aggregation_process.AggregationProcess(init_fn, next_fn)

    return UnweightedWrappedFactory()
def _client_fn(model,
               initial_model_weights,
               train_data,
               test_data,
               personalize_fn_dict,
               baseline_evaluate_fn,
               context=None):
  """The main `tf.function` that runs on device.

  This function first evalautes the initial model and gets the baseline metrics.
  Then starting from the same initial model, this function iterates over the
  personalization strategies defined in `personalize_fn_dict`, trains and
  evaluates the personalized models, and returns the evaluation metrics.

  Args:
    model: A `tff.learning.Model`.
    initial_model_weights: A `tff.learning.framework.ModelWeights` containing
      `tf.Tensor`s that hold trainable and non-trainable weights.
    train_data: A `tf.data.Dataset` used for training.
    test_data: A `tf.data.Dataset` used for evaluation.
    personalize_fn_dict: This is the same argument specified in the function
      `build_personalization_eval` above; see its documentation for details.
    baseline_evaluate_fn: This is the same argument specified in the function
      `build_personalization_eval` above; see its documentation for details.
    context: An optional object used in `personalize_fn_dict`. If used, its
      `tff.Type` must be provided by passing the correct `context_tff_type`
      argument to the `build_personalization_eval` function.

  Returns:
    An `OrderedDict` that maps a string 'baseline_metrics' to the evaluation
    metrics of the initial model (computed by `baseline_evaluate_fn`), and maps
    keys (strategy names) in `personalize_fn_dict` to the evaluation metrics of
    the corresponding personalization strategies.

  Raises:
    TypeError: If arguments are of the wrong types.
    ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`.
  """
  # Wrap the input model as an `EnhancedModel` for easy access of its weights.
  model = model_utils.enhance(model)

  final_metrics = collections.OrderedDict()
  tff.utils.assign(model.weights, initial_model_weights)
  py_typecheck.check_callable(baseline_evaluate_fn)
  final_metrics['baseline_metrics'] = baseline_evaluate_fn(model, test_data)

  py_typecheck.check_type(personalize_fn_dict, collections.OrderedDict)
  if 'baseline_metrics' in personalize_fn_dict:
    raise ValueError('baseline_metrics should not be used as a key in '
                     'personalize_fn_dict.')

  for name, personalize_fn_builder in personalize_fn_dict.items():
    py_typecheck.check_type(name, str)
    tff.utils.assign(model.weights, initial_model_weights)

    # Construct the `personalize_fn` (and the associated `tf.Variable`s) here.
    # Once `_client_fn` is decorated with `tff.tf_computation`, construction of
    # the new variables will happen in a scope controlled by TFF. Ensuring
    # `tf.Variable`s are created in the graphs that TFF controls is the reason
    # we need `personalize_fn_dict` to contain no-argument functions that build
    # the desired `tf.function`s, rather than already built `tf.function`s.
    py_typecheck.check_callable(personalize_fn_builder)
    personalize_fn = personalize_fn_builder()

    py_typecheck.check_callable(personalize_fn)
    final_metrics[name] = personalize_fn(model, train_data, test_data, context)

  return final_metrics
예제 #5
0
def serialize_jax_computation(traced_fn, arg_fn, parameter_type,
                              context_stack):
    """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    traced_fn: The Python function containing JAX code to be traced by JAX and
      serialized as a TFF computation containing XLA code.
    arg_fn: An unpacking function that takes a TFF argument, and returns a combo
      of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed
      by `function_utils.create_argument_unpacking_fn`).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if the function does not
      take any parameters.
    context_stack: The context stack to use during serialization.

  Returns:
    An instance of `pb.Computation` with the constructed computation.

  Raises:
    TypeError: if the arguments are of the wrong types.
  """
    py_typecheck.check_callable(traced_fn)
    py_typecheck.check_callable(arg_fn)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)

    if parameter_type is not None:
        parameter_type = computation_types.to_type(parameter_type)
        packed_arg = _tff_type_to_xla_serializer_arg(parameter_type)
    else:
        packed_arg = None

    args, kwargs = arg_fn(packed_arg)

    # While the fake parameters are fed via args/kwargs during serialization,
    # it is possible for them to get reorderd in the actual generate XLA code.
    # We use here the same flatenning function as that one, which is used by
    # the JAX serializer to determine the orderding and allow it to be captured
    # in the parameter binding. We do not need to do anything special for the
    # results, since the results, if multiple, are always returned as a tuple.
    flattened_obj, _ = jax.tree_util.tree_flatten((args, kwargs))
    tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj]))

    def _adjust_arg(x):
        if isinstance(x, structure.Struct):
            return type_conversions.type_to_py_container(x, x.type_signature)
        else:
            return x

    args = [_adjust_arg(x) for x in args]
    kwargs = {k: _adjust_arg(v) for k, v in kwargs.items()}

    context = jax_computation_context.JaxComputationContext()
    with context_stack.install(context):
        tracer_callable = jax.xla_computation(traced_fn,
                                              tuple_args=True,
                                              return_shape=True)
        compiled_xla, returned_shape = tracer_callable(*args, **kwargs)

    if isinstance(returned_shape, jax.ShapeDtypeStruct):
        returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(
            returned_shape)
    else:
        returned_type_spec = computation_types.to_type(
            structure.map_structure(
                _jax_shape_dtype_struct_to_tff_tensor,
                structure.from_container(returned_shape, recursive=True)))

    computation_type = computation_types.FunctionType(parameter_type,
                                                      returned_type_spec)
    return xla_serialization.create_xla_tff_computation(
        compiled_xla, tensor_indexes, computation_type)
예제 #6
0
def create_binary_operator_with_upcast(
        type_signature: computation_types.StructType,
        operator: Callable[[Any, Any], Any]) -> ProtoAndType:
    """Creates TF computation upcasting its argument and applying `operator`.

  Args:
    type_signature: A `computation_types.StructType` with two elements, both
      only containing structs or tensors in their type tree. The first and
      second element must match in structure, or the second element may be a
      single tensor type that is broadcasted (upcast) to the leaves of the
      structure of the first type.
    operator: Callable defining the operator.

  Returns:
    A `building_blocks.CompiledComputation` encapsulating a function which
    upcasts the second element of its argument and applies the binary
    operator.
  """
    py_typecheck.check_type(type_signature, computation_types.StructType)
    py_typecheck.check_callable(operator)
    type_analysis.check_tensorflow_compatible_type(type_signature)
    if not type_signature.is_struct() or len(type_signature) != 2:
        raise TypeError(
            'To apply a binary operator, we must by definition have an '
            'argument which is a `StructType` with 2 elements; '
            'asked to create a binary operator for type: {t}'.format(
                t=type_signature))
    if type_analysis.contains(type_signature, lambda t: t.is_sequence()):
        raise TypeError('Applying binary operators in TensorFlow is only '
                        'supported on Tensors and StructTypes; you '
                        'passed {t} which contains a SequenceType.'.format(
                            t=type_signature))

    def _pack_into_type(to_pack, type_spec):
        """Pack Tensor value `to_pack` into the nested structure `type_spec`."""
        if type_spec.is_struct():
            elem_iter = structure.iter_elements(type_spec)
            return structure.Struct([(elem_name,
                                      _pack_into_type(to_pack, elem_type))
                                     for elem_name, elem_type in elem_iter])
        elif type_spec.is_tensor():
            return tf.broadcast_to(to_pack, type_spec.shape)

    with tf.Graph().as_default() as graph:
        first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', type_signature[0], graph)
        operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph(
            'y', type_signature[1], graph)

        if type_signature[0].is_struct() and type_signature[1].is_struct():
            # If both the first and second arguments are structs with the same
            # structure, simply re-use operand_2_value as. `tf.nest.map_structure`
            # below will map the binary operator pointwise to the leaves of the
            # structure.
            if structure.is_same_structure(type_signature[0],
                                           type_signature[1]):
                second_arg = operand_2_value
            else:
                raise TypeError(
                    'Cannot upcast one structure to a different structure. '
                    '{x} -> {y}'.format(x=type_signature[1],
                                        y=type_signature[0]))
        elif type_signature[0].is_equivalent_to(type_signature[1]):
            second_arg = operand_2_value
        else:
            second_arg = _pack_into_type(operand_2_value, type_signature[0])

        if type_signature[0].is_tensor():
            result_value = operator(first_arg, second_arg)
        elif type_signature[0].is_struct():
            result_value = structure.map_structure(operator, first_arg,
                                                   second_arg)
        else:
            raise TypeError(
                'Encountered unexpected type {t}; can only handle Tensor '
                'and StructTypes.'.format(t=type_signature[0]))

    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result_value, graph)

    type_signature = computation_types.FunctionType(type_signature,
                                                    result_type)
    parameter_binding = pb.TensorFlow.Binding(
        struct=pb.TensorFlow.StructBinding(
            element=[operand_1_binding, operand_2_binding]))
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
def create_binary_operator_with_upcast(
    type_signature: computation_types.NamedTupleType,
    operator: Callable[[Any, Any], Any]) -> pb.Computation:
  """Creates TF computation upcasting its argument and applying `operator`.

  Args:
    type_signature: A `computation_types.NamedTupleType` with two elements, both
      of the same type or the second able to be upcast to the first, as
      explained in `apply_binary_operator_with_upcast`, and both containing only
      tuples and tensors in their type tree.
    operator: Callable defining the operator.

  Returns:
    A `building_blocks.CompiledComputation` encapsulating a function which
    upcasts the second element of its argument and applies the binary
    operator.
  """
  py_typecheck.check_type(type_signature, computation_types.NamedTupleType)
  py_typecheck.check_callable(operator)
  type_analysis.check_tensorflow_compatible_type(type_signature)
  if not type_signature.is_tuple() or len(type_signature) != 2:
    raise TypeError('To apply a binary operator, we must by definition have an '
                    'argument which is a `NamedTupleType` with 2 elements; '
                    'asked to create a binary operator for type: {t}'.format(
                        t=type_signature))
  if type_analysis.contains(type_signature, lambda t: t.is_sequence()):
    raise TypeError(
        'Applying binary operators in TensorFlow is only '
        'supported on Tensors and NamedTupleTypes; you '
        'passed {t} which contains a SequenceType.'.format(t=type_signature))

  def _pack_into_type(to_pack, type_spec):
    """Pack Tensor value `to_pack` into the nested structure `type_spec`."""
    if type_spec.is_tuple():
      elem_iter = anonymous_tuple.iter_elements(type_spec)
      return anonymous_tuple.AnonymousTuple([
          (elem_name, _pack_into_type(to_pack, elem_type))
          for elem_name, elem_type in elem_iter
      ])
    elif type_spec.is_tensor():
      return tf.broadcast_to(to_pack, type_spec.shape)

  with tf.Graph().as_default() as graph:
    first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', type_signature[0], graph)
    operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph(
        'y', type_signature[1], graph)
    if type_signature[0].is_equivalent_to(type_signature[1]):
      second_arg = operand_2_value
    else:
      second_arg = _pack_into_type(operand_2_value, type_signature[0])

    if type_signature[0].is_tensor():
      result_value = operator(first_arg, second_arg)
    elif type_signature[0].is_tuple():
      result_value = anonymous_tuple.map_structure(operator, first_arg,
                                                   second_arg)
    else:
      raise TypeError('Encountered unexpected type {t}; can only handle Tensor '
                      'and NamedTupleTypes.'.format(t=type_signature[0]))

  result_type, result_binding = tensorflow_utils.capture_result_from_graph(
      result_value, graph)

  type_signature = computation_types.FunctionType(type_signature, result_type)
  parameter_binding = pb.TensorFlow.Binding(
      tuple=pb.TensorFlow.NamedTupleBinding(
          element=[operand_1_binding, operand_2_binding]))
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature),
      tensorflow=tensorflow)
예제 #8
0
def build_weighted_mime_lite(
    model_fn: Callable[[], model_lib.Model],
    base_optimizer: optimizer_base.Optimizer,
    server_optimizer: optimizer_base.Optimizer = sgdm.build_sgdm(1.0),
    client_weighting: Optional[
        client_weight_lib.
        ClientWeighting] = client_weight_lib.ClientWeighting.NUM_EXAMPLES,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    full_gradient_aggregator: Optional[
        factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False
) -> learning_process.LearningProcess:
    """Builds a learning process that performs Mime Lite.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  Mime Lite algorithm on client models. The iterative process has the following
  methods inherited from `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` representing the initial
      state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is communicated to
  each client using the provided `model_distributor`. For each client, local
  training is performed using `optimizer`, where its state is communicated by
  the server, and kept intact during local training. The state is updated only
  at the server based on the full gradient evaluated by the clients based on the
  current server model state. The client full gradients are aggregated by
  weighted `full_gradient_aggregator`. Each client computes the difference
  between the client model after training and its initial model. These model
  deltas are then aggregated by weighted `model_aggregator`. Both of the
  aggregations are weighted, according to `client_weighting`. The aggregate
  model delta is added to the existing server model state.

  The Mime Lite algorithm is based on the paper
  "Breaking the centralized barrier for cross-device federated learning."
    Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank
    Reddi, Sebastian U. Stich, and Ananda Theertha Suresh.
    Advances in Neural Information Processing Systems 34 (2021).
    https://proceedings.neurips.cc/paper/2021/file/f0e6be4ce76ccfa73c5a540d992d0756-Paper.pdf

  Note that Keras optimizers are not supported. This is due to the Mime Lite
  algorithm applying the optimizer without changing it state at clients
  (optimizer's `tf.Variable`s in the case of Keras), which is not possible with
  Keras optimizers without reaching into private implementation details and
  incurring additional computation and memory cost at clients.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    base_optimizer: A `tff.learning.optimizers.Optimizer` which will be used for
      both creating and updating a global optimizer state, as well as
      optimization at clients given the global state, which is fixed during the
      optimization.
    server_optimizer: A `tff.learning.optimizers.Optimizer` which will be used
      for applying the aggregate model update to the global model weights.
    client_weighting: A member of `tff.learning.ClientWeighting` that specifies
      a built-in weighting method. By default, weighting by number of examples
      is used.
    model_distributor: An optional `DistributionProcess` that distributes the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via `distributors.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    full_gradient_aggregator: An optional
      `tff.aggregators.WeightedAggregationFactory` used to aggregate the full
      gradients on client datasets. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `tff.learning.templates.LearningProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_type(base_optimizer, optimizer_base.Optimizer)
    py_typecheck.check_type(server_optimizer, optimizer_base.Optimizer)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)

    @tensorflow_computation.tf_computation()
    def initial_model_weights_fn():
        return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result
    if model_distributor is None:
        model_distributor = distributors.build_broadcast_process(
            model_weights_type)
    if model_aggregator is None:
        model_aggregator = mean.MeanFactory()
    py_typecheck.check_type(model_aggregator,
                            factory.WeightedAggregationFactory)
    model_aggregator = model_aggregator.create(
        model_weights_type.trainable, computation_types.TensorType(tf.float32))
    if full_gradient_aggregator is None:
        full_gradient_aggregator = mean.MeanFactory()
    py_typecheck.check_type(full_gradient_aggregator,
                            factory.WeightedAggregationFactory)

    client_work = _build_mime_lite_client_work(
        model_fn=model_fn,
        optimizer=base_optimizer,
        client_weighting=client_weighting,
        full_gradient_aggregator=full_gradient_aggregator,
        metrics_aggregator=metrics_aggregator,
        use_experimental_simulation_loop=use_experimental_simulation_loop)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer, model_weights_type)
    return composers.compose_learning_process(initial_model_weights_fn,
                                              model_distributor, client_work,
                                              model_aggregator, finalizer)
예제 #9
0
def build_model_delta_optimizer_process(
    model_fn: _ModelConstructor,
    model_to_client_delta_fn: Callable[[model_lib.Model], ClientDeltaFn],
    server_optimizer_fn: _OptimizerConstructor,
    stateful_delta_aggregate_fn: tff.utils.
    StatefulAggregateFn = build_stateless_mean(),
    stateful_model_broadcast_fn: tff.utils.
    StatefulBroadcastFn = build_stateless_broadcaster(),
) -> tff.templates.IterativeProcess:
    """Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD.

  This provides the TFF orchestration logic connecting the common server logic
  which applies aggregated model deltas to the server model with a
  `ClientDeltaFn` that specifies how `weight_deltas` are computed on device.

  Note: We pass in functions rather than constructed objects so we can ensure
  any variables or ops created in constructors are placed in the correct graph.

  Args:
    model_fn: a no-arg function that returns a `tff.learning.Model`.
    model_to_client_delta_fn: a function from a `model_fn` to a `ClientDeltaFn`.
    server_optimizer_fn: a no-arg function that returns a `tf.Optimizer`. The
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model.
    stateful_delta_aggregate_fn: a `tff.utils.StatefulAggregateFn` where the
      `next_fn` performs a federated aggregation and upates state. That is, it
      has TFF type `(state@SERVER, value@CLIENTS, weights@CLIENTS) ->
      (state@SERVER, aggregate@SERVER)`, where the `value` type is
      `tff.learning.framework.ModelWeights.trainable` corresponding to the
      object returned by `model_fn`.
    stateful_model_broadcast_fn: a `tff.utils.StatefulBroadcastFn` where the
      `next_fn` performs a federated broadcast and upates state. That is, it has
      TFF type `(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS)`,
      where the `value` type is `tff.learning.framework.ModelWeights`
      corresponding to the object returned by `model_fn`.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_callable(model_to_client_delta_fn)
    py_typecheck.check_callable(server_optimizer_fn)
    py_typecheck.check_type(stateful_delta_aggregate_fn,
                            tff.utils.StatefulAggregateFn)
    py_typecheck.check_type(stateful_model_broadcast_fn,
                            tff.utils.StatefulBroadcastFn)

    initialize_computation = _build_initialize_computaiton(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        delta_aggregate_fn=stateful_delta_aggregate_fn,
        model_broadcast_fn=stateful_model_broadcast_fn)

    delta_aggregate_state_type = initialize_computation.type_signature.result.member.delta_aggregate_state
    model_broadcast_state_type = initialize_computation.type_signature.result.member.model_broadcast_state
    run_one_round_computation = _build_one_round_computation(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        model_to_client_delta_fn=model_to_client_delta_fn,
        delta_aggregate_fn=stateful_delta_aggregate_fn,
        model_broadcast_fn=stateful_model_broadcast_fn,
        delta_aggregate_state_type=delta_aggregate_state_type,
        model_broadcast_state_type=model_broadcast_state_type)

    return tff.templates.IterativeProcess(initialize_fn=initialize_computation,
                                          next_fn=run_one_round_computation)
예제 #10
0
def build_model_delta_optimizer_process(model_fn, model_to_client_delta_fn,
                                        server_optimizer_fn):
    """Constructs `tff.utils.IterativeProcess` for Federated Averaging or SGD.

  This provides the TFF orchestration logic connecting the common server logic
  which applies aggregated model deltas to the server model with a ClientDeltaFn
  that specifies how weight_deltas are computed on device.

  Note: We pass in functions rather than constructed objects so we can ensure
  any variables or ops created in constructors are placed in the correct graph.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    model_to_client_delta_fn: A function from a model_fn to a `ClientDeltaFn`.
    server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model.

  Returns:
    A `tff.utils.IterativeProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_callable(model_to_client_delta_fn)
    py_typecheck.check_callable(server_optimizer_fn)

    # TODO(b/122081673): would be nice not to have the construct a throwaway model
    # here just to get the types. After fully moving to TF2.0 and eager-mode, we
    # should re-evaluate what happens here and where `g` is used below.
    with tf.Graph().as_default() as g:
        dummy_model_for_metadata = model_utils.enhance(model_fn())

    @tff.federated_computation
    def server_init_tff():
        """Orchestration logic for server model initialization."""
        no_arg_server_init_fn = lambda: server_init(model_fn,
                                                    server_optimizer_fn)
        server_init_tf = tff.tf_computation(no_arg_server_init_fn)
        return tff.federated_value(server_init_tf(), tff.SERVER)

    federated_server_state_type = server_init_tff.type_signature.result
    server_state_type = federated_server_state_type.member

    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)
    federated_dataset_type = tff.FederatedType(tf_dataset_type,
                                               tff.CLIENTS,
                                               all_equal=False)

    @tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round_tff(server_state, federated_dataset):
        """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
    `tff.learning.Model.federated_output_computation`.
    """
        model_weights_type = federated_server_state_type.member.model

        @tff.tf_computation(tf_dataset_type, model_weights_type)
        def client_delta_tf(tf_dataset, initial_model_weights):
            """Performs client local model optimization.

      Args:
        tf_dataset: a `tf.data.Dataset` that provides training examples.
        initial_model_weights: a `model_utils.ModelWeights` containing the
          starting weights.

      Returns:
        A `ClientOutput` structure.
      """
            client_delta_fn = model_to_client_delta_fn(model_fn)

            # TODO(b/123092620): this can be removed once AnonymousTuple works with
            # tf.contrib.framework.nest, or the following behavior is moved to
            # anonymous_tuple module.
            if isinstance(initial_model_weights,
                          anonymous_tuple.AnonymousTuple):
                initial_model_weights = model_utils.ModelWeights.from_tff_value(
                    initial_model_weights)

            client_output = client_delta_fn(tf_dataset, initial_model_weights)
            return client_output

        client_outputs = tff.federated_map(
            client_delta_tf,
            (federated_dataset, tff.federated_broadcast(server_state.model)))

        @tff.tf_computation(server_state_type, model_weights_type.trainable)
        def server_update_model_tf(server_state, model_delta):
            """Converts args to correct python types and calls server_update_model."""
            # We need to convert TFF types to the types server_update_model expects.
            # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not
            # pretty, fold this into anonymous_tuple module or get working with
            # tf.contrib.framework.nest.
            py_typecheck.check_type(model_delta,
                                    anonymous_tuple.AnonymousTuple)
            model_delta = anonymous_tuple.to_odict(model_delta)
            py_typecheck.check_type(server_state,
                                    anonymous_tuple.AnonymousTuple)
            server_state = ServerState(
                model=model_utils.ModelWeights.from_tff_value(
                    server_state.model),
                optimizer_state=list(server_state.optimizer_state))

            return server_update_model(server_state,
                                       model_delta,
                                       model_fn=model_fn,
                                       optimizer_fn=server_optimizer_fn)

        # TODO(b/124070381): We hope to remove this explicit cast once we have a
        # full solution for type analysis in multiplications and divisions
        # inside TFF
        fed_weight_type = client_outputs.weights_delta_weight.type_signature.member
        py_typecheck.check_type(fed_weight_type, tff.TensorType)
        if fed_weight_type.dtype.is_integer:

            @tff.tf_computation(fed_weight_type)
            def _cast_to_float(x):
                return tf.cast(x, tf.float32)

            weight_denom = tff.federated_map(
                _cast_to_float, client_outputs.weights_delta_weight)
        else:
            weight_denom = client_outputs.weights_delta_weight
        round_model_delta = tff.federated_average(client_outputs.weights_delta,
                                                  weight=weight_denom)

        # TODO(b/123408447): remove tff.federated_apply and call
        # server_update_model_tf directly once T <-> T@SERVER isomorphism is
        # supported.
        server_state = tff.federated_apply(server_update_model_tf,
                                           (server_state, round_model_delta))

        # Re-use graph used to construct `model`, since it has the variables, which
        # need to be read in federated_output_computation to get the correct shapes
        # and types for the federated aggregation.
        with g.as_default():
            aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
                client_outputs.model_output)

        # Promote the FederatedType outside the NamedTupleType
        aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs

    return tff.utils.IterativeProcess(initialize_fn=server_init_tff,
                                      next_fn=run_one_round_tff)
예제 #11
0
def _build_mime_lite_client_work(
    model_fn: Callable[[], model_lib.Model],
    optimizer: optimizer_base.Optimizer,
    client_weighting: client_weight_lib.ClientWeighting,
    full_gradient_aggregator: Optional[
        factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False
) -> client_works.ClientWorkProcess:
    """Creates a `ClientWorkProcess` for Mime Lite.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    optimizer: A `tff.learning.optimizers.Optimizer` which will be used for both
      creating and updating a global optimizer state, as well as optimization at
      clients given the global state, which is fixed during the optimization.
    client_weighting: A member of `tff.learning.ClientWeighting` that specifies
      a built-in weighting method.
    full_gradient_aggregator: An optional
      `tff.aggregators.WeightedAggregationFactory` used to aggregate the full
      gradients on client datasets. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `ClientWorkProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_type(optimizer, optimizer_base.Optimizer)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)
    if full_gradient_aggregator is None:
        full_gradient_aggregator = mean.MeanFactory()
    py_typecheck.check_type(full_gradient_aggregator,
                            factory.WeightedAggregationFactory)
    if metrics_aggregator is None:
        metrics_aggregator = metric_aggregator.sum_then_finalize

    with tf.Graph().as_default():
        # Wrap model construction in a graph to avoid polluting the global context
        # with variables created for this model.
        model = model_fn()
        unfinalized_metrics_type = type_conversions.type_from_tensors(
            model.report_local_unfinalized_metrics())
        metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(),
                                                    unfinalized_metrics_type)
    data_type = computation_types.SequenceType(model.input_spec)
    weights_type = model_utils.weights_type_from_model(model)
    weight_tensor_specs = type_conversions.type_to_tf_tensor_specs(
        weights_type)

    full_gradient_aggregator = full_gradient_aggregator.create(
        weights_type.trainable, computation_types.TensorType(tf.float32))

    @federated_computation.federated_computation
    def init_fn():
        specs = weight_tensor_specs.trainable
        optimizer_state = intrinsics.federated_eval(
            tensorflow_computation.tf_computation(
                lambda: optimizer.initialize(specs)), placements.SERVER)
        aggregator_state = full_gradient_aggregator.initialize()
        return intrinsics.federated_zip((optimizer_state, aggregator_state))

    client_update_fn = _build_client_update_fn_for_mime_lite(
        model_fn, optimizer, client_weighting,
        use_experimental_simulation_loop)

    @tensorflow_computation.tf_computation(
        init_fn.type_signature.result.member[0], weights_type.trainable)
    def update_optimizer_state(state, aggregate_gradient):
        whimsy_weights = tf.nest.map_structure(
            lambda g: tf.zeros(g.shape, g.dtype), aggregate_gradient)
        updated_state, _ = optimizer.next(state, whimsy_weights,
                                          aggregate_gradient)
        return updated_state

    @federated_computation.federated_computation(
        init_fn.type_signature.result,
        computation_types.at_clients(weights_type),
        computation_types.at_clients(data_type))
    def next_fn(state, weights, client_data):
        optimizer_state, aggregator_state = state
        optimizer_state_at_clients = intrinsics.federated_broadcast(
            optimizer_state)
        client_result, model_outputs, full_gradient = (
            intrinsics.federated_map(
                client_update_fn,
                (optimizer_state_at_clients, weights, client_data)))
        full_gradient_agg_output = full_gradient_aggregator.next(
            aggregator_state, full_gradient, client_result.update_weight)
        updated_optimizer_state = intrinsics.federated_map(
            update_optimizer_state,
            (optimizer_state, full_gradient_agg_output.result))

        new_state = intrinsics.federated_zip(
            (updated_optimizer_state, full_gradient_agg_output.state))
        train_metrics = metrics_aggregation_fn(model_outputs)
        measurements = intrinsics.federated_zip(
            collections.OrderedDict(train=train_metrics))
        return measured_process.MeasuredProcessOutput(new_state, client_result,
                                                      measurements)

    return client_works.ClientWorkProcess(init_fn, next_fn)
def build_weighted_fed_avg_with_optimizer_schedule(
    model_fn: Callable[[], model_lib.Model],
    client_learning_rate_fn: Callable[[int], float],
    client_optimizer_fn: Callable[[float], TFFOrKerasOptimizer],
    server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[
        [],
        tf.keras.optimizers.Optimizer]] = fed_avg.DEFAULT_SERVER_OPTIMIZER_FN,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False
) -> learning_process.LearningProcess:
  """Builds a learning process for FedAvg with client optimizer scheduling.

  This function creates a `LearningProcess` that performs federated averaging on
  client models. The iterative process has the following methods inherited from
  `LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a `LearningAlgorithmState` representing the
      initial state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is broadcast to each
  client using a broadcast function. For each client, local training is
  performed using `client_optimizer_fn`. Each client computes the difference
  between the client model after training and the initial broadcast model.
  These model deltas are then aggregated at the server using a weighted
  aggregation function. Clients weighted by the number of examples they see
  thoughout local training. The aggregate model delta is applied at the server
  using a server optimizer.

  The primary purpose of this implementation of FedAvg is that it allows for the
  client optimizer to be scheduled across rounds. The process keeps track of how
  many iterations of `.next` have occurred (starting at `0`), and for each such
  `round_num`, the clients will use `client_optimizer_fn(round_num)` to perform
  local optimization. This allows learning rate scheduling (eg. starting with
  a large learning rate and decaying it over time) as well as a small learning
  rate (eg. switching optimizers as learning progresses).

  Note: the default server optimizer function is `tf.keras.optimizers.SGD`
  with a learning rate of 1.0, which corresponds to adding the model delta to
  the current server model. This recovers the original FedAvg algorithm in
  [McMahan et al., 2017](https://arxiv.org/abs/1602.05629). More
  sophisticated federated averaging procedures may use different learning rates
  or server optimizers.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    client_learning_rate_fn: A callable accepting an integer round number and
      returning a float to be used as a learning rate for the optimizer. The
      client work will call `optimizer_fn(learning_rate_fn(round_num))` where
      `round_num` is the integer round number. Note that the round numbers
      supplied will start at `0` and increment by one each time `.next` is
      called on the resulting process. Also note that this function must be
      serializable by TFF.
    client_optimizer_fn: A callable accepting a float learning rate, and
      returning a `tff.learning.optimizers.Optimizer` or a `tf.keras.Optimizer`.
    server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`. By default, this uses
      `tf.keras.optimizers.SGD` with a learning rate of 1.0.
    model_distributor: An optional `DistributionProcess` that distributes the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via `distributors.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `LearningProcess`.
  """
  py_typecheck.check_callable(model_fn)

  @tensorflow_computation.tf_computation()
  def initial_model_weights_fn():
    return model_utils.ModelWeights.from_model(model_fn())

  model_weights_type = initial_model_weights_fn.type_signature.result

  if model_distributor is None:
    model_distributor = distributors.build_broadcast_process(model_weights_type)

  if model_aggregator is None:
    model_aggregator = mean.MeanFactory()
  py_typecheck.check_type(model_aggregator, factory.WeightedAggregationFactory)
  aggregator = model_aggregator.create(model_weights_type.trainable,
                                       computation_types.TensorType(tf.float32))
  process_signature = aggregator.next.type_signature
  input_client_value_type = process_signature.parameter[1]
  result_server_value_type = process_signature.result[1]
  if input_client_value_type.member != result_server_value_type.member:
    raise TypeError('`model_update_aggregation_factory` does not produce a '
                    'compatible `AggregationProcess`. The processes must '
                    'retain the type structure of the inputs on the '
                    f'server, but got {input_client_value_type.member} != '
                    f'{result_server_value_type.member}.')

  if metrics_aggregator is None:
    metrics_aggregator = metric_aggregator.sum_then_finalize
  client_work = build_scheduled_client_work(model_fn, client_learning_rate_fn,
                                            client_optimizer_fn,
                                            metrics_aggregator,
                                            use_experimental_simulation_loop)
  finalizer = finalizers.build_apply_optimizer_finalizer(
      server_optimizer_fn, model_weights_type)
  return composers.compose_learning_process(initial_model_weights_fn,
                                            model_distributor, client_work,
                                            aggregator, finalizer)
예제 #13
0
def extract_nodes_consuming(tree, predicate):
    """Returns the set of AST nodes which consume nodes matching `predicate`.

  Notice we adopt the convention that a node which itself satisfies the
  predicate is in this set.

  Args:
    tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an
      abstract syntax tree, and construct the set of nodes in this tree having a
      dependency on nodes matching `predicate`; that is, the set of nodes whose
      value depends on evaluating nodes matching `predicate`.
    predicate: One-arg callable, accepting arguments of type
      `building_blocks.ComputationBuildingBlock` and returning a `bool`
      indicating match or mismatch with the desired pattern.

  Returns:
    A `set` of `building_blocks.ComputationBuildingBlock` instances
    representing the nodes in `tree` dependent on nodes matching `predicate`.
  """
    py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock)
    py_typecheck.check_callable(predicate)

    class _NodeSet:
        def __init__(self):
            self.mapping = {}

        def add(self, comp):
            self.mapping[id(comp)] = comp

        def to_set(self):
            return set(self.mapping.values())

    dependent_nodes = _NodeSet()

    def _are_children_in_dependent_set(comp, symbol_tree):
        """Checks if the dependencies of `comp` are present in `dependent_nodes`."""
        if (comp.is_intrinsic() or comp.is_data() or comp.is_placement()
                or comp.is_compiled_computation()):
            return False
        elif comp.is_lambda():
            return id(comp.result) in dependent_nodes.mapping
        elif comp.is_block():
            return any(
                id(x[1]) in dependent_nodes.mapping
                for x in comp.locals) or id(
                    comp.result) in dependent_nodes.mapping
        elif comp.is_struct():
            return any(id(x) in dependent_nodes.mapping for x in comp)
        elif comp.is_selection():
            return id(comp.source) in dependent_nodes.mapping
        elif comp.is_call():
            return id(comp.function) in dependent_nodes.mapping or id(
                comp.argument) in dependent_nodes.mapping
        elif comp.is_reference():
            return _is_reference_dependent(comp, symbol_tree)

    def _is_reference_dependent(comp, symbol_tree):
        payload = symbol_tree.get_payload_with_name(comp.name)
        if payload is None:
            return False
        # The postorder traversal ensures that we process any
        # bindings before we process the reference to those bindings
        return id(payload.value) in dependent_nodes.mapping

    def _populate_dependent_set(comp, symbol_tree):
        """Populates `dependent_nodes` with all nodes dependent on `predicate`."""
        if predicate(comp):
            dependent_nodes.add(comp)
        elif _are_children_in_dependent_set(comp, symbol_tree):
            dependent_nodes.add(comp)
        return comp, False

    symbol_tree = transformation_utils.SymbolTree(
        transformation_utils.ReferenceCounter)
    transformation_utils.transform_postorder_with_symbol_bindings(
        tree, _populate_dependent_set, symbol_tree)
    return dependent_nodes.to_set()
예제 #14
0
def build_personalization_eval(model_fn,
                               personalize_fn_dict,
                               baseline_evaluate_fn,
                               max_num_clients=100,
                               context_tff_type=None):
  """Builds the TFF computation for evaluating personalization strategies.

  The returned TFF computation broadcasts model weights from `tff.SERVER` to
  `tff.CLIENTS`. Each client evaluates the personalization strategies given in
  `personalize_fn_dict`. Evaluation metrics from at most `max_num_clients`
  participating clients are collected to the server.

  NOTE: The functions in `personalize_fn_dict` and `baseline_evaluate_fn` are
  expected to take as input *unbatched* datasets, and are responsible for
  applying batching, if any, to the provided input datasets.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a
      strategy name) to a no-argument function that returns a `tf.function`.
      Each `tf.function` represents a personalization strategy - it accepts a
      `tff.learning.Model` (with weights already initialized to the given model
      weights when users invoke the returned TFF computation), an unbatched
      `tf.data.Dataset` for train, an unbatched `tf.data.Dataset` for test, and
      an arbitrary context object (which is used to hold any extra information
      that a personalization strategy may use), trains a personalized model, and
      returns the evaluation metrics. The evaluation metrics are represented as
      an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to
      scalar `tf.Tensor`s.
    baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model`
      (with weights already initialized to the provided model weights when users
      invoke the returned TFF computation), and an unbatched `tf.data.Dataset`,
      evaluates the model on the dataset, and returns the evaluation metrics.
      The evaluation metrics are represented as an `OrderedDict` (or a nested
      `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. This
      function is *only* used to compute the baseline metrics of the initial
      model.
    max_num_clients: A positive `int` specifying the maximum number of clients
      to collect metrics in a round (default is 100). The clients are sampled
      without replacement. For each sampled client, all the personalization
      metrics from this client will be collected. If the number of participating
      clients in a round is smaller than this value, then metrics from all
      clients will be collected.
    context_tff_type: A `tff.Type` of the optional context object used by the
      personalization strategies defined in `personalization_fn_dict`. We use a
      context object to hold any extra information (in addition to the training
      dataset) that personalization may use. If context is used in
      `personalization_fn_dict`, its `tff.Type` must be provided here.

  Returns:
    A federated `tff.Computation` with the functional type signature
    `(<model_weights@SERVER, input@CLIENTS> -> personalization_metrics@SERVER)`:

    *   `model_weights` is a `tff.learning.ModelWeights`.
    *   Each client's input is an `OrderedDict` of two required keys
        `train_data` and `test_data`; each key is mapped to an unbatched
        `tf.data.Dataset`. If extra context (e.g., extra datasets) is used in
        `personalize_fn_dict`, then client input has a third key `context` that
        is mapped to a object whose `tff.Type` is provided by the
        `context_tff_type` argument.
    *   `personazliation_metrics` is an `OrderedDict` that maps a key
        'baseline_metrics' to the evaluation metrics of the initial model
        (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in
        `personalize_fn_dict` to the evaluation metrics of the corresponding
        personalization strategies.
    *   Note: only metrics from at most `max_num_clients` participating clients
        (sampled without replacement) are collected to the SERVER. All collected
        metrics are stored in a single `OrderedDict` (`personalization_metrics`
        shown above), where each metric is mapped to a list of scalars (each
        scalar comes from one client). Metric values at the same position, e.g.,
        metric_1[i], metric_2[i]..., all come from the same client.

  Raises:
    TypeError: If arguments are of the wrong types.
    ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`.
    ValueError: If `max_num_clients` is not positive.
  """
  # Obtain the types by constructing the model first.
  # TODO(b/124477628): Replace it with other ways of handling metadata.
  with tf.Graph().as_default():
    py_typecheck.check_callable(model_fn)
    model = model_fn()
    model_weights_type = model_utils.weights_type_from_model(model)
    batch_tff_type = computation_types.to_type(model.input_spec)

  # Define the `tff.Type` of each client's input. Since batching (as well as
  # other preprocessing of datasets) is done within each personalization
  # strategy (i.e., by functions in `personalize_fn_dict`), the client-side
  # input should contain unbatched elements.
  element_tff_type = _remove_batch_dim(batch_tff_type)
  client_input_type = collections.OrderedDict(
      train_data=computation_types.SequenceType(element_tff_type),
      test_data=computation_types.SequenceType(element_tff_type))
  if context_tff_type is not None:
    py_typecheck.check_type(context_tff_type, computation_types.Type)
    client_input_type['context'] = context_tff_type
  client_input_type = computation_types.to_type(client_input_type)

  @computations.tf_computation(model_weights_type, client_input_type)
  def _client_computation(initial_model_weights, client_input):
    """TFF computation that runs on each client."""
    train_data = client_input['train_data']
    test_data = client_input['test_data']
    context = client_input.get('context', None)

    final_metrics = collections.OrderedDict()
    # Compute the evaluation metrics of the initial model.
    final_metrics['baseline_metrics'] = _compute_baseline_metrics(
        model_fn, initial_model_weights, test_data, baseline_evaluate_fn)

    py_typecheck.check_type(personalize_fn_dict, collections.OrderedDict)
    if 'baseline_metrics' in personalize_fn_dict:
      raise ValueError('baseline_metrics should not be used as a key in '
                       'personalize_fn_dict.')

    # Compute the evaluation metrics of the personalized models. The returned
    # `p13n_metrics` is an `OrderedDict` that maps keys (strategy names) in
    # `personalize_fn_dict` to the evaluation metrics of the corresponding
    # personalization strategies.
    p13n_metrics = _compute_p13n_metrics(model_fn, initial_model_weights,
                                         train_data, test_data,
                                         personalize_fn_dict, context)
    final_metrics.update(p13n_metrics)
    return final_metrics

  py_typecheck.check_type(max_num_clients, int)
  if max_num_clients <= 0:
    raise ValueError('max_num_clients must be a positive integer.')

  reservoir_sampling_factory = sampling.UnweightedReservoirSamplingFactory(
      sample_size=max_num_clients)
  aggregation_process = reservoir_sampling_factory.create(
      _client_computation.type_signature.result)

  @computations.federated_computation(
      computation_types.FederatedType(model_weights_type, placements.SERVER),
      computation_types.FederatedType(client_input_type, placements.CLIENTS))
  def personalization_eval(server_model_weights, federated_client_input):
    """TFF orchestration logic."""
    client_init_weights = intrinsics.federated_broadcast(server_model_weights)
    client_final_metrics = intrinsics.federated_map(
        _client_computation, (client_init_weights, federated_client_input))

    # WARNING: Collecting information from clients can be risky. Users have to
    # make sure that it is proper to collect those metrics from clients.
    # TODO(b/147889283): Add a link to the TFF doc once it exists.
    sampling_output = aggregation_process.next(
        aggregation_process.initialize(),  # No state.
        client_final_metrics)
    # In the future we may want to output `sampling_output.measurements` also
    # but currently it is empty.
    return sampling_output.result

  return personalization_eval
예제 #15
0
def zero_or_one_arg_fn_to_building_block(
    fn,
    parameter_name: Optional[str],
    parameter_type: Optional[computation_types.Type],
    context_stack: context_stack_base.ContextStack,
    suggested_name: Optional[str] = None,
) -> Tuple[building_blocks.ComputationBuildingBlock, computation_types.Type]:
    """Converts a zero- or one-argument `fn` into a computation building block.

  Args:
    fn: A function with 0 or 1 arguments that contains orchestration logic,
      i.e., that expects zero or one `values_base.Value` and returns a result
      convertible to the same.
    parameter_name: The name of the parameter, or `None` if there is't any.
    parameter_type: The `tff.Type` of the parameter, or `None` if there's none.
    context_stack: The context stack to use.
    suggested_name: The optional suggested name to use for the federated context
      that will be used to serialize this function's body (ideally the name of
      the underlying Python function). It might be modified to avoid conflicts.

  Returns:
    A tuple of `(building_blocks.ComputationBuildingBlock,
    computation_types.Type)`, where the first element contains the logic from
    `fn`, and the second element contains potentially annotated type information
    for the result of `fn`.

  Raises:
    ValueError: if `fn` is incompatible with `parameter_type`.
  """
    py_typecheck.check_callable(fn)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if suggested_name is not None:
        py_typecheck.check_type(suggested_name, str)
    if isinstance(context_stack.current,
                  federated_computation_context.FederatedComputationContext):
        parent_context = context_stack.current
    else:
        parent_context = None
    context = federated_computation_context.FederatedComputationContext(
        context_stack, suggested_name=suggested_name, parent=parent_context)
    if parameter_name is not None:
        py_typecheck.check_type(parameter_name, str)
        parameter_name = '{}_{}'.format(context.name, str(parameter_name))
    with context_stack.install(context):
        if parameter_type is not None:
            result = fn(
                value_impl.ValueImpl(
                    building_blocks.Reference(parameter_name, parameter_type),
                    context_stack))
        else:
            result = fn()
        if result is None:
            raise ValueError(
                'The function defined on line {} of file {} has returned a '
                '`NoneType`, but all TFF functions must return some non-`None` '
                'value.'.format(fn.__code__.co_firstlineno,
                                fn.__code__.co_filename))
        annotated_result_type = type_conversions.infer_type(result)
        result = value_impl.to_value(result, annotated_result_type,
                                     context_stack)
        result_comp = value_impl.ValueImpl.get_comp(result)
        symbols_bound_in_context = context_stack.current.symbol_bindings
        if symbols_bound_in_context:
            result_comp = building_blocks.Block(
                local_symbols=symbols_bound_in_context, result=result_comp)
        annotated_type = computation_types.FunctionType(
            parameter_type, annotated_result_type)
        return building_blocks.Lambda(parameter_name, parameter_type,
                                      result_comp), annotated_type
예제 #16
0
def create_executor_factory(
    executor_stack_fn: Callable[[CardinalitiesType], executor_base.Executor]
) -> ExecutorFactory:
    """Create an `ExecutorFactory` for a given executor stack function."""
    py_typecheck.check_callable(executor_stack_fn)
    return ExecutorFactoryImpl(executor_stack_fn)
예제 #17
0
def build_jax_federated_averaging_process(batch_type, model_type, loss_fn,
                                          step_size):
    """Constructs an iterative process that implements simple federated averaging.

  Args:
    batch_type: An instance of `tff.Type` that represents the type of a single
      batch of data to use for training. This type should be constructed with
      standard Python containers (such as `collections.OrderedDict`) of the sort
      that are expected as parameters to `loss_fn`.
    model_type: An instance of `tff.Type` that represents the type of the model.
      Similarly to `batch_size`, this type should be constructed with standard
      Python containers (such as `collections.OrderedDict`) of the sort that are
      expected as parameters to `loss_fn`.
    loss_fn: A loss function for the model. Must be a Python function that takes
      two parameters, one of them being the model, and the other being a single
      batch of data (with types matching `batch_type` and `model_type`).
    step_size: The step size to use during training (an `np.float32`).

  Returns:
    An instance of `tff.templates.IterativeProcess` that implements federated
    training in JAX.
  """
    batch_type = computation_types.to_type(batch_type)
    model_type = computation_types.to_type(model_type)

    py_typecheck.check_type(batch_type, computation_types.Type)
    py_typecheck.check_type(model_type, computation_types.Type)
    py_typecheck.check_callable(loss_fn)
    py_typecheck.check_type(step_size, np.float)

    def _tensor_zeros(tensor_type):
        return jax.numpy.zeros(tensor_type.shape.dims,
                               dtype=tensor_type.dtype.as_numpy_dtype)

    @experimental_computations.jax_computation
    def _create_zero_model():
        model_zeros = structure.map_structure(_tensor_zeros, model_type)
        return type_conversions.type_to_py_container(model_zeros, model_type)

    @computations.federated_computation
    def _create_zero_model_on_server():
        return intrinsics.federated_eval(_create_zero_model, placements.SERVER)

    def _apply_update(model_param, param_delta):
        return model_param - step_size * param_delta

    @experimental_computations.jax_computation(model_type, batch_type)
    def _train_on_one_batch(model, batch):
        params = structure.flatten(
            structure.from_container(model, recursive=True))
        grads = structure.flatten(
            structure.from_container(jax.api.grad(loss_fn)(model, batch)))
        updated_params = [_apply_update(x, y) for (x, y) in zip(params, grads)]
        trained_model = structure.pack_sequence_as(model_type, updated_params)
        return type_conversions.type_to_py_container(trained_model, model_type)

    local_dataset_type = computation_types.SequenceType(batch_type)

    @computations.federated_computation(model_type, local_dataset_type)
    def _train_on_one_client(model, batches):
        return intrinsics.sequence_reduce(batches, model, _train_on_one_batch)

    @computations.federated_computation(
        computation_types.FederatedType(model_type, placements.SERVER),
        computation_types.FederatedType(local_dataset_type,
                                        placements.CLIENTS))
    def _train_one_round(model, federated_data):
        locally_trained_models = intrinsics.federated_map(
            _train_on_one_client,
            collections.OrderedDict([('model',
                                      intrinsics.federated_broadcast(model)),
                                     ('batches', federated_data)]))
        return intrinsics.federated_mean(locally_trained_models)

    return iterative_process.IterativeProcess(
        initialize_fn=_create_zero_model_on_server, next_fn=_train_one_round)
    def __init__(self,
                 raw_client_data: client_data.ClientData,
                 make_transform_fn: Callable[[str, int], Callable[[Any], Any]],
                 num_transformed_clients: Optional[int] = None):
        """Initializes the TransformingClientData.

    Args:
      raw_client_data: A ClientData to expand.
      make_transform_fn: A function that returns a callable that maps datapoint
        x to a new datapoint x'. make_transform_fn will be called as
        make_transform_fn(raw_client_id, i) where i is an integer index, and
        should return a function fn(x)->x. For example if x is an image, then
        make_transform_fn("client_a", 0)(x) might be the identity, while
        make_transform_fn("client_a", 1)(x) could be a random rotation of the
        image with the angle determined by a hash of "client_a" and "1". If
        transform_fn_cons returns `None`, no transformation is performed.
        Typically by convention the index 0 corresponds to the identity function
        if the identity is supported.
      num_transformed_clients: The total number of transformed clients to
        produce. If `None`, only the original clients will be transformed. If it
        is an integer multiple k of the number of real clients, there will be
        exactly k pseudo-clients per real client, with indices 0...k-1. Any
        remainder g will be generated from the first g real clients and will be
        given index k.
    """
        py_typecheck.check_type(raw_client_data, client_data.ClientData)
        py_typecheck.check_callable(make_transform_fn)

        raw_client_ids = raw_client_data.client_ids
        if not raw_client_ids:
            raise ValueError('`raw_client_data` must be non-empty.')

        if num_transformed_clients is None:
            num_transformed_clients = len(raw_client_ids)
        else:
            py_typecheck.check_type(num_transformed_clients, int)
            if num_transformed_clients <= 0:
                raise ValueError('`num_transformed_clients` must be positive.')

        self._raw_client_data = raw_client_data
        self._make_transform_fn = make_transform_fn

        self._has_pseudo_clients = num_transformed_clients > len(
            raw_client_ids)

        if self._has_pseudo_clients:
            num_digits = len(str(num_transformed_clients - 1))
            format_str = '{}_{:0' + str(num_digits) + '}'

            k = num_transformed_clients // len(raw_client_ids)
            self._client_ids = []
            for raw_client_id in raw_client_ids:
                for i in range(k):
                    self._client_ids.append(format_str.format(
                        raw_client_id, i))
            num_extra_client_ids = num_transformed_clients - k * len(
                raw_client_ids)
            for c in range(num_extra_client_ids):
                self._client_ids.append(format_str.format(
                    raw_client_ids[c], k))
        else:
            self._client_ids = raw_client_ids

        # Already sorted if raw_client_data.client_ids are, but just to be sure...
        self._client_ids = sorted(self._client_ids)
def create_binary_operator(
    operator, operand_type: computation_types.Type) -> pb.Computation:
  """Returns a tensorflow computation computing a binary operation.

  The returned computation has the type signature `(<T,T> -> U)`, where `T` is
  `operand_type` and `U` is the result of applying the `operator` to a tuple of
  type `<T,T>`

  Note: If `operand_type` is a `computation_types.NamedTupleType`, then
  `operator` will be applied pointwise. This places the burden on callers of
  this function to construct the correct values to pass into the returned
  function. For example, to divide `[2, 2]` by `2`, first `2` must be packed
  into the data structure `[x, x]`, before the division operator of the
  appropriate type is called.

  Args:
    operator: A callable taking two arguments representing the operation to
      encode For example: `tf.math.add`, `tf.math.multiply`, and
        `tf.math.divide`.
    operand_type: A `computation_types.Type` to use as the argument to the
      constructed binary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `operand_type` are violated or `operator`
      is not callable.
  """
  if not type_analysis.is_generic_op_compatible_type(operand_type):
    raise TypeError(
        'The type {} contains a type other than `computation_types.TensorType` '
        'and `computation_types.NamedTupleType`; this is disallowed in the '
        'generic operators.'.format(operand_type))
  py_typecheck.check_callable(operator)
  with tf.Graph().as_default() as graph:
    operand_1_value, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', operand_type, graph)
    operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph(
        'y', operand_type, graph)

    if operand_type is not None:
      if operand_type.is_tensor():
        result_value = operator(operand_1_value, operand_2_value)
      elif operand_type.is_tuple():
        result_value = anonymous_tuple.map_structure(operator, operand_1_value,
                                                     operand_2_value)
    else:
      raise TypeError(
          'Operand type {} cannot be used in generic operations. The call to '
          '`type_analysis.is_generic_op_compatible_type` has allowed it to '
          'pass, and should be updated.'.format(operand_type))
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result_value, graph)

  type_signature = computation_types.FunctionType(
      computation_types.NamedTupleType((operand_type, operand_type)),
      result_type)
  parameter_binding = pb.TensorFlow.Binding(
      tuple=pb.TensorFlow.NamedTupleBinding(
          element=[operand_1_binding, operand_2_binding]))
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature),
      tensorflow=tensorflow)
예제 #20
0
 def __init__(self, compilation_fn: Callable[[computation_base.Computation],
                                             Any]):
   py_typecheck.check_callable(compilation_fn)
   self._compilation_fn = compilation_fn
예제 #21
0
def extract_nodes_consuming(tree, predicate):
    """Returns the set of AST nodes which consume nodes matching `predicate`.

  Notice we adopt the convention that a node which itself satisfies the
  predicate is in this set.

  Args:
    tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an
      abstract syntax tree, and construct the set of nodes in this tree having a
      dependency on nodes matching `predicate`; that is, the set of nodes whose
      value depends on evaluating nodes matching `predicate`.
    predicate: One-arg callable, accepting arguments of type
      `building_blocks.ComputationBuildingBlock` and returning a `bool`
      indicating match or mismatch with the desired pattern.

  Returns:
    A `set` of `building_blocks.ComputationBuildingBlock` instances
    representing the nodes in `tree` dependent on nodes matching `predicate`.
  """
    py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock)
    py_typecheck.check_callable(predicate)
    dependent_nodes = set()

    def _are_children_in_dependent_set(comp, symbol_tree):
        """Checks if the dependencies of `comp` are present in `dependent_nodes`."""
        if isinstance(
                comp,
            (building_blocks.Intrinsic, building_blocks.Data,
             building_blocks.Placement, building_blocks.CompiledComputation)):
            return False
        elif isinstance(comp, building_blocks.Lambda):
            return comp.result in dependent_nodes
        elif isinstance(comp, building_blocks.Block):
            return any(x[1] in dependent_nodes
                       for x in comp.locals) or comp.result in dependent_nodes
        elif isinstance(comp, building_blocks.Tuple):
            return any(x in dependent_nodes for x in comp)
        elif isinstance(comp, building_blocks.Selection):
            return comp.source in dependent_nodes
        elif isinstance(comp, building_blocks.Call):
            return comp.function in dependent_nodes or comp.argument in dependent_nodes
        elif isinstance(comp, building_blocks.Reference):
            return _is_reference_dependent(comp, symbol_tree)

    def _is_reference_dependent(comp, symbol_tree):
        payload = symbol_tree.get_payload_with_name(comp.name)
        if payload is None:
            return False
        # The postorder traversal ensures that we process any
        # bindings before we process the reference to those bindings
        return payload.value in dependent_nodes

    def _populate_dependent_set(comp, symbol_tree):
        """Populates `dependent_nodes` with all nodes dependent on `predicate`."""
        if predicate(comp):
            dependent_nodes.add(comp)
        elif _are_children_in_dependent_set(comp, symbol_tree):
            dependent_nodes.add(comp)
        return comp, False

    symbol_tree = transformation_utils.SymbolTree(
        transformation_utils.ReferenceCounter)
    transformation_utils.transform_postorder_with_symbol_bindings(
        tree, _populate_dependent_set, symbol_tree)
    return dependent_nodes
예제 #22
0
def transform_preorder(
    comp: building_blocks.ComputationBuildingBlock,
    transform: Callable[[building_blocks.ComputationBuildingBlock],
                        TransformReturnType]
) -> TransformReturnType:
  """Walks the AST of `comp` preorder, calling `transform` on the way down.

  Notice that this function will stop walking the tree when its transform
  function modifies a node; this is to prevent the caller from unexpectedly
  kicking off an infinite recursion. For this purpose the transform function
  must identify when it has transformed the structure of a building block; if
  the structure of the building block is modified but `False` is returned as
  the second element of the tuple returned by `transform`, `transform_preorder`
  may result in an infinite recursion.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` to be
      transformed in a preorder fashion.
    transform: Transform function to be applied to the nodes of `comp`. Must
      return a two-tuple whose first element is a
      `building_blocks.ComputationBuildingBlock` and whose second element is a
      Boolean. If the computation which is passed to `comp` is returned in a
      modified state, must return `True` for the second element.

  Returns:
    A two-tuple, whose first element is modified version of `comp`, and
    whose second element is a Boolean indicating whether `comp` was transformed
    during the walk.

  Raises:
    TypeError: If the argument types don't match those specified above.
  """

  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
  py_typecheck.check_callable(transform)
  inner_comp, modified = transform(comp)
  if modified:
    return inner_comp, modified
  if (inner_comp.is_compiled_computation() or inner_comp.is_data() or
      inner_comp.is_intrinsic() or inner_comp.is_placement() or
      inner_comp.is_reference()):
    return inner_comp, modified
  elif inner_comp.is_lambda():
    transformed_result, result_modified = transform_preorder(
        inner_comp.result, transform)
    if not (modified or result_modified):
      return inner_comp, False
    return building_blocks.Lambda(inner_comp.parameter_name,
                                  inner_comp.parameter_type,
                                  transformed_result), True
  elif inner_comp.is_tuple():
    elements_modified = False
    elements = []
    for name, val in anonymous_tuple.iter_elements(inner_comp):
      result, result_modified = transform_preorder(val, transform)
      elements_modified = elements_modified or result_modified
      elements.append((name, result))
    if not (modified or elements_modified):
      return inner_comp, False
    return building_blocks.Tuple(elements), True
  elif inner_comp.is_selection():
    transformed_source, source_modified = transform_preorder(
        inner_comp.source, transform)
    if not (modified or source_modified):
      return inner_comp, False
    return building_blocks.Selection(transformed_source, inner_comp.name,
                                     inner_comp.index), True
  elif inner_comp.is_call():
    transformed_fn, fn_modified = transform_preorder(inner_comp.function,
                                                     transform)
    if inner_comp.argument is not None:
      transformed_arg, arg_modified = transform_preorder(
          inner_comp.argument, transform)
    else:
      transformed_arg = None
      arg_modified = False
    if not (modified or fn_modified or arg_modified):
      return inner_comp, False
    return building_blocks.Call(transformed_fn, transformed_arg), True
  elif inner_comp.is_block():
    transformed_variables = []
    values_modified = False
    for key, value in inner_comp.locals:
      transformed_value, value_modified = transform_preorder(value, transform)
      transformed_variables.append((key, transformed_value))
      values_modified = values_modified or value_modified
    transformed_result, result_modified = transform_preorder(
        inner_comp.result, transform)
    if not (modified or values_modified or result_modified):
      return inner_comp, False
    return building_blocks.Block(transformed_variables,
                                 transformed_result), True
  else:
    raise NotImplementedError(
        'Unrecognized computation building block: {}'.format(str(inner_comp)))
def build_personalization_eval(model_fn,
                               personalize_fn_dict,
                               baseline_evaluate_fn,
                               max_num_samples=100,
                               context_tff_type=None):
  """Builds the TFF computation for evaluating personalization strategies.

  The returned TFF computation broadcasts model weights from SERVER to CLIENTS.
  Each client evaluates the personalization strategies given in
  `personalize_fn_dict`. Evaluation metrics from at most `max_num_samples`
  participating clients are collected to the SERVER.

  Args:
    model_fn: A no-argument function that returns a `tff.learning.Model`.
    personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a
      strategy name) to a no-argument function that returns a `tf.function`.
      Each `tf.function` represents a personalization strategy: it accepts a
      `tff.learning.Model` (with weights already initialized to the provided
      model weights when users invoke the returned TFF computation), a training
      `tf.dataset.Dataset`, a test `tf.dataset.Dataset`, and an arbitrary
      context object (which is used to hold any extra information that a
      personalization strategy may use), trains a personalized model, and
      returns the evaluation metrics. The evaluation metrics are usually
      represented as an `OrderedDict` (or a nested `OrderedDict`) of `string`
      metric names to scalar `tf.Tensor`s.
    baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model`
      (with weights already initialized to the provided model weights when users
      invoke the returned TFF computation), and a `tf.dataset.Dataset`,
      evaluates the model on the dataset, and returns the evaluation metrics.
      The evaluation metrics are usually represented as an `OrderedDict` (or a
      nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s.
      This function is *only* used to compute the baseline metrics of the
      initial model.
    max_num_samples: A positive `int` specifying the maximum number of metric
      samples to collect in a round. Each sample contains the personalization
      metrics from a single client. If the number of participating clients in a
      round is smaller than this value, all clients' metrics are collected.
    context_tff_type: A `tff.Type` of the optional context object used by the
      personalization strategies defined in `personalization_fn_dict`. We use a
      context object to hold any extra information (in addition to the training
      dataset) that personalization may use. If context is used in
      `personalization_fn_dict`, its `tff.Type` must be provided here.

  Returns:
    A federated `tff.Computation` that maps
    < model_weights@SERVER, input@CLIENTS > -> personalization_metrics@SERVER,
    where:
    - model_weights is a `tff.learning.framework.ModelWeights`.
    - each client's input is an `OrderedDict` of at least two keys `train_data`
      and `test_data`, and each key is mapped to a `tf.dataset.Dataset`. If
      context is used in `personalize_fn_dict`, then client input has a third
      key `context` that is mapped to a object whose `tff.Type` is provided by
      the `context_tff_type` argument.
    - personazliation_metrics is an `OrderedDict` that maps a key
      'baseline_metrics' to the evaluation metrics of the initial model
      (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in
      `personalize_fn_dict` to the evaluation metrics of the corresponding
      personalization strategies.
    - Note: only metrics from at most `max_num_samples` participating clients
      are collected to the SERVER. All collected metrics are stored in a
      single `OrderedDict` (the personalization_metrics shown above), where each
      metric is mapped to a list of scalars (each scalar comes from one client).
      Metric values at the same position, e.g., metric_1[i], metric_2[i]..., all
      come from the same client.

  Raises:
    TypeError: If arguments are of the wrong types.
    ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`.
    ValueError: If `max_num_samples` is not positive.
  """
  # Obtain the types by constructing the model first.
  # TODO(b/124477628): Replace it with other ways of handling metadata.
  with tf.Graph().as_default():
    py_typecheck.check_callable(model_fn)
    model = model_utils.enhance(model_fn())
    model_weights_type = tff.framework.type_from_tensors(model.weights)
    batch_type = tff.to_type(model.input_spec)

  # Define the `tff.Type` of each client's input.
  client_input_type = collections.OrderedDict([
      ('train_data', tff.SequenceType(batch_type)),
      ('test_data', tff.SequenceType(batch_type))
  ])
  if context_tff_type is not None:
    py_typecheck.check_type(context_tff_type, tff.Type)
    client_input_type['context'] = context_tff_type
  client_input_type = tff.to_type(client_input_type)

  @tff.tf_computation(model_weights_type, client_input_type)
  def _client_computation(initial_model_weights, client_input):
    """TFF computation that runs on each client."""
    model = model_fn()
    train_data = client_input['train_data']
    test_data = client_input['test_data']
    context = client_input.get('context', None)
    return _client_fn(model, initial_model_weights, train_data, test_data,
                      personalize_fn_dict, baseline_evaluate_fn, context)

  py_typecheck.check_type(max_num_samples, int)
  if max_num_samples <= 0:
    raise ValueError('max_num_samples must be a positive integer.')

  @tff.federated_computation(
      tff.FederatedType(model_weights_type, tff.SERVER),
      tff.FederatedType(client_input_type, tff.CLIENTS))
  def personalization_eval(server_model_weights, federated_client_input):
    """TFF orchestration logic."""
    client_init_weights = tff.federated_broadcast(server_model_weights)
    client_final_metrics = tff.federated_map(
        _client_computation, (client_init_weights, federated_client_input))

    # WARNING: Collecting information from clients can be risky. Users have to
    # make sure that it is proper to collect those metrics from clients.
    # TODO(b/147889283): Add a link to the TFF doc once it exists.
    results = tff.utils.federated_sample(client_final_metrics, max_num_samples)
    return results

  return personalization_eval
예제 #24
0
def build_model_delta_client_work(
    model_fn: Callable[[], model_lib.Model],
    optimizer: Union[optimizer_base.Optimizer,
                     Callable[[], tf.keras.optimizers.Optimizer]],
    client_weighting: client_weight_lib.ClientWeighting,
    delta_l2_regularizer: float = 0.0,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    *,
    use_experimental_simulation_loop: bool = False
) -> client_works.ClientWorkProcess:
    """Creates a `ClientWorkProcess` for federated averaging.

  This client work is constructed in slightly different manners depending on
  whether `optimizer` is a `tff.learning.optimizers.Optimizer`, or a no-arg
  callable returning a `tf.keras.optimizers.Optimizer`.

  If it is a `tff.learning.optimizers.Optimizer`, we avoid creating
  `tf.Variable`s associated with the optimizer state within the scope of the
  client work, as they are not necessary. This also means that the client's
  model weights are updated by computing `optimizer.next` and then assigning
  the result to the model weights (while a `tf.keras.optimizers.Optimizer` will
  modify the model weight in place using `optimizer.apply_gradients`).

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    optimizer: A `tff.learning.optimizers.Optimizer`, or a no-arg callable that
      returns a `tf.keras.Optimizer`.
    client_weighting:  A `tff.learning.ClientWeighting` value.
    delta_l2_regularizer: A nonnegative float representing the parameter of the
      L2-regularization term applied to the delta from initial model weights
      during training. Values larger than 0.0 prevent clients from moving too
      far from the server model during local training.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `ClientWorkProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)
    py_typecheck.check_type(delta_l2_regularizer, float)
    if delta_l2_regularizer < 0.0:
        raise ValueError(f'Provided delta_l2_regularizer must be non-negative,'
                         f'but found: {delta_l2_regularizer}')
    if not (isinstance(optimizer, optimizer_base.Optimizer)
            or callable(optimizer)):
        raise TypeError(
            'Provided optimizer must a either a tff.learning.optimizers.Optimizer '
            'or a no-arg callable returning an tf.keras.optimizers.Optimizer.')

    if metrics_aggregator is None:
        metrics_aggregator = aggregator.sum_then_finalize

    with tf.Graph().as_default():
        # Wrap model construction in a graph to avoid polluting the global context
        # with variables created for this model.
        model = model_fn()
        unfinalized_metrics_type = type_conversions.type_from_tensors(
            model.report_local_unfinalized_metrics())
        metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(),
                                                    unfinalized_metrics_type)
    data_type = computation_types.SequenceType(model.input_spec)
    weights_type = model_utils.weights_type_from_model(model)

    if isinstance(optimizer, optimizer_base.Optimizer):

        @tensorflow_computation.tf_computation(weights_type, data_type)
        def client_update_computation(initial_model_weights, dataset):
            client_update = build_model_delta_update_with_tff_optimizer(
                model_fn=model_fn,
                weighting=client_weighting,
                delta_l2_regularizer=delta_l2_regularizer,
                use_experimental_simulation_loop=
                use_experimental_simulation_loop)
            return client_update(optimizer, initial_model_weights, dataset)

    else:

        @tensorflow_computation.tf_computation(weights_type, data_type)
        def client_update_computation(initial_model_weights, dataset):
            keras_optimizer = optimizer()
            client_update = build_model_delta_update_with_keras_optimizer(
                model_fn=model_fn,
                weighting=client_weighting,
                delta_l2_regularizer=delta_l2_regularizer,
                use_experimental_simulation_loop=
                use_experimental_simulation_loop)
            return client_update(keras_optimizer, initial_model_weights,
                                 dataset)

    @federated_computation.federated_computation
    def init_fn():
        return intrinsics.federated_value((), placements.SERVER)

    @federated_computation.federated_computation(
        init_fn.type_signature.result,
        computation_types.at_clients(weights_type),
        computation_types.at_clients(data_type))
    def next_fn(state, weights, client_data):
        client_result, model_outputs = intrinsics.federated_map(
            client_update_computation, (weights, client_data))
        train_metrics = metrics_aggregation_fn(model_outputs)
        measurements = intrinsics.federated_zip(
            collections.OrderedDict(train=train_metrics))
        return measured_process.MeasuredProcessOutput(state, client_result,
                                                      measurements)

    return client_works.ClientWorkProcess(init_fn, next_fn)
예제 #25
0
def build_model_delta_optimizer_process(
    model_fn,
    model_to_client_delta_fn,
    server_optimizer_fn,
    stateful_delta_aggregate_fn=build_stateless_mean(),
    stateful_model_broadcast_fn=build_stateless_broadcaster()):
    """Constructs `tff.utils.IterativeProcess` for Federated Averaging or SGD.

  This provides the TFF orchestration logic connecting the common server logic
  which applies aggregated model deltas to the server model with a ClientDeltaFn
  that specifies how weight_deltas are computed on device.

  Note: We pass in functions rather than constructed objects so we can ensure
  any variables or ops created in constructors are placed in the correct graph.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    model_to_client_delta_fn: A function from a model_fn to a `ClientDeltaFn`.
    server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model.
    stateful_delta_aggregate_fn: A `tff.utils.StatefulAggregateFn` where the
      next_fn performs a federated aggregation and upates state. That is, it has
      TFF type (state@SERVER, value@CLIENTS) -> (state@SERVER,
      aggregate@SERVER).
    stateful_model_broadcast_fn: A `tff.utils.StatefulBroadcastFn` where the
      next_fn performs a federated broadcast and upates state. That is, it has
      TFF type (state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS).

  Returns:
    A `tff.utils.IterativeProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_callable(model_to_client_delta_fn)
    py_typecheck.check_callable(server_optimizer_fn)
    py_typecheck.check_type(stateful_delta_aggregate_fn,
                            tff.utils.StatefulAggregateFn)
    py_typecheck.check_type(stateful_model_broadcast_fn,
                            tff.utils.StatefulBroadcastFn)

    # TODO(b/122081673): would be nice not to have the construct a throwaway model
    # here just to get the types. After fully moving to TF2.0 and eager-mode, we
    # should re-evaluate what happens here.
    with tf.Graph().as_default():
        dummy_model_for_metadata = model_utils.enhance(model_fn())

    @tff.tf_computation
    def tf_init_fn():
        return server_init(model_fn, server_optimizer_fn,
                           stateful_delta_aggregate_fn.initialize(),
                           stateful_model_broadcast_fn.initialize())

    @tff.federated_computation
    def server_init_tff():
        """Orchestration logic for server model initialization."""
        return tff.federated_value(tf_init_fn(), tff.SERVER)

    federated_server_state_type = server_init_tff.type_signature.result
    server_state_type = tf_init_fn.type_signature.result

    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)
    federated_dataset_type = tff.FederatedType(tf_dataset_type,
                                               tff.CLIENTS,
                                               all_equal=False)

    @tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round_tff(server_state, federated_dataset):
        """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
    `tff.learning.Model.federated_output_computation`.
    """
        model_weights_type = server_state_type.model

        @tff.tf_computation(tf_dataset_type, model_weights_type)
        def client_delta_tf(tf_dataset, initial_model_weights):
            """Performs client local model optimization.

      Args:
        tf_dataset: a `tf.data.Dataset` that provides training examples.
        initial_model_weights: a `model_utils.ModelWeights` containing the
          starting weights.

      Returns:
        A `ClientOutput` structure.
      """
            client_delta_fn = model_to_client_delta_fn(model_fn)
            client_output = client_delta_fn(tf_dataset, initial_model_weights)
            return client_output

        new_broadcaster_state, client_model = stateful_model_broadcast_fn(
            server_state.model_broadcast_state, server_state.model)

        client_outputs = tff.federated_map(client_delta_tf,
                                           (federated_dataset, client_model))

        @tff.tf_computation(
            server_state_type, model_weights_type.trainable,
            server_state.delta_aggregate_state.type_signature.member,
            server_state.model_broadcast_state.type_signature.member)
        def server_update_tf(server_state, model_delta,
                             new_delta_aggregate_state, new_broadcaster_state):
            """Converts args to correct python types and calls server_update_model."""
            py_typecheck.check_type(server_state, ServerState)
            server_state = ServerState(
                model=server_state.model,
                optimizer_state=list(server_state.optimizer_state),
                delta_aggregate_state=new_delta_aggregate_state,
                model_broadcast_state=new_broadcaster_state)

            return server_update_model(server_state,
                                       model_delta,
                                       model_fn=model_fn,
                                       optimizer_fn=server_optimizer_fn)

        # TODO(b/124070381): We hope to remove this explicit cast once we have a
        # full solution for type analysis in multiplications and divisions
        # inside TFF
        fed_weight_type = client_outputs.weights_delta_weight.type_signature.member
        py_typecheck.check_type(fed_weight_type, tff.TensorType)
        if fed_weight_type.dtype.is_integer:

            @tff.tf_computation(fed_weight_type)
            def _cast_to_float(x):
                return tf.cast(x, tf.float32)

            weight_denom = tff.federated_map(
                _cast_to_float, client_outputs.weights_delta_weight)
        else:
            weight_denom = client_outputs.weights_delta_weight

        new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn(
            server_state.delta_aggregate_state,
            client_outputs.weights_delta,
            weight=weight_denom)

        # TODO(b/123408447): remove tff.federated_apply and call
        # server_update_tf directly once T <-> T@SERVER isomorphism is
        # supported.
        server_state = tff.federated_apply(
            server_update_tf,
            (server_state, round_model_delta, new_delta_aggregate_state,
             new_broadcaster_state))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)

        # Promote the FederatedType outside the NamedTupleType
        aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs

    return tff.utils.IterativeProcess(initialize_fn=server_init_tff,
                                      next_fn=run_one_round_tff)
예제 #26
0
def serialize_tf2_as_tf_computation(target, parameter_type, unpack=None):
  """Serializes the 'target' as a TF computation with a given parameter type.

  Args:
    target: The entity to convert into and serialize as a TF computation. This
      can currently only be a Python function or `tf.function`, with arguments
      matching the 'parameter_type'.
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `types.Type`, or something that's convertible to it by
      `types.to_type()`.
    unpack: Whether to always unpack the parameter_type. Necessary for support
      of polymorphic tf2_computations.

  Returns:
    The constructed `pb.Computation` instance with the `pb.TensorFlow` variant
      set.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
  py_typecheck.check_callable(target)
  parameter_type = computation_types.to_type(parameter_type)
  signature = function_utils.get_signature(target)
  if signature.parameters and parameter_type is None:
    raise ValueError(
        'Expected the target to declare no parameters, found {!r}.'.format(
            signature.parameters))

  # In the codepath for TF V1 based serialization (tff.tf_computation),
  # we get the "wrapped" function to serialize. Here, target is the
  # raw function to be wrapped; however, we still need to know if
  # the parameter_type should be unpacked into multiple args and kwargs
  # in order to construct the TensorSpecs to be passed in the call
  # to get_concrete_fn below.
  unpack = function_utils.infer_unpack_needed(target, parameter_type, unpack)
  arg_typespecs, kwarg_typespecs, parameter_binding = (
      tensorflow_utils.get_tf_typespec_and_binding(
          parameter_type,
          arg_names=list(signature.parameters.keys()),
          unpack=unpack))

  # Pseudo-global to be appended to once when target_poly below is traced.
  type_and_binding_slot = []

  # N.B. To serialize a tf.function or eager python code,
  # the return type must be a flat list, tuple, or dict. However, the
  # tff.tf_computation must be able to handle structured inputs and outputs.
  # Thus, we intercept the result of calling the original target fn, introspect
  # its structure to create a result_type and bindings, and then return a
  # flat dict output. It is this new "unpacked" tf.function that we will
  # serialize using tf.saved_model.save.
  #
  # TODO(b/117428091): The return type limitation is primarily a limitation of
  # SignatureDefs  and therefore of the signatures argument to
  # tf.saved_model.save. tf.functions attached to objects and loaded back with
  # tf.saved_model.load can take/return nests; this might offer a better
  # approach to the one taken here.

  @tf.function
  def target_poly(*args, **kwargs):
    result = target(*args, **kwargs)
    result_dict, result_type, result_binding = (
        tensorflow_utils.get_tf2_result_dict_and_binding(result))
    assert not type_and_binding_slot
    # A "side channel" python output.
    type_and_binding_slot.append((result_type, result_binding))
    return result_dict

  # Triggers tracing so that type_and_binding_slot is filled.
  cc_fn = target_poly.get_concrete_function(*arg_typespecs, **kwarg_typespecs)
  assert len(type_and_binding_slot) == 1
  result_type, result_binding = type_and_binding_slot[0]

  # N.B. Note that cc_fn does *not* accept the same args and kwargs as the
  # Python target_poly; instead, it must be called with **kwargs based on the
  # unique names embedded in the TensorSpecs inside arg_typespecs and
  # kwarg_typespecs. The (preliminary) parameter_binding tracks the mapping
  # between these tensor names and the components of the (possibly nested) TFF
  # input type. When cc_fn is serialized, concrete tensors for each input are
  # introduced, and the call finalize_binding(parameter_binding,
  # sigs['serving_default'].inputs) updates the bindings to reference these
  # concrete tensors.

  # Associate vars with unique names and explicitly attach to the Checkpoint:
  var_dict = {
      'var{:02d}'.format(i): v for i, v in enumerate(cc_fn.graph.variables)
  }
  saveable = tf.train.Checkpoint(fn=target_poly, **var_dict)

  try:
    # TODO(b/122081673): All we really need is the  meta graph def, we could
    # probably just load that directly, e.g., using parse_saved_model from
    # tensorflow/python/saved_model/loader_impl.py, but I'm not sure we want to
    # depend on that presumably non-public symbol. Perhaps TF can expose a way
    # to just get the MetaGraphDef directly without saving to a tempfile? This
    # looks like a small change to v2.saved_model.save().
    outdir = tempfile.mkdtemp('savedmodel')
    tf.saved_model.save(saveable, outdir, signatures=cc_fn)

    graph = tf.Graph()
    with tf.compat.v1.Session(graph=graph) as sess:
      mgd = tf.compat.v1.saved_model.load(
          sess, tags=[tf.saved_model.SERVING], export_dir=outdir)
  finally:
    shutil.rmtree(outdir)
  sigs = mgd.signature_def

  # TODO(b/123102455): Figure out how to support the init_op. The meta graph def
  # contains sigs['__saved_model_init_op'].outputs['__saved_model_init_op']. It
  # probably won't do what we want, because it will want to read from
  # Checkpoints, not just run Variable initializerse (?). The right solution may
  # be to grab the target_poly.get_initialization_function(), and save a sig for
  # that.

  # Now, traverse the signature from the MetaGraphDef to find
  # find the actual tensor names and write them into the bindings.
  finalize_binding(parameter_binding, sigs['serving_default'].inputs)
  finalize_binding(result_binding, sigs['serving_default'].outputs)

  annotated_type = computation_types.FunctionType(parameter_type, result_type)

  return pb.Computation(
      type=pb.Type(
          function=pb.FunctionType(
              parameter=type_serialization.serialize_type(parameter_type),
              result=type_serialization.serialize_type(result_type))),
      tensorflow=pb.TensorFlow(
          graph_def=serialization_utils.pack_graph_def(mgd.graph_def),
          parameter=parameter_binding,
          result=result_binding)), annotated_type
예제 #27
0
def build_model_delta_optimizer_process(
    model_fn,
    model_to_client_delta_fn,
    server_optimizer_fn,
    stateful_delta_aggregate_fn=build_stateless_mean(),
    stateful_model_broadcast_fn=build_stateless_broadcaster()):
    """Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD.

  This provides the TFF orchestration logic connecting the common server logic
  which applies aggregated model deltas to the server model with a
  `ClientDeltaFn` that specifies how `weight_deltas` are computed on device.

  Note: We pass in functions rather than constructed objects so we can ensure
  any variables or ops created in constructors are placed in the correct graph.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    model_to_client_delta_fn: A function from a `model_fn` to a `ClientDeltaFn`.
    server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model.
    stateful_delta_aggregate_fn: A `tff.utils.StatefulAggregateFn` where the
      `next_fn` performs a federated aggregation and upates state. That is, it
      has TFF type `(state@SERVER, value@CLIENTS, weights@CLIENTS) ->
      (state@SERVER, aggregate@SERVER)`, where the `value` type is
      `tff.learning.framework.ModelWeights.trainable` corresponding to the
      object returned by `model_fn`.
    stateful_model_broadcast_fn: A `tff.utils.StatefulBroadcastFn` where the
      `next_fn` performs a federated broadcast and upates state. That is, it has
      TFF type `(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS)`,
      where the `value` type is `tff.learning.framework.ModelWeights`
      corresponding to the object returned by `model_fn`.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_callable(model_to_client_delta_fn)
    py_typecheck.check_callable(server_optimizer_fn)
    py_typecheck.check_type(stateful_delta_aggregate_fn,
                            tff.utils.StatefulAggregateFn)
    py_typecheck.check_type(stateful_model_broadcast_fn,
                            tff.utils.StatefulBroadcastFn)

    # TODO(b/124477628): would be nice not to have the construct a throwaway model
    # here just to get the types. After fully moving to TF2.0 and eager-mode, we
    # should re-evaluate what happens here.
    # TODO(b/144382142): Keras name uniquification is probably the main reason we
    # still need this.
    with tf.Graph().as_default():
        dummy_model_for_metadata = model_utils.enhance(model_fn())

    # ===========================================================================
    # TensorFlow Computations

    @tff.tf_computation
    def tf_init_fn():
        return server_init(model_fn, server_optimizer_fn,
                           stateful_delta_aggregate_fn.initialize(),
                           stateful_model_broadcast_fn.initialize())

    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)
    server_state_type = tf_init_fn.type_signature.result

    @tff.tf_computation(tf_dataset_type, server_state_type.model)
    def tf_client_delta(tf_dataset, initial_model_weights):
        """Performs client local model optimization.

    Args:
      tf_dataset: a `tf.data.Dataset` that provides training examples.
      initial_model_weights: a `model_utils.ModelWeights` containing the
        starting weights.

    Returns:
      A `ClientOutput` structure.
    """
        client_delta_fn = model_to_client_delta_fn(model_fn)
        client_output = client_delta_fn(tf_dataset, initial_model_weights)
        return client_output

    @tff.tf_computation(server_state_type, server_state_type.model.trainable,
                        server_state_type.delta_aggregate_state,
                        server_state_type.model_broadcast_state)
    def tf_server_update(server_state, model_delta, new_delta_aggregate_state,
                         new_broadcaster_state):
        """Converts args to correct python types and calls server_update_model."""
        py_typecheck.check_type(server_state, ServerState)
        server_state = ServerState(
            model=server_state.model,
            optimizer_state=list(server_state.optimizer_state),
            delta_aggregate_state=new_delta_aggregate_state,
            model_broadcast_state=new_broadcaster_state)

        return server_update_model(server_state,
                                   model_delta,
                                   model_fn=model_fn,
                                   optimizer_fn=server_optimizer_fn)

    weight_type = tf_client_delta.type_signature.result.weights_delta_weight

    @tff.tf_computation(weight_type)
    def _cast_weight_to_float(x):
        return tf.cast(x, tf.float32)

    # ===========================================================================
    # Federated Computations

    @tff.federated_computation
    def server_init_tff():
        """Orchestration logic for server model initialization."""
        return tff.federated_value(tf_init_fn(), tff.SERVER)

    federated_server_state_type = server_init_tff.type_signature.result
    federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

    @tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round_tff(server_state, federated_dataset):
        """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
    `tff.learning.Model.federated_output_computation`.
    """
        new_broadcaster_state, client_model = stateful_model_broadcast_fn(
            server_state.model_broadcast_state, server_state.model)

        client_outputs = tff.federated_map(tf_client_delta,
                                           (federated_dataset, client_model))

        # TODO(b/124070381): We hope to remove this explicit cast once we have a
        # full solution for type analysis in multiplications and divisions
        # inside TFF
        weight_denom = tff.federated_map(_cast_weight_to_float,
                                         client_outputs.weights_delta_weight)
        new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn(
            server_state.delta_aggregate_state,
            client_outputs.weights_delta,
            weight=weight_denom)

        server_state = tff.federated_map(
            tf_server_update,
            (server_state, round_model_delta, new_delta_aggregate_state,
             new_broadcaster_state))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)

        if isinstance(aggregated_outputs.type_signature, tff.NamedTupleType):
            # Promote the FederatedType outside the NamedTupleType.
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs

    return tff.templates.IterativeProcess(initialize_fn=server_init_tff,
                                          next_fn=run_one_round_tff)
예제 #28
0
def build_model_delta_optimizer_process(
    model_fn: _ModelConstructor,
    model_to_client_delta_fn: Callable[[model_lib.Model], ClientDeltaFn],
    server_optimizer_fn: _OptimizerConstructor,
    stateful_delta_aggregate_fn: Optional[
        tff.utils.StatefulAggregateFn] = None,
    stateful_model_broadcast_fn: Optional[
        tff.utils.StatefulBroadcastFn] = None,
    *,
    broadcast_process: Optional[tff.templates.MeasuredProcess] = None,
    aggregation_process: Optional[tff.templates.MeasuredProcess] = None,
) -> tff.templates.IterativeProcess:
    """Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD.

  This provides the TFF orchestration logic connecting the common server logic
  which applies aggregated model deltas to the server model with a
  `ClientDeltaFn` that specifies how `weight_deltas` are computed on device.

  Note: We pass in functions rather than constructed objects so we can ensure
  any variables or ops created in constructors are placed in the correct graph.

  Args:
    model_fn: a no-arg function that returns a `tff.learning.Model`.
    model_to_client_delta_fn: a function from a `model_fn` to a `ClientDeltaFn`.
    server_optimizer_fn: a no-arg function that returns a `tf.Optimizer`. The
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model.
    stateful_delta_aggregate_fn: a `tff.utils.StatefulAggregateFn` where the
      `next_fn` performs a federated aggregation and updates state. That is, it
      has TFF type `(state@SERVER, value@CLIENTS, weights@CLIENTS) ->
      (state@SERVER, aggregate@SERVER)`, where the `value` type is
      `tff.learning.framework.ModelWeights.trainable` corresponding to the
      object returned by `model_fn`.
    stateful_model_broadcast_fn: a `tff.utils.StatefulBroadcastFn` where the
      `next_fn` performs a federated broadcast and updates state. That is, it
      has TFF type `(state@SERVER, value@SERVER) -> (state@SERVER,
      value@CLIENTS)`, where the `value` type is
      `tff.learning.framework.ModelWeights`
      corresponding to the object returned by `model_fn`.
    broadcast_process: a `tff.templates.MeasuredProcess` that broadcasts the
      model weights on the server to the clients. It must support the signature
      `(input_values@SERVER -> output_values@CLIENT)`.
    aggregation_process: a `tff.templates.MeasuredProcess` that aggregates the
      model updates on the clients back to the server. It must support the
      signature `({input_values}@CLIENTS-> output_values@SERVER)`.

  Returns:
    A `tff.templates.IterativeProcess`.

  Raises:
    ProcessTypeError: if `broadcast_process` or `aggregation_process` do not
    conform to the signature of broadcast (SERVER->CLIENTS) or aggregation
    (CLIENTS->SERVER).
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_callable(model_to_client_delta_fn)
    py_typecheck.check_callable(server_optimizer_fn)

    with tf.Graph().as_default():
        dummy_model_for_metadata = model_fn()
        model_weights_type = tff.framework.type_from_tensors(
            model_utils.ModelWeights.from_model(dummy_model_for_metadata))

    # TODO(b/159138779): remove the StatefulFn arguments and these validation
    # functions once all callers are migrated.
    def validate_disjoint_optional_arguments(
        stateful_fn: Optional[Union[tff.utils.StatefulBroadcastFn,
                                    tff.utils.StatefulAggregateFn]],
        process: Optional[tff.templates.MeasuredProcess],
        process_input_type: Union[tff.NamedTupleType, tff.TensorType],
    ) -> Optional[tff.templates.MeasuredProcess]:
        """Validate that only one of two arguments is specified.

    This validates that only the `tff.templates.MeasuredProcess` or the
    `tff.utils.StatefulFn` is specified, and converts the `tff.utils.StatefulFn`
    to a `tff.templates.MeasuredProcess` if possible.  This a bridge while
    transition to `tff.templates.MeasuredProcess`.

    Args:
      stateful_fn: an optional `tff.utils.StatefulFn` that will be wrapped if
        specified.
      process: an optional `tff.templates.MeasuredProcess` that will be returned
        as-is.
      process_input_type: the input type used when wrapping `stateful_fn`.

    Returns:
      `None` if neither argument is specified, otherwise a
      `tff.templates.MeasuredProcess`.

    Raises:
      DisjointArgumentError: if both `stateful_fn` and `process` are not `None`.
    """
        if stateful_fn is not None:
            py_typecheck.check_type(
                stateful_fn,
                (tff.utils.StatefulBroadcastFn, tff.utils.StatefulAggregateFn))
            if process is not None:
                raise DisjointArgumentError(
                    'Specifying both arguments is an error. Only one may be used'
                )
            return _wrap_in_measured_process(stateful_fn,
                                             input_type=process_input_type)
        return process

    try:
        aggregation_process = validate_disjoint_optional_arguments(
            stateful_delta_aggregate_fn, aggregation_process,
            model_weights_type.trainable)
    except DisjointArgumentError as e:
        raise DisjointArgumentError(
            'Specifying both `stateful_delta_aggregate_fn` and '
            '`aggregation_process` is an error. Only one may be used') from e

    try:
        broadcast_process = validate_disjoint_optional_arguments(
            stateful_model_broadcast_fn, broadcast_process, model_weights_type)
    except DisjointArgumentError as e:
        raise DisjointArgumentError(
            'Specifying both `stateful_model_broadcast_fn` and '
            '`broadcast_process` is an error. Only one may be used') from e

    if broadcast_process is None:
        broadcast_process = build_stateless_broadcaster(
            model_weights_type=model_weights_type)
    if not _is_valid_broadcast_process(broadcast_process):
        raise ProcessTypeError(
            'broadcast_process type signature does not conform to expected '
            'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).'
            ' Got: {t}'.format(t=broadcast_process.next.type_signature))

    if aggregation_process is None:
        aggregation_process = build_stateless_mean(
            model_delta_type=model_weights_type.trainable)
    if not _is_valid_aggregation_process(aggregation_process):
        raise ProcessTypeError(
            'aggregation_process type signature does not conform to expected '
            'signature (<state@S, input@C> -> <state@S, result@S, measurements@S>).'
            ' Got: {t}'.format(t=aggregation_process.next.type_signature))

    initialize_computation = _build_initialize_computation(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        broadcast_process=broadcast_process,
        aggregation_process=aggregation_process)

    run_one_round_computation = _build_one_round_computation(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        model_to_client_delta_fn=model_to_client_delta_fn,
        broadcast_process=broadcast_process,
        aggregation_process=aggregation_process)

    return tff.templates.IterativeProcess(initialize_fn=initialize_computation,
                                          next_fn=run_one_round_computation)
예제 #29
0
def build_model_delta_optimizer_process(
    model_fn: _ModelConstructor,
    model_to_client_delta_fn: Callable[[Callable[[], model_lib.Model]],
                                       ClientDeltaFn],
    server_optimizer_fn: _OptimizerConstructor,
    *,
    broadcast_process: Optional[measured_process.MeasuredProcess] = None,
    aggregation_process: Optional[measured_process.MeasuredProcess] = None,
    model_update_aggregation_factory: Optional[
        factory.AggregationFactory] = None,
) -> iterative_process.IterativeProcess:
    """Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD.

  This provides the TFF orchestration logic connecting the common server logic
  which applies aggregated model deltas to the server model with a
  `ClientDeltaFn` that specifies how `weight_deltas` are computed on device.

  Note: We pass in functions rather than constructed objects so we can ensure
  any variables or ops created in constructors are placed in the correct graph.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    model_to_client_delta_fn: A function from a `model_fn` to a `ClientDeltaFn`.
    server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model.
    broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the
      model weights on the server to the clients. It must support the signature
      `(input_values@SERVER -> output_values@CLIENT)`.
    aggregation_process: A `tff.templates.MeasuredProcess` that aggregates the
      model updates on the clients back to the server. It must support the
      signature `({input_values}@CLIENTS-> output_values@SERVER)`. Must be
      `None` if `model_update_aggregation_factory` is not `None.`
    model_update_aggregation_factory: An optional
      `tff.aggregators.WeightedAggregationFactory` that contstructs
      `tff.templates.AggregationProcess` for aggregating the client model
      updates on the server. If `None`, uses a default constructed
      `tff.aggregators.MeanFactory`, creating a stateless mean aggregation. Must
      be `None` if `aggregation_process` is not `None.`

  Returns:
    A `tff.templates.IterativeProcess`.

  Raises:
    ProcessTypeError: if `broadcast_process` or `aggregation_process` do not
      conform to the signature of broadcast (SERVER->CLIENTS) or aggregation
      (CLIENTS->SERVER).
    DisjointArgumentError: if both `aggregation_process` and
      `model_update_aggregation_factory` are not `None`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_callable(model_to_client_delta_fn)
    py_typecheck.check_callable(server_optimizer_fn)

    model_weights_type = model_utils.weights_type_from_model(model_fn)

    if broadcast_process is None:
        broadcast_process = build_stateless_broadcaster(
            model_weights_type=model_weights_type)
    if not _is_valid_broadcast_process(broadcast_process):
        raise ProcessTypeError(
            'broadcast_process type signature does not conform to expected '
            'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).'
            ' Got: {t}'.format(t=broadcast_process.next.type_signature))

    if (model_update_aggregation_factory is not None
            and aggregation_process is not None):
        raise DisjointArgumentError(
            'Must specify only one of `model_update_aggregation_factory` and '
            '`AggregationProcess`.')

    if aggregation_process is None:
        if model_update_aggregation_factory is None:
            model_update_aggregation_factory = mean_factory.MeanFactory()

        py_typecheck.check_type(model_update_aggregation_factory,
                                factory.AggregationFactory.__args__)

        if isinstance(model_update_aggregation_factory,
                      factory.WeightedAggregationFactory):
            aggregation_process = model_update_aggregation_factory.create(
                model_weights_type.trainable,
                computation_types.TensorType(tf.float32))
        else:
            aggregation_process = model_update_aggregation_factory.create(
                model_weights_type.trainable)
    else:
        next_num_args = len(aggregation_process.next.type_signature.parameter)
        if next_num_args not in [2, 3]:
            raise ValueError(
                f'`next` function of `aggregation_process` must take two (for '
                f'unweighted aggregation) or three (for weighted aggregation) '
                f'arguments. Found {next_num_args}.')

    if not _is_valid_aggregation_process(aggregation_process):
        raise ProcessTypeError(
            'aggregation_process type signature does not conform to expected '
            'signature (<state@S, input@C> -> <state@S, result@S, measurements@S>).'
            ' Got: {t}'.format(t=aggregation_process.next.type_signature))

    initialize_computation = _build_initialize_computation(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        broadcast_process=broadcast_process,
        aggregation_process=aggregation_process)

    run_one_round_computation = _build_one_round_computation(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        model_to_client_delta_fn=model_to_client_delta_fn,
        broadcast_process=broadcast_process,
        aggregation_process=aggregation_process)

    return iterative_process.IterativeProcess(
        initialize_fn=initialize_computation,
        next_fn=run_one_round_computation)
예제 #30
0
def to_representation_for_type(value, type_spec, callable_handler=None):
  """Verifies or converts the `value` representation to match `type_spec`.

  This method first tries to determine whether `value` is a valid representation
  of TFF type `type_spec`. If so, it is returned unchanged. If not, but if it
  can be converted into a valid representation, it is converted to such, and the
  valid representation is returned. If no conversion to a valid representation
  is possible, TypeError is raised.

  The accepted forms of `value` for various TFF types are as follows:

  *   For TFF tensor types listed in
      `tensorflow_utils.TENSOR_REPRESENTATION_TYPES`.

  *   For TFF named tuple types, instances of `anonymous_tuple.AnonymousTuple`.

  *   For TFF sequences, Python lists.

  *   For TFF functional types, Python callables that accept a single argument
      that is an instance of `ComputedValue` (if the function has a parameter)
      or `None` (otherwise), and return a `ComputedValue` instance as a result.
      This function only verifies that `value` is a callable.

  *   For TFF abstract types, there is no valid representation. The reference
      executor requires all types in an executable computation to be concrete.

  *   For TFF placement types, the valid representations are the placement
      literals (currently only `tff.SERVER` and `tff.CLIENTS`).

  *   For TFF federated types with `all_equal` set to `True`, the representation
      is the same as the representation of the member constituent (thus, e.g.,
      a valid representation of `int32@SERVER` is the same as that of `int32`).
      For those types that have `all_equal_` set to `False`, the representation
      is a Python list of member constituents.

      NOTE: This function does not attempt at validating that the sizes of lists
      that represent federated values match the corresponding placemenets. The
      cardinality analysis is a separate step, handled by the reference executor
      at a different point. As long as values can be packed into a Python list,
      they are accepted as they are.

  Args:
    value: The raw representation of a value to compare against `type_spec` and
      potentially to be converted into a canonical form for the given TFF type.
    type_spec: The TFF type, an instance of `tff.Type` or something convertible
      to it that determines what the valid representation should be.
    callable_handler: The function to invoke to handle TFF functional types. If
      this is `None`, functional types are not supported. The function must
      accept `value` and `type_spec` as arguments and return the converted valid
      representation, just as `to_representation_for_type`.

  Returns:
    Either `value` itself, or the `value` converted into a valid representation
    for `type_spec`.

  Raises:
    TypeError: If `value` is not a valid representation for given `type_spec`.
    NotImplementedError: If verification for `type_spec` is not supported.
  """
  type_spec = computation_types.to_type(type_spec)
  py_typecheck.check_type(type_spec, computation_types.Type)
  if callable_handler is not None:
    py_typecheck.check_callable(callable_handler)

  # NOTE: We do not simply call `type_utils.infer_type()` on `value`, as the
  # representations of values in the reference executor are only a subset of
  # the Python types recognized by that helper function.

  if isinstance(type_spec, computation_types.TensorType):
    if tf.executing_eagerly() and isinstance(value, (tf.Tensor, tf.Variable)):
      value = value.numpy()
    py_typecheck.check_type(value, tensorflow_utils.TENSOR_REPRESENTATION_TYPES)
    inferred_type_spec = type_utils.infer_type(value)
    if not type_utils.is_assignable_from(type_spec, inferred_type_spec):
      raise TypeError(
          'The tensor type {} of the value representation does not match '
          'the type spec {}.'.format(inferred_type_spec, type_spec))
    return value
  elif isinstance(type_spec, computation_types.NamedTupleType):
    type_spec_elements = anonymous_tuple.to_elements(type_spec)
    # Special-casing unodered dictionaries to allow their elements to be fed in
    # the order in which they're defined in the named tuple type.
    if (isinstance(value, dict) and
        (set(value.keys()) == set(k for k, _ in type_spec_elements))):
      value = collections.OrderedDict([
          (k, value[k]) for k, _ in type_spec_elements
      ])
    value = anonymous_tuple.from_container(value)
    value_elements = anonymous_tuple.to_elements(value)
    if len(value_elements) != len(type_spec_elements):
      raise TypeError(
          'The number of elements {} in the value tuple {} does not match the '
          'number of elements {} in the type spec {}.'.format(
              len(value_elements), value, len(type_spec_elements), type_spec))
    result_elements = []
    for index, (type_elem_name, type_elem) in enumerate(type_spec_elements):
      value_elem_name, value_elem = value_elements[index]
      if value_elem_name not in [type_elem_name, None]:
        raise TypeError(
            'Found element named `{}` where `{}` was expected at position {} '
            'in the value tuple. Value: {}. Type: {}'.format(
                value_elem_name, type_elem_name, index, value, type_spec))
      converted_value_elem = to_representation_for_type(value_elem, type_elem,
                                                        callable_handler)
      result_elements.append((type_elem_name, converted_value_elem))
    return anonymous_tuple.AnonymousTuple(result_elements)
  elif isinstance(type_spec, computation_types.SequenceType):
    if isinstance(value, tf.data.Dataset):
      inferred_type_spec = computation_types.SequenceType(
          computation_types.to_type(tf.data.experimental.get_structure(value)))
      if not type_utils.is_assignable_from(type_spec, inferred_type_spec):
        raise TypeError(
            'Value of type {!s} not assignable to expected type {!s}'.format(
                inferred_type_spec, type_spec))
      if tf.executing_eagerly():
        return [
            to_representation_for_type(v, type_spec.element, callable_handler)
            for v in value
        ]
      else:
        raise ValueError(
            'Processing `tf.data.Datasets` outside of eager mode is not '
            'currently supported.')
    return [
        to_representation_for_type(v, type_spec.element, callable_handler)
        for v in value
    ]
  elif isinstance(type_spec, computation_types.FunctionType):
    if callable_handler is not None:
      return callable_handler(value, type_spec)
    else:
      raise TypeError(
          'Values that are callables have been explicitly disallowed '
          'in this context. If you would like to supply here a function '
          'as a parameter, please construct a computation that contains '
          'this call.')
  elif isinstance(type_spec, computation_types.AbstractType):
    raise TypeError(
        'Abstract types are not supported by the reference executor.')
  elif isinstance(type_spec, computation_types.PlacementType):
    py_typecheck.check_type(value, placement_literals.PlacementLiteral)
    return value
  elif isinstance(type_spec, computation_types.FederatedType):
    if type_spec.all_equal:
      return to_representation_for_type(value, type_spec.member,
                                        callable_handler)
    elif type_spec.placement is not placements.CLIENTS:
      raise TypeError(
          'Unable to determine a valid value representation for a federated '
          'type with non-equal members placed at {}.'.format(
              type_spec.placement))
    elif not isinstance(value, (list, tuple)):
      raise ValueError('Please pass a list or tuple to any function that'
                       ' expects a federated type placed at {};'
                       ' you passed {}'.format(type_spec.placement, value))
    else:
      return [
          to_representation_for_type(v, type_spec.member, callable_handler)
          for v in value
      ]
  else:
    raise NotImplementedError(
        'Unable to determine valid value representation for {} for what '
        'is currently an unsupported TFF type {}.'.format(value, type_spec))