def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None):
  """Wrapper function to plug Tensorflow logic in to TFF framework."""
  del name  # Unused.
  target_fn = function_utils.wrap_as_zero_or_one_arg_callable(
      target_fn, parameter_type, unpack)
  if not type_utils.is_tensorflow_compatible_type(parameter_type):
    raise TypeError('`tf_computation`s can accept only parameter types with '
                    'constituents `SequenceType`, `NamedTupleType` '
                    'and `TensorType`; you have attempted to create one '
                    'with the type {}.'.format(parameter_type))
  ctx_stack = context_stack_impl.context_stack
  comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
      target_fn, parameter_type, ctx_stack)
  return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def check_and_pack_before_aggregate_type_signature(type_spec,
                                                   previously_packed_types):
    """Checks types inferred from `before_aggregate` and packs in `previously_packed_types`.

  After splitting the `after_broadcast` portion of a
  `tff.utils.IterativeProcess` into `before_aggregate` and `after_aggregate`,
  `before_aggregate` should have type signature
  `<<s1,c1>,c2> -> <c5,zero,accumulate,merge,report>`. This
  function validates `c1`, `s1` and `c2` against the existing entries in
  `previously_packed_types`, then packs `s5`, `zero`, `accumulate`, `merge` and
  `report`.

  Args:
    type_spec: The `type_signature` attribute of the `before_aggregate` portion
      of the `tff.utils.IterativeProcess` from which we are looking to extract
      an instance of `canonical_form.CanonicalForm`.
    previously_packed_types: Dict containing the information from `next` and
      `before_broadcast` in the iterative process we are parsing.

  Returns:
    A `dict` packing the types which can be inferred from `type_spec`.

  Raises:
    TypeError: If `type_signature` is incompatible with
    `previously_packed_types`.
  """
    should_raise = False
    if not (isinstance(type_spec, computation_types.FunctionType) and
            isinstance(type_spec.parameter, computation_types.NamedTupleType)):
        should_raise = True
    if not (isinstance(type_spec.parameter[0],
                       computation_types.NamedTupleType)
            and len(type_spec.parameter[0]) == 2 and type_spec.parameter[0][0]
            == previously_packed_types['s1_type'] and type_spec.parameter[0][1]
            == previously_packed_types['c1_type']):
        should_raise = True
    if not (isinstance(type_spec.parameter[1], computation_types.FederatedType)
            and type_spec.parameter[1].placement == placements.CLIENTS
            and type_spec.parameter[1].member
            == previously_packed_types['s2_type'].member):
        should_raise = True
    if not (isinstance(type_spec.result, computation_types.NamedTupleType)
            and len(type_spec.result) == 5 and isinstance(
                type_spec.result[0], computation_types.FederatedType)
            and type_spec.result[0].placement == placements.CLIENTS
            and type_utils.is_tensorflow_compatible_type(type_spec.result[1])
            and type_spec.result[2] == computation_types.FunctionType(
                [type_spec.result[1], type_spec.result[0].member],
                type_spec.result[1])
            and type_spec.result[3] == computation_types.FunctionType(
                [type_spec.result[1], type_spec.result[1]],
                type_spec.result[1])
            and type_spec.result[4].parameter == type_spec.result[1]
            and type_utils.is_tensorflow_compatible_type(
                type_spec.result[4].result)):
        should_raise = True
    if should_raise:
        # TODO(b/121290421): These error messages, and indeed the 'track boolean and
        # raise once' logic of these methods as well, is intended to be provisional
        # and revisited when we've seen the compilation pipeline fail more clearly,
        # or maybe preferably iteratively improved as new failure modes are
        # encountered.
        raise TypeError(
            'Encountered a type error while checking '
            '`before_aggregate`. Expected a type signature of the '
            'form `<<s1,c1>,c2> -> <c5,zero,accumulate,merge,report>`, '
            'where `s1` matches {}, `c1` matches {}, and `c2` matches '
            'the result of broadcasting {}, as defined in '
            '`canonical_form.CanonicalForm`. Found type signature {}.'.format(
                previously_packed_types['s1_type'],
                previously_packed_types['c1_type'],
                previously_packed_types['s2_type'], type_spec))
    newly_determined_types = {}
    c2_type = type_spec.parameter[1]
    newly_determined_types['c2_type'] = c2_type
    c3_type = computation_types.FederatedType(
        [previously_packed_types['c1_type'].member, c2_type.member],
        placements.CLIENTS)
    newly_determined_types['c3_type'] = c3_type
    c5_type = type_spec.result[0]
    zero_type = computation_types.FunctionType(None, type_spec.result[1])
    accumulate_type = type_spec.result[2]
    merge_type = type_spec.result[3]
    report_type = type_spec.result[4]
    newly_determined_types['c5_type'] = c5_type
    newly_determined_types['zero_type'] = zero_type
    newly_determined_types['accumulate_type'] = accumulate_type
    newly_determined_types['merge_type'] = merge_type
    newly_determined_types['report_type'] = report_type
    newly_determined_types['s3_type'] = computation_types.FederatedType(
        report_type.result, placements.SERVER)
    c4_type = computation_types.FederatedType([
        newly_determined_types['c5_type'].member,
        previously_packed_types['c6_type'].member
    ], placements.CLIENTS)
    newly_determined_types['c4_type'] = c4_type
    newly_determined_types['work_type'] = computation_types.FunctionType(
        c3_type.member, c4_type.member)
    return dict(
        itertools.chain(previously_packed_types.items(),
                        newly_determined_types.items()))