コード例 #1
0
def apply(transform_fn: computation_base.Computation,
          arg_process: EstimationProcess):
    """Builds an `EstimationProcess` by applying `transform_fn` to `arg_process`.

  Args:
    transform_fn: A `computation_base.Computation` to apply to the estimate of
      the arg_process.
    arg_process: An `EstimationProcess` to which the transformation will be
      applied.

  Returns:
    An estimation process that applies `transform_fn` to the result of calling
      `arg_process.get_estimate`.
  """
    py_typecheck.check_type(transform_fn, computation_base.Computation)
    py_typecheck.check_type(arg_process, EstimationProcess)

    arg_process_estimate_type = arg_process.get_estimate.type_signature.result
    transform_fn_arg_type = transform_fn.type_signature.parameter

    if not transform_fn_arg_type.is_assignable_from(arg_process_estimate_type):
        raise errors.TemplateStateNotAssignableError(
            f'The return type of `get_estimate` of `arg_process` must be '
            f'assignable to the input argument of `transform_fn`, but '
            f'`get_estimate` returns type:\n{arg_process_estimate_type}\n'
            f'and the argument of `transform_fn` is:\n'
            f'{transform_fn_arg_type}')

    transformed_estimate_fn = computations.tf_computation(
        lambda state: transform_fn(arg_process.get_estimate(state)),
        arg_process.state_type)

    return EstimationProcess(initialize_fn=arg_process.initialize,
                             next_fn=arg_process.next,
                             get_estimate_fn=transformed_estimate_fn)
コード例 #2
0
    def __init__(self,
                 initialize_fn: computation_base.Computation,
                 next_fn: computation_base.Computation,
                 next_is_multi_arg: Optional[bool] = None):
        """Creates a `tff.templates.IterativeProcess`.

    Args:
      initialize_fn: A no-arg `tff.Computation` that returns the initial state
        of the iterative process. Let the type of this state be called `S`.
      next_fn: A `tff.Computation` that represents the iterated function. The
        first or only argument must match the state type `S`. The first or only
        return value must also match state type `S`.
      next_is_multi_arg: An optional boolean indicating that `next_fn` will
        receive more than just the state argument (if `True`) or only the state
        argument (if `False`). This parameter is primarily used to provide
        better error messages.

    Raises:
      TypeError: If `initialize_fn` and `next_fn` are not instances of
        `tff.Computation`.
      TemplateInitFnParamNotEmptyError: If `initialize_fn` has any input
        arguments.
      TemplateStateNotAssignableError: If the `state` returned by either
        `initialize_fn` or `next_fn` is not assignable to the first input
        argument of `next_fn`.
    """
        py_typecheck.check_type(initialize_fn, computation_base.Computation)
        if initialize_fn.type_signature.parameter is not None:
            raise errors.TemplateInitFnParamNotEmptyError(
                f'Provided `initialize_fn` must be a no-arg function, but found '
                f'input argument(s) {initialize_fn.type_signature.parameter}.')
        initialize_result_type = initialize_fn.type_signature.result

        py_typecheck.check_type(next_fn, computation_base.Computation)
        next_parameter_type = next_fn.type_signature.parameter
        state_type = _infer_state_type(initialize_result_type,
                                       next_parameter_type, next_is_multi_arg)

        next_result_type = next_fn.type_signature.result
        if state_type.is_assignable_from(next_result_type):
            # The whole return value is the state type
            pass
        elif (_is_nonempty_struct(next_result_type)
              and state_type.is_assignable_from(next_result_type[0])):
            # The first return value is state type
            pass
        else:
            raise errors.TemplateStateNotAssignableError(
                f'The first return argument of `next_fn` must be '
                f'assignable to its first input argument, but found\n'
                f'`next_fn` which returns type:\n{next_result_type}\n'
                f'which does not match its first input argument:\n{state_type}'
            )

        self._state_type = state_type
        self._initialize_fn = initialize_fn
        self._next_fn = next_fn
コード例 #3
0
def _infer_state_type(initialize_result_type, next_parameter_type,
                      next_is_multi_arg):
    """Infers the state type from the `initialize` and `next` types."""
    if next_is_multi_arg is None:
        # `state_type` may be `next_parameter_type` or
        # `next_parameter_type[0]`, depending on which one was assignable from
        # `initialize_result_type`.
        if next_parameter_type.is_assignable_from(initialize_result_type):
            return next_parameter_type
        if (_is_nonempty_struct(next_parameter_type)
                and next_parameter_type[0].is_assignable_from(
                    initialize_result_type)):
            return next_parameter_type[0]
        raise errors.TemplateStateNotAssignableError(
            'The return type of `initialize_fn` must be assignable to either\n'
            'the whole argument to `next_fn` or the first argument to `next_fn`,\n'
            'but found `initialize_fn` return type:\n'
            f'{initialize_result_type}\n'
            'and `next_fn` with whole argument type:\n'
            f'{next_parameter_type}')
    elif next_is_multi_arg:
        if not _is_nonempty_struct(next_parameter_type):
            raise errors.TemplateNextFnNumArgsError(
                'Expected `next_parameter_type` to be a structure type of at least '
                f'length one, but found type:\n{next_parameter_type}')
        if next_parameter_type[0].is_assignable_from(initialize_result_type):
            return next_parameter_type[0]
        raise errors.TemplateStateNotAssignableError(
            'The return type of `initialize_fn` must be assignable to the first\n'
            'argument to `next_fn`, but found `initialize_fn` return type:\n'
            f'{initialize_result_type}\n'
            'and `next_fn` whose first argument type is:\n'
            f'{next_parameter_type}')
    else:
        # `next_is_multi_arg` is `False`
        if next_parameter_type.is_assignable_from(initialize_result_type):
            return next_parameter_type
        raise errors.TemplateStateNotAssignableError(
            'The return type of `initialize_fn` must be assignable to the whole\n'
            'argument to `next_fn`, but found `initialize_fn` return type:\n'
            f'{initialize_result_type}\n'
            'and `next_fn` whose first argument type is:\n'
            f'{next_parameter_type}')
コード例 #4
0
    def __init__(self, initialize_fn: computation_base.Computation,
                 next_fn: computation_base.Computation):
        """Creates a `tff.templates.MeasuredProcess`.

    Args:
      initialize_fn: A no-arg `tff.Computation` that creates the initial state
        of the measured process.
      next_fn: A `tff.Computation` that defines an iterated function. If
        `initialize_fn` returns a non-federated type `S`, then `next_fn` must
        return a `MeasuredProcessOutput` where the `state` attribute matches the
        non-federated type `S`, and accept either a single argument of
        non-federated type `S` or multiple arguments where the first argument
        must be of non-federated type `S`.

    Raises:
      TypeError: If `initialize_fn` and `next_fn` are not instances of
        `tff.Computation`.
      TemplateInitFnParamNotEmptyError: If `initialize_fn` has any input
        arguments.
      TemplateStateNotAssignableError: If the `state` returned by either
        `initialize_fn` or `next_fn` is not assignable to the first input
        argument of `next_fn`.
      TemplateNotMeasuredProcessOutputError: If `next_fn` does not return a
        `MeasuredProcessOutput`.
    """
        super().__init__(initialize_fn, next_fn)
        next_result_type = next_fn.type_signature.result
        if not (isinstance(next_result_type,
                           computation_types.StructWithPythonType) and
                next_result_type.python_container is MeasuredProcessOutput):
            raise errors.TemplateNotMeasuredProcessOutputError(
                f'The `next_fn` of a `MeasuredProcess` must return a '
                f'`MeasuredProcessOutput` object, but returns {next_result_type!r}'
            )

        # Perform a more strict type check on state than the base class. Base class
        # ensures that state returned by initialize_fn is accepted as input argument
        # of next_fn, and that this is in the returned structure. For
        # MeasuredProcess, this explicitly needs to be in the state attribute. See
        # `test_measured_process_output_as_state_raises` for an example.
        if next_fn.type_signature.parameter.is_assignable_from(
                initialize_fn.type_signature.result):
            state_type = next_fn.type_signature.parameter
        else:
            state_type = next_fn.type_signature.parameter[0]
        if not state_type.is_assignable_from(
                next_fn.type_signature.result.state):
            raise errors.TemplateStateNotAssignableError(
                f'The state attrubute of returned MeasuredProcessOutput must be '
                f'assignable to its first input argument, but found\n'
                f'`next_fn` which returns MeasuredProcessOutput with state attribute '
                f'of type:\n{next_result_type}\n'
                f'which does not match its first input argument:\n{state_type}'
            )
コード例 #5
0
    def __init__(self,
                 initialize_fn: computation_base.Computation,
                 next_fn: computation_base.Computation,
                 next_is_multi_arg: Optional[bool] = None):
        """Creates a `tff.templates.MeasuredProcess`.

    Args:
      initialize_fn: A no-arg `tff.Computation` that returns the initial state
        of the measured process. Let the type of this state be called `S`.
      next_fn: A `tff.Computation` that represents the iterated function. The
        first or only argument must match the state type `S`. The return value
        must be a `MeasuredProcessOutput` whose `state` member matches the
        state type `S`.
      next_is_multi_arg: An optional boolean indicating that `next_fn` will
        receive more than just the state argument (if `True`) or only the state
        argument (if `False`). This parameter is primarily used to provide
        better error messages.

    Raises:
      TypeError: If `initialize_fn` and `next_fn` are not instances of
        `tff.Computation`.
      TemplateInitFnParamNotEmptyError: If `initialize_fn` has any input
        arguments.
      TemplateStateNotAssignableError: If the `state` returned by either
        `initialize_fn` or `next_fn` is not assignable to the first input
        argument of `next_fn`.
      TemplateNotMeasuredProcessOutputError: If `next_fn` does not return a
        `MeasuredProcessOutput`.
    """
        super().__init__(initialize_fn, next_fn, next_is_multi_arg)
        next_result_type = next_fn.type_signature.result
        if not (isinstance(next_result_type,
                           computation_types.StructWithPythonType) and
                next_result_type.python_container is MeasuredProcessOutput):
            raise errors.TemplateNotMeasuredProcessOutputError(
                f'The `next_fn` of a `MeasuredProcess` must return a '
                f'`MeasuredProcessOutput` object, but returns {next_result_type!r}'
            )

        # Perform a more strict type check on state than the base class. Base class
        # ensures that state returned by initialize_fn is accepted as input argument
        # of next_fn, and that this is in the returned structure. For
        # MeasuredProcess, this explicitly needs to be in the state attribute. See
        # `test_measured_process_output_as_state_raises` for an example.
        state_type = self.state_type
        if not state_type.is_assignable_from(
                next_fn.type_signature.result.state):
            raise errors.TemplateStateNotAssignableError(
                f'The state attrubute of returned MeasuredProcessOutput must be '
                f'assignable to its first input argument, but found\n'
                f'`next_fn` which returns MeasuredProcessOutput with state attribute '
                f'of type:\n{next_result_type}\n'
                f'which does not match its first input argument:\n{state_type}'
            )
コード例 #6
0
    def __init__(self, initialize_fn: computation_base.Computation,
                 next_fn: computation_base.Computation,
                 get_estimate_fn: computation_base.Computation):
        super().__init__(initialize_fn, next_fn)

        py_typecheck.check_type(get_estimate_fn, computation_base.Computation)
        estimate_fn_arg_type = get_estimate_fn.type_signature.parameter
        if not estimate_fn_arg_type.is_assignable_from(self.state_type):
            raise errors.TemplateStateNotAssignableError(
                f'The state type of the process must be assignable to the '
                f'input argument of `get_estimate_fn`, but the state type is: '
                f'{self.state_type}\n'
                f'and the argument of `get_estimate_fn` is:\n'
                f'{estimate_fn_arg_type}')

        self._get_estimate_fn = get_estimate_fn
コード例 #7
0
    def __init__(self,
                 initialize_fn: computation_base.Computation,
                 next_fn: computation_base.Computation,
                 report_fn: computation_base.Computation,
                 next_is_multi_arg: Optional[bool] = None):
        """Creates a `tff.templates.EstimationProcess`.

    Args:
      initialize_fn: A no-arg `tff.Computation` that returns the initial state
        of the estimation process. Let the type of this state be called `S`.
      next_fn: A `tff.Computation` that represents the iterated function. The
        first or only argument must match the state type `S`. The first or only
        return value must also match state type `S`.
      report_fn: A `tff.Computation` that represents the estimation based on
        state. Its input argument must match the state type `S`.
      next_is_multi_arg: An optional boolean indicating that `next_fn` will
        receive more than just the state argument (if `True`) or only the state
        argument (if `False`). This parameter is primarily used to provide
        better error messages.

    Raises:
      TypeError: If `initialize_fn`, `next_fn` and `report_fn` are not
        instances of `tff.Computation`.
      TemplateInitFnParamNotEmptyError: If `initialize_fn` has any input
        arguments.
      TemplateStateNotAssignableError: If the `state` returned by either
        `initialize_fn` or `next_fn` is not assignable to the first input
        argument of `next_fn` and `report_fn`.
    """
        super().__init__(initialize_fn,
                         next_fn,
                         next_is_multi_arg=next_is_multi_arg)

        py_typecheck.check_type(report_fn, computation_base.Computation)
        report_fn_arg_type = report_fn.type_signature.parameter
        if not report_fn_arg_type.is_assignable_from(self.state_type):
            raise errors.TemplateStateNotAssignableError(
                f'The state type of the process must be assignable to the '
                f'input argument of `report_fn`, but the state type is: '
                f'{self.state_type}\n'
                f'and the argument of `report_fn` is:\n'
                f'{report_fn_arg_type}')

        self._report_fn = report_fn
コード例 #8
0
    def __init__(self, initialize_fn: computation_base.Computation,
                 next_fn: computation_base.Computation,
                 report_fn: computation_base.Computation):
        """Creates a `tff.templates.EstimationProcess`.

    Args:
      initialize_fn: A no-arg `tff.Computation` that creates the initial state
        of the computation.
      next_fn: A `tff.Computation` that represents the iterated function. If
        `initialize_fn` returns a type `T`, then `next_fn` must either return a
        type `U` which is compatible with `T` or multiple values where the first
        type is `U`, and accept either a single argument of type `U` or multiple
        arguments where the first argument must be of type `U`.
      report_fn: A `tff.Computation` that represents the estimation based on
        state. Its input argument must be assignable from return type of
        `initialize_fn`.

    Raises:
      TypeError: If `initialize_fn`, `next_fn` and `report_fn` are not
        instances of `tff.Computation`.
      TemplateInitFnParamNotEmptyError: If `initialize_fn` has any input
        arguments.
      TemplateStateNotAssignableError: If the `state` returned by either
        `initialize_fn` or `next_fn` is not assignable to the first input
        argument of `next_fn` and `report_fn`.
    """
        super().__init__(initialize_fn, next_fn)

        py_typecheck.check_type(report_fn, computation_base.Computation)
        report_fn_arg_type = report_fn.type_signature.parameter
        if not report_fn_arg_type.is_assignable_from(self.state_type):
            raise errors.TemplateStateNotAssignableError(
                f'The state type of the process must be assignable to the '
                f'input argument of `report_fn`, but the state type is: '
                f'{self.state_type}\n'
                f'and the argument of `report_fn` is:\n'
                f'{report_fn_arg_type}')

        self._report_fn = report_fn
コード例 #9
0
  def __init__(self, initialize_fn: computation_base.Computation,
               next_fn: computation_base.Computation):
    """Creates a `tff.templates.IterativeProcess`.

    Args:
      initialize_fn: A no-arg `tff.Computation` that creates the initial state
        of the computation.
      next_fn: A `tff.Computation` that represents the iterated function. If
        `initialize_fn` returns a type `T`, then `next_fn` must either return a
        type `U` which is compatible with `T` or multiple values where the first
        type is `U`, and accept either a single argument of type `U` or multiple
        arguments where the first argument must be of type `U`.

    Raises:
      TypeError: If `initialize_fn` and `next_fn` are not instances of
        `tff.Computation`.
      TemplateInitFnParamNotEmptyError: If `initialize_fn` has any input
        arguments.
      TemplateStateNotAssignableError: If the `state` returned by either
        `initialize_fn` or `next_fn` is not assignable to the first input
        argument of `next_fn`.
    """
    py_typecheck.check_type(initialize_fn, computation_base.Computation)
    if initialize_fn.type_signature.parameter is not None:
      raise errors.TemplateInitFnParamNotEmptyError(
          f'Provided `initialize_fn` must be a no-arg function, but found '
          f'input argument(s) {initialize_fn.type_signature.parameter}.')
    initialize_result_type = initialize_fn.type_signature.result

    py_typecheck.check_type(next_fn, computation_base.Computation)
    next_parameter_type = next_fn.type_signature.parameter
    # `next_first_parameter_type` may be `next_parameter_type` or
    # `next_parameter_type[0]`, depending on which one was assignable from
    # `initialize_result_type`.
    if next_parameter_type.is_assignable_from(initialize_result_type):
      # The only argument is the state type
      state_type = next_parameter_type
    elif (next_parameter_type.is_struct() and next_parameter_type and
          next_parameter_type[0].is_assignable_from(initialize_result_type)):
      # The first argument is the state type
      state_type = next_parameter_type[0]
    else:
      raise errors.TemplateStateNotAssignableError(
          f'The return type of `initialize_fn` must be assignable to '
          f'the first input argument of `next_fn`, but:\n'
          f'`initialize_fn` returned type:\n{initialize_result_type}\n'
          f'and the first input argument of `next_fn` is:\n'
          f'{next_parameter_type}')

    next_result_type = next_fn.type_signature.result
    if state_type.is_assignable_from(next_result_type):
      # The whole return value is the state type
      pass
    elif (next_result_type.is_struct() and next_result_type and
          state_type.is_assignable_from(next_result_type[0])):
      # The first return value is state type
      pass
    else:
      raise errors.TemplateStateNotAssignableError(
          f'The first return argument of `next_fn` must be '
          f'assignable to its first input argument, but found\n'
          f'`next_fn` which returns type:\n{next_result_type}\n'
          f'which does not match its first input argument:\n{state_type}')

    self._state_type = state_type
    self._initialize_fn = initialize_fn
    self._next_fn = next_fn