def __init__(self, initialize_fn, next_fn):
        super().__init__(initialize_fn, next_fn, next_is_multi_arg=True)

        if not initialize_fn.type_signature.result.is_federated():
            raise errors.TemplateNotFederatedError(
                f'Provided `initialize_fn` must return a federated type, but found '
                f'return type:\n{initialize_fn.type_signature.result}\nTip: If you '
                f'see a collection of federated types, try wrapping the returned '
                f'value in `tff.federated_zip` before returning.')
        next_types = (structure.flatten(next_fn.type_signature.parameter) +
                      structure.flatten(next_fn.type_signature.result))
        if not all([t.is_federated() for t in next_types]):
            offending_types = '\n- '.join(
                [t for t in next_types if not t.is_federated()])
            raise errors.TemplateNotFederatedError(
                f'Provided `next_fn` must be a *federated* computation, that is, '
                f'operate on `tff.FederatedType`s, but found\n'
                f'next_fn with type signature:\n{next_fn.type_signature}\n'
                f'The non-federated types are:\n {offending_types}.')

        if initialize_fn.type_signature.result.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The state controlled by an `DistributionProcess` must be placed at '
                f'the SERVER, but found type: {initialize_fn.type_signature.result}.'
            )
        # Note that state of next_fn being placed at SERVER is now ensured by the
        # assertions in base class which would otherwise raise
        # TemplateStateNotAssignableError.

        next_fn_param = next_fn.type_signature.parameter
        next_fn_result = next_fn.type_signature.result
        if not next_fn_param.is_struct():
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have exactly two input arguments, but found '
                f'the following input type which is not a Struct: {next_fn_param}.'
            )
        if len(next_fn_param) != 2:
            next_param_str = '\n- '.join([str(t) for t in next_fn_param])
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have exactly two input arguments, but found '
                f'{len(next_fn_param)} input arguments:\n{next_param_str}')
        if next_fn_param[1].placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The second input argument of `next_fn` must be placed at SERVER '
                f'but found {next_fn_param[1]}.')

        if next_fn_result.result.placement != placements.CLIENTS:
            raise errors.TemplatePlacementError(
                f'The "result" attribute of return type of `next_fn` must be placed '
                f'at CLIENTS, but found {next_fn_result.result}.')
        if next_fn_result.measurements.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The "measurements" attribute of return type of `next_fn` must be '
                f'placed at SERVER, but found {next_fn_result.measurements}.')
Example #2
0
def _infer_state_type(initialize_result_type, next_parameter_type,
                      next_is_multi_arg):
    """Infers the state type from the `initialize` and `next` types."""
    if next_is_multi_arg is None:
        # `state_type` may be `next_parameter_type` or
        # `next_parameter_type[0]`, depending on which one was assignable from
        # `initialize_result_type`.
        if next_parameter_type.is_assignable_from(initialize_result_type):
            return next_parameter_type
        if (_is_nonempty_struct(next_parameter_type)
                and next_parameter_type[0].is_assignable_from(
                    initialize_result_type)):
            return next_parameter_type[0]
        raise errors.TemplateStateNotAssignableError(
            'The return type of `initialize_fn` must be assignable to either\n'
            'the whole argument to `next_fn` or the first argument to `next_fn`,\n'
            'but found `initialize_fn` return type:\n'
            f'{initialize_result_type}\n'
            'and `next_fn` with whole argument type:\n'
            f'{next_parameter_type}')
    elif next_is_multi_arg:
        if not _is_nonempty_struct(next_parameter_type):
            raise errors.TemplateNextFnNumArgsError(
                'Expected `next_parameter_type` to be a structure type of at least '
                f'length one, but found type:\n{next_parameter_type}')
        if next_parameter_type[0].is_assignable_from(initialize_result_type):
            return next_parameter_type[0]
        raise errors.TemplateStateNotAssignableError(
            'The return type of `initialize_fn` must be assignable to the first\n'
            'argument to `next_fn`, but found `initialize_fn` return type:\n'
            f'{initialize_result_type}\n'
            'and `next_fn` whose first argument type is:\n'
            f'{next_parameter_type}')
    else:
        # `next_is_multi_arg` is `False`
        if next_parameter_type.is_assignable_from(initialize_result_type):
            return next_parameter_type
        raise errors.TemplateStateNotAssignableError(
            'The return type of `initialize_fn` must be assignable to the whole\n'
            'argument to `next_fn`, but found `initialize_fn` return type:\n'
            f'{initialize_result_type}\n'
            'and `next_fn` whose first argument type is:\n'
            f'{next_parameter_type}')
  def __init__(self, initialize_fn, next_fn):
    super().__init__(initialize_fn, next_fn, next_is_multi_arg=True)

    if not initialize_fn.type_signature.result.is_federated():
      raise errors.TemplateNotFederatedError(
          f'Provided `initialize_fn` must return a federated type, but found '
          f'return type:\n{initialize_fn.type_signature.result}\nTip: If you '
          f'see a collection of federated types, try wrapping the returned '
          f'value in `tff.federated_zip` before returning.')
    next_types = (
        structure.flatten(next_fn.type_signature.parameter) +
        structure.flatten(next_fn.type_signature.result))
    if not all([t.is_federated() for t in next_types]):
      offending_types = '\n- '.join(
          [t for t in next_types if not t.is_federated()])
      raise errors.TemplateNotFederatedError(
          f'Provided `next_fn` must be a *federated* computation, that is, '
          f'operate on `tff.FederatedType`s, but found\n'
          f'next_fn with type signature:\n{next_fn.type_signature}\n'
          f'The non-federated types are:\n {offending_types}.')

    if initialize_fn.type_signature.result.placement != placements.SERVER:
      raise errors.TemplatePlacementError(
          f'The state controlled by a `ClientWorkProcess` must be placed at '
          f'the SERVER, but found type: {initialize_fn.type_signature.result}.')
    # Note that state of next_fn being placed at SERVER is now ensured by the
    # assertions in base class which would otherwise raise
    # TemplateStateNotAssignableError.

    next_fn_param = next_fn.type_signature.parameter
    if not next_fn_param.is_struct():
      raise errors.TemplateNextFnNumArgsError(
          f'The `next_fn` must have exactly three input arguments, but found '
          f'the following input type which is not a Struct: {next_fn_param}.')
    if len(next_fn_param) != 3:
      next_param_str = '\n- '.join([str(t) for t in next_fn_param])
      raise errors.TemplateNextFnNumArgsError(
          f'The `next_fn` must have exactly three input arguments, but found '
          f'{len(next_fn_param)} input arguments:\n{next_param_str}')
    model_weights_param = next_fn_param[1]
    client_data_param = next_fn_param[2]
    if model_weights_param.placement != placements.CLIENTS:
      raise errors.TemplatePlacementError(
          f'The second input argument of `next_fn` must be placed at CLIENTS '
          f'but found {model_weights_param}.')
    if (not model_weights_param.member.is_struct_with_python() or
        model_weights_param.member.python_container
        is not model_utils.ModelWeights):
      raise ModelWeightsTypeError(
          f'The second input argument of `next_fn` must have the '
          f'`tff.learning.ModelWeights` container but found '
          f'{model_weights_param}')
    if client_data_param.placement != placements.CLIENTS:
      raise errors.TemplatePlacementError(
          f'The third input argument of `next_fn` must be placed at CLIENTS '
          f'but found {client_data_param}.')
    if not client_data_param.member.is_sequence():
      raise ClientDataTypeError(
          f'The third input argument of `next_fn` must be a sequence but found '
          f'{client_data_param}.')

    next_fn_result = next_fn.type_signature.result
    if (not next_fn_result.result.is_federated() or
        next_fn_result.result.placement != placements.CLIENTS):
      raise errors.TemplatePlacementError(
          f'The "result" attribute of the return type of `next_fn` must be '
          f'placed at CLIENTS, but found {next_fn_result.result}.')
    if (not next_fn_result.result.member.is_struct_with_python() or
        next_fn_result.result.member.python_container is not ClientResult):
      raise ClientResultTypeError(
          f'The "result" attribute of the return type of `next_fn` must have '
          f'the `ClientResult` container, but found {next_fn_result.result}.')
    if not model_weights_param.member.trainable.is_assignable_from(
        next_fn_result.result.member.update):
      raise ClientResultTypeError(
          f'The "update" attribute of returned `ClientResult` must match '
          f'the "trainable" attribute of the `tff.learning.ModelWeights` '
          f'expected as second input argument of the `next_fn`. Found:\n'
          f'Second input argument: {model_weights_param.member.trainable}\n'
          f'Update attribute of result: {next_fn_result.result.member.update}.')
    if next_fn_result.measurements.placement != placements.SERVER:
      raise errors.TemplatePlacementError(
          f'The "measurements" attribute of return type of `next_fn` must be '
          f'placed at SERVER, but found {next_fn_result.measurements}.')
Example #4
0
    def __init__(self, initialize_fn: computation_base.Computation,
                 next_fn: computation_base.Computation,
                 report_fn: computation_base.Computation):
        """Creates a `tff.templates.AggregationProcess`.

    Args:
      initialize_fn: A no-arg `tff.Computation` that creates the initial state
        of the learning process.
      next_fn: A `tff.Computation` that defines an iterated function. Given that
        `initialize_fn` returns a type `S@SERVER`, the `next_fn` must return a
        `LearningProcessOutput` where the `state` attribute matches the type
        `S@SERVER`, and accepts two argument of types `S@SERVER` and
        `{D*}@CLIENTS`.
     report_fn: A `tff.Computation` that accepts an input `S` where the output
       of `initialize_fn` is of type `S@SERVER`. This computation is used to
       create a representation of the state that can be used for downstream
       tasks without requiring access to the entire server state. For example,
       `report_fn` could be used to extract model weights for computing metrics
       on held-out data.

    Raises:
      TypeError: If `initialize_fn` and `next_fn` are not instances of
        `tff.Computation`.
      TemplateInitFnParamNotEmptyError: If `initialize_fn` has any input
        arguments.
      TemplateStateNotAssignableError: If the `state` returned by either
        `initialize_fn` or `next_fn` is not assignable to the first input
        argument of `next_fn`.
      TemplateNextFnNumArgsError: If `next_fn` does not have at exactly two
        input arguments.
      LearningProcessPlacementError: If the placements of `initialize_fn` and
        `next_fn` do not match the expected type placements.
      LearningProcessOutputError: If `next_fn` does not return a
        `LearningProcessOutput`.
      LearningProcessSequenceTypeError: If the second argument to `next_fn` is
        not a sequence type.
    """
        super().__init__(initialize_fn, next_fn)

        init_fn_result = initialize_fn.type_signature.result
        if init_fn_result.placement != placements.SERVER:
            raise LearningProcessPlacementError(
                f'The result of `initialize_fn` must be placed at `SERVER` but found '
                f'placement {init_fn_result.placement}.')

        next_result_type = next_fn.type_signature.result
        if not (isinstance(next_result_type,
                           computation_types.StructWithPythonType) and
                next_result_type.python_container is LearningProcessOutput):
            raise LearningProcessOutputError(
                f'The `next_fn` of a `LearningProcess` must return a '
                f'`LearningProcessOutput` object, but returns {next_result_type!r}'
            )

        # We perform a more strict type check on the inputs to `next_fn` than in the
        # base class.
        next_fn_param = next_fn.type_signature.parameter
        if not next_fn_param.is_struct() or len(next_fn_param) != 2:
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have two input arguments, but found an input '
                f'of type {next_fn_param}.')
        if next_fn_param[1].placement != placements.CLIENTS:
            raise LearningProcessPlacementError(
                f'The second input argument of `next_fn` must be placed at `CLIENTS`,'
                f' but found placement {next_fn_param[1].placement}.')
        if not next_fn_param[1].member.is_sequence():
            raise LearningProcessSequenceTypeError(
                f'The member type of the second input argument to `next_fn` must be a'
                f' `tff.SequenceType` but found {next_fn_param[1].member} instead.'
            )

        next_fn_result = next_fn.type_signature.result
        if next_fn_result.metrics.placement != placements.SERVER:
            raise LearningProcessPlacementError(
                f'The result of `next_fn` must be placed at `SERVER` but found '
                f'placement {next_fn_result.metrics.placement} for `metrics`.')

        py_typecheck.check_type(report_fn, computation_base.Computation)

        report_fn_type = report_fn.type_signature
        if report_fn_type.is_federated():
            raise LearningProcessPlacementError(
                f'The `report_fn` must not be a federated computation, '
                f'but found `report_fn` with type signature:\n'
                f'{report_fn_type}')

        report_fn_param = report_fn.type_signature.parameter
        state_type_without_placement = initialize_fn.type_signature.result.member
        if not report_fn_param.is_assignable_from(
                state_type_without_placement):
            raise ReportFnTypeSignatureError(
                f'The input type of `report_fn` must be assignable from '
                f'the member type of the output of `initialize_fn`, but found input '
                f'type {report_fn_param}, which is not assignable from '
                f'{state_type_without_placement}.')

        self._report_fn = report_fn
    def __init__(self, initialize_fn: computation_base.Computation,
                 next_fn: computation_base.Computation):
        """Creates a `tff.templates.AggregationProcess`.

    Args:
      initialize_fn: A no-arg `tff.Computation` that returns the initial state
        of the aggregation process. The returned state must be a server-placed
        federated value. Let the type of this state be called `S@SERVER`.
      next_fn: A `tff.Computation` that represents the iterated function.
        `next_fn` must accept at least two arguments, the first of which is of
        state type `S@SERVER` and the second of which is client-placed data of
        type `V@CLIENTS`. `next_fn` must return a  `MeasuredProcessOutput` where
        the `state` attribute matches the type `S@SERVER` and the `result`
        attribute matches type `V@SERVER`.

    Raises:
      TypeError: If `initialize_fn` and `next_fn` are not instances of
        `tff.Computation`.
      TemplateInitFnParamNotEmptyError: If `initialize_fn` has any input
        arguments.
      TemplateStateNotAssignableError: If the `state` returned by either
        `initialize_fn` or `next_fn` is not assignable to the first input
        argument of `next_fn`.
      TemplateNotMeasuredProcessOutputError: If `next_fn` does not return a
        `MeasuredProcessOutput`.
      TemplateNextFnNumArgsError: If `next_fn` does not have at least two
        input arguments.
      AggregationNotFederatedError: If `initialize_fn` and `next_fn` are not
        computations operating on federated types.
      AggregationPlacementError: If the placements of `initialize_fn` and
        `next_fn` are not matching the expected type signature.
    """
        # Calling super class __init__ first ensures that
        # next_fn.type_signature.result is a `MeasuredProcessOutput`, make our
        # validation here easier as that must be true.
        super().__init__(initialize_fn, next_fn, next_is_multi_arg=True)

        if not initialize_fn.type_signature.result.is_federated():
            raise AggregationNotFederatedError(
                f'Provided `initialize_fn` must return a federated type, but found '
                f'return type:\n{initialize_fn.type_signature.result}\nTip: If you '
                f'see a collection of federated types, try wrapping the returned '
                f'value in `tff.federated_zip` before returning.')
        next_types = (structure.flatten(next_fn.type_signature.parameter) +
                      structure.flatten(next_fn.type_signature.result))
        non_federated_types = [t for t in next_types if not t.is_federated()]
        if non_federated_types:
            offending_types_str = '\n- '.join(
                str(t) for t in non_federated_types)
            raise AggregationNotFederatedError(
                f'Provided `next_fn` must both be a *federated* computations, that '
                f'is, operate on `tff.FederatedType`s, but found\n'
                f'next_fn with type signature:\n{next_fn.type_signature}\n'
                f'The non-federated types are:\n {offending_types_str}.')

        if initialize_fn.type_signature.result.placement != placements.SERVER:
            raise AggregationPlacementError(
                f'The state controlled by an `AggregationProcess` must be placed at '
                f'the SERVER, but found type: {initialize_fn.type_signature.result}.'
            )
        # Note that state of next_fn being placed at SERVER is now ensured by the
        # assertions in base class which would otherwise raise
        # errors.TemplateStateNotAssignableError.

        next_fn_param = next_fn.type_signature.parameter
        next_fn_result = next_fn.type_signature.result
        if len(next_fn_param) < 2:
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have at least two input arguments, but found '
                f'the following input type: {next_fn_param}.')

        if next_fn_param[_INPUT_PARAM_INDEX].placement != placements.CLIENTS:
            raise AggregationPlacementError(
                f'The second input argument of `next_fn` must be placed at CLIENTS '
                f'but found {next_fn_param[_INPUT_PARAM_INDEX]}.')

        if next_fn_result.result.placement != placements.SERVER:
            raise AggregationPlacementError(
                f'The "result" attribute of return type of `next_fn` must be placed '
                f'at SERVER, but found {next_fn_result.result}.')
        if next_fn_result.measurements.placement != placements.SERVER:
            raise AggregationPlacementError(
                f'The "measurements" attribute of return type of `next_fn` must be '
                f'placed at SERVER, but found {next_fn_result.measurements}.')
Example #6
0
    def __init__(self, initialize_fn: computation_base.Computation,
                 next_fn: computation_base.Computation,
                 get_model_weights: computation_base.Computation,
                 set_model_weights: computation_base.Computation):
        """Creates a `tff.templates.AggregationProcess`.

    Args:
      initialize_fn: A no-arg `tff.Computation` that creates the initial state
        of the learning process.
      next_fn: A `tff.Computation` that defines an iterated function. Given that
        `initialize_fn` returns a type `S@SERVER`, the `next_fn` must return a
        `LearningProcessOutput` where the `state` attribute is assignable from
        values with type `S@SERVER`, and accepts two arguments with types
        assignable from values with type `S@SERVER` and `{D*}@CLIENTS`.
      get_model_weights: A `tff.Computation` that accepts an input `S` whose
        type is assignable from the result of `init_fn`. This computation is
        used to create a representation of the state that can be used for
        downstream tasks without requiring access to the entire server state.
        For example, `get_model_weights` could be used to extract model weights
        suitable for computing evaluation metrics on held-out data.
      set_model_weights: A `tff.Computation` that accepts two inputs `S` and `M`
        where the type of `S` is assignable from values with the type returned
        by `init_fn` and `M` is a representation of the model weights stored in
        `S`. This updates the model weights representation within the state with
        the incoming value and returns a new value of type `S`.

    Raises:
      TypeError: If `initialize_fn` and `next_fn` are not instances of
        `tff.Computation`.
      TemplateInitFnParamNotEmptyError: If `initialize_fn` has any input
        arguments.
      TemplateStateNotAssignableError: If the `state` returned by either
        `initialize_fn` or `next_fn` is not assignable to the first input
        argument of `next_fn`.
      TemplateNextFnNumArgsError: If `next_fn` does not have at exactly two
        input arguments.
      LearningProcessPlacementError: If the placements of `initialize_fn` and
        `next_fn` do not match the expected type placements.
      LearningProcessOutputError: If `next_fn` does not return a
        `LearningProcessOutput`.
      LearningProcessSequenceTypeError: If the second argument to `next_fn` is
        not a sequence type.
    """
        super().__init__(initialize_fn, next_fn)

        init_fn_result = initialize_fn.type_signature.result
        if init_fn_result.placement != placements.SERVER:
            raise LearningProcessPlacementError(
                f'The result of `initialize_fn` must be placed at `SERVER` but found '
                f'placement {init_fn_result.placement}.')

        next_result_type = next_fn.type_signature.result
        if not (isinstance(next_result_type,
                           computation_types.StructWithPythonType) and
                next_result_type.python_container is LearningProcessOutput):
            raise LearningProcessOutputError(
                f'The `next_fn` of a `LearningProcess` must return a '
                f'`LearningProcessOutput` object, but returns {next_result_type!r}'
            )
        # We perform a more strict type check on the inputs to `next_fn` than in the
        # base class.
        next_fn_param = next_fn.type_signature.parameter
        if not next_fn_param.is_struct() or len(next_fn_param) != 2:
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have two input arguments, but found an input '
                f'of type {next_fn_param}.')
        if next_fn_param[1].placement != placements.CLIENTS:
            raise LearningProcessPlacementError(
                f'The second input argument of `next_fn` must be placed at `CLIENTS`,'
                f' but found placement {next_fn_param[1].placement}.')

        def is_allowed_client_data_type(
                type_spec: computation_types.Type) -> bool:
            """Returns `True` if the type is a valid client dataset type."""
            if type_spec.is_sequence():
                return type_analysis.is_tensorflow_compatible_type(
                    type_spec.element)
            elif type_spec.is_struct():
                return all(
                    is_allowed_client_data_type(element_type)
                    for element_type in type_spec.children())
            else:
                return False

        if not is_allowed_client_data_type(next_fn_param[1].member):
            raise LearningProcessSequenceTypeError(
                f'The member type of the second input argument to `next_fn` must be a'
                f' `tff.SequenceType` or a nested `tff.StructType` of sequence types '
                f'but found {next_fn_param[1].member} instead.')
        next_fn_result = next_fn.type_signature.result
        if next_fn_result.metrics.placement != placements.SERVER:
            raise LearningProcessPlacementError(
                f'The result of `next_fn` must be placed at `SERVER` but found '
                f'placement {next_fn_result.metrics.placement} for `metrics`.')

        py_typecheck.check_type(get_model_weights,
                                computation_base.Computation)
        get_model_weights_type = get_model_weights.type_signature
        get_model_weights_param = get_model_weights_type.parameter
        next_fn_state_param = next_fn.type_signature.parameter[0].member
        if not get_model_weights_param.is_equivalent_to(next_fn_state_param):
            raise GetModelWeightsTypeSignatureError(
                f'The input type of `get_model_weights` must be assignable from '
                f'the member type of the output of `initialize_fn`, but found input '
                f'type {get_model_weights_param}, which is not equivalent to '
                f'{next_fn_state_param}.')
        self._get_model_weights = get_model_weights

        py_typecheck.check_type(set_model_weights,
                                computation_base.Computation)
        set_model_weights_type = set_model_weights.type_signature
        set_model_weights_state_param = set_model_weights_type.parameter[0]
        if not set_model_weights_state_param.is_equivalent_to(
                next_fn_state_param):
            raise SetModelWeightsTypeSignatureError(
                f'The input type of `set_model_weights` must be assignable from '
                f'the member type of the output of `initialize_fn`, but found input '
                f'type {set_model_weights_state_param}, which is not equivalent to '
                f'{next_fn_state_param}.')
        set_model_weights_result = set_model_weights_type.result
        if not next_fn_state_param.is_assignable_from(
                set_model_weights_result):
            raise SetModelWeightsTypeSignatureError(
                f'The output type of `set_model_weights` must be assignable to '
                f'the first parameter of `next_fn`, but found input '
                f'type {set_model_weights_result}, which is not assignable to; '
                f'{next_fn_state_param}.')
        self._set_model_weights = set_model_weights
Example #7
0
    def __init__(self, initialize_fn, next_fn):
        super().__init__(initialize_fn, next_fn, next_is_multi_arg=True)

        if not initialize_fn.type_signature.result.is_federated():
            raise errors.TemplateNotFederatedError(
                f'Provided `initialize_fn` must return a federated type, but found '
                f'return type:\n{initialize_fn.type_signature.result}\nTip: If you '
                f'see a collection of federated types, try wrapping the returned '
                f'value in `tff.federated_zip` before returning.')
        next_types = (structure.flatten(next_fn.type_signature.parameter) +
                      structure.flatten(next_fn.type_signature.result))
        if not all([t.is_federated() for t in next_types]):
            offending_types = '\n- '.join(
                [t for t in next_types if not t.is_federated()])
            raise errors.TemplateNotFederatedError(
                f'Provided `next_fn` must be a *federated* computation, that is, '
                f'operate on `tff.FederatedType`s, but found\n'
                f'next_fn with type signature:\n{next_fn.type_signature}\n'
                f'The non-federated types are:\n {offending_types}.')

        if initialize_fn.type_signature.result.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The state controlled by a `ClientWorkProcess` must be placed at '
                f'the SERVER, but found type: {initialize_fn.type_signature.result}.'
            )
        # Note that state of next_fn being placed at SERVER is now ensured by the
        # assertions in base class which would otherwise raise
        # TemplateStateNotAssignableError.

        next_fn_param = next_fn.type_signature.parameter
        if not next_fn_param.is_struct():
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have exactly three input arguments, but found '
                f'the following input type which is not a Struct: {next_fn_param}.'
            )
        if len(next_fn_param) != 3:
            next_param_str = '\n- '.join([str(t) for t in next_fn_param])
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have exactly three input arguments, but found '
                f'{len(next_fn_param)} input arguments:\n{next_param_str}')
        second_next_param = next_fn_param[1]
        client_data_param = next_fn_param[2]
        if second_next_param.placement != placements.CLIENTS:
            raise errors.TemplatePlacementError(
                f'The second input argument of `next_fn` must be placed at CLIENTS '
                f'but found {second_next_param}.')
        if client_data_param.placement != placements.CLIENTS:
            raise errors.TemplatePlacementError(
                f'The third input argument of `next_fn` must be placed at CLIENTS '
                f'but found {client_data_param}.')

        def is_allowed_client_data_type(
                type_spec: computation_types.Type) -> bool:
            if type_spec.is_sequence():
                return type_analysis.is_tensorflow_compatible_type(
                    type_spec.element)
            elif type_spec.is_struct():
                return all(
                    is_allowed_client_data_type(element_type)
                    for element_type in type_spec.children())
            else:
                return False

        if not is_allowed_client_data_type(client_data_param.member):
            raise ClientDataTypeError(
                f'The third input argument of `next_fn` must be a sequence or '
                f'a structure of squences, but found {client_data_param}.')

        next_fn_result = next_fn.type_signature.result
        if (not next_fn_result.result.is_federated()
                or next_fn_result.result.placement != placements.CLIENTS):
            raise errors.TemplatePlacementError(
                f'The "result" attribute of the return type of `next_fn` must be '
                f'placed at CLIENTS, but found {next_fn_result.result}.')
        if (not next_fn_result.result.member.is_struct_with_python()
                or next_fn_result.result.member.python_container
                is not ClientResult):
            raise ClientResultTypeError(
                f'The "result" attribute of the return type of `next_fn` must have '
                f'the `ClientResult` container, but found {next_fn_result.result}.'
            )
        if next_fn_result.measurements.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The "measurements" attribute of return type of `next_fn` must be '
                f'placed at SERVER, but found {next_fn_result.measurements}.')