Пример #1
0
    def __init__(self, initialize_fn, next_fn):
        super().__init__(initialize_fn, next_fn, next_is_multi_arg=True)

        if not initialize_fn.type_signature.result.is_federated():
            raise errors.TemplateNotFederatedError(
                f'Provided `initialize_fn` must return a federated type, but found '
                f'return type:\n{initialize_fn.type_signature.result}\nTip: If you '
                f'see a collection of federated types, try wrapping the returned '
                f'value in `tff.federated_zip` before returning.')
        next_types = (structure.flatten(next_fn.type_signature.parameter) +
                      structure.flatten(next_fn.type_signature.result))
        if not all([t.is_federated() for t in next_types]):
            offending_types = '\n- '.join(
                [t for t in next_types if not t.is_federated()])
            raise errors.TemplateNotFederatedError(
                f'Provided `next_fn` must be a *federated* computation, that is, '
                f'operate on `tff.FederatedType`s, but found\n'
                f'next_fn with type signature:\n{next_fn.type_signature}\n'
                f'The non-federated types are:\n {offending_types}.')

        if initialize_fn.type_signature.result.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The state controlled by an `DistributionProcess` must be placed at '
                f'the SERVER, but found type: {initialize_fn.type_signature.result}.'
            )
        # Note that state of next_fn being placed at SERVER is now ensured by the
        # assertions in base class which would otherwise raise
        # TemplateStateNotAssignableError.

        next_fn_param = next_fn.type_signature.parameter
        next_fn_result = next_fn.type_signature.result
        if not next_fn_param.is_struct():
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have exactly two input arguments, but found '
                f'the following input type which is not a Struct: {next_fn_param}.'
            )
        if len(next_fn_param) != 2:
            next_param_str = '\n- '.join([str(t) for t in next_fn_param])
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have exactly two input arguments, but found '
                f'{len(next_fn_param)} input arguments:\n{next_param_str}')
        if next_fn_param[1].placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The second input argument of `next_fn` must be placed at SERVER '
                f'but found {next_fn_param[1]}.')

        if next_fn_result.result.placement != placements.CLIENTS:
            raise errors.TemplatePlacementError(
                f'The "result" attribute of return type of `next_fn` must be placed '
                f'at CLIENTS, but found {next_fn_result.result}.')
        if next_fn_result.measurements.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The "measurements" attribute of return type of `next_fn` must be '
                f'placed at SERVER, but found {next_fn_result.measurements}.')
Пример #2
0
  def __init__(self, initialize_fn, next_fn):
    super().__init__(initialize_fn, next_fn, next_is_multi_arg=True)

    if not initialize_fn.type_signature.result.is_federated():
      raise errors.TemplateNotFederatedError(
          f'Provided `initialize_fn` must return a federated type, but found '
          f'return type:\n{initialize_fn.type_signature.result}\nTip: If you '
          f'see a collection of federated types, try wrapping the returned '
          f'value in `tff.federated_zip` before returning.')
    next_types = (
        structure.flatten(next_fn.type_signature.parameter) +
        structure.flatten(next_fn.type_signature.result))
    if not all([t.is_federated() for t in next_types]):
      offending_types = '\n- '.join(
          [t for t in next_types if not t.is_federated()])
      raise errors.TemplateNotFederatedError(
          f'Provided `next_fn` must be a *federated* computation, that is, '
          f'operate on `tff.FederatedType`s, but found\n'
          f'next_fn with type signature:\n{next_fn.type_signature}\n'
          f'The non-federated types are:\n {offending_types}.')

    if initialize_fn.type_signature.result.placement != placements.SERVER:
      raise errors.TemplatePlacementError(
          f'The state controlled by a `ClientWorkProcess` must be placed at '
          f'the SERVER, but found type: {initialize_fn.type_signature.result}.')
    # Note that state of next_fn being placed at SERVER is now ensured by the
    # assertions in base class which would otherwise raise
    # TemplateStateNotAssignableError.

    next_fn_param = next_fn.type_signature.parameter
    if not next_fn_param.is_struct():
      raise errors.TemplateNextFnNumArgsError(
          f'The `next_fn` must have exactly three input arguments, but found '
          f'the following input type which is not a Struct: {next_fn_param}.')
    if len(next_fn_param) != 3:
      next_param_str = '\n- '.join([str(t) for t in next_fn_param])
      raise errors.TemplateNextFnNumArgsError(
          f'The `next_fn` must have exactly three input arguments, but found '
          f'{len(next_fn_param)} input arguments:\n{next_param_str}')
    model_weights_param = next_fn_param[1]
    client_data_param = next_fn_param[2]
    if model_weights_param.placement != placements.CLIENTS:
      raise errors.TemplatePlacementError(
          f'The second input argument of `next_fn` must be placed at CLIENTS '
          f'but found {model_weights_param}.')
    if (not model_weights_param.member.is_struct_with_python() or
        model_weights_param.member.python_container
        is not model_utils.ModelWeights):
      raise ModelWeightsTypeError(
          f'The second input argument of `next_fn` must have the '
          f'`tff.learning.ModelWeights` container but found '
          f'{model_weights_param}')
    if client_data_param.placement != placements.CLIENTS:
      raise errors.TemplatePlacementError(
          f'The third input argument of `next_fn` must be placed at CLIENTS '
          f'but found {client_data_param}.')
    if not client_data_param.member.is_sequence():
      raise ClientDataTypeError(
          f'The third input argument of `next_fn` must be a sequence but found '
          f'{client_data_param}.')

    next_fn_result = next_fn.type_signature.result
    if (not next_fn_result.result.is_federated() or
        next_fn_result.result.placement != placements.CLIENTS):
      raise errors.TemplatePlacementError(
          f'The "result" attribute of the return type of `next_fn` must be '
          f'placed at CLIENTS, but found {next_fn_result.result}.')
    if (not next_fn_result.result.member.is_struct_with_python() or
        next_fn_result.result.member.python_container is not ClientResult):
      raise ClientResultTypeError(
          f'The "result" attribute of the return type of `next_fn` must have '
          f'the `ClientResult` container, but found {next_fn_result.result}.')
    if not model_weights_param.member.trainable.is_assignable_from(
        next_fn_result.result.member.update):
      raise ClientResultTypeError(
          f'The "update" attribute of returned `ClientResult` must match '
          f'the "trainable" attribute of the `tff.learning.ModelWeights` '
          f'expected as second input argument of the `next_fn`. Found:\n'
          f'Second input argument: {model_weights_param.member.trainable}\n'
          f'Update attribute of result: {next_fn_result.result.member.update}.')
    if next_fn_result.measurements.placement != placements.SERVER:
      raise errors.TemplatePlacementError(
          f'The "measurements" attribute of return type of `next_fn` must be '
          f'placed at SERVER, but found {next_fn_result.measurements}.')
Пример #3
0
    def __init__(self, initialize_fn, next_fn):
        super().__init__(initialize_fn, next_fn, next_is_multi_arg=True)

        if not initialize_fn.type_signature.result.is_federated():
            raise errors.TemplateNotFederatedError(
                f'Provided `initialize_fn` must return a federated type, but found '
                f'return type:\n{initialize_fn.type_signature.result}\nTip: If you '
                f'see a collection of federated types, try wrapping the returned '
                f'value in `tff.federated_zip` before returning.')
        next_types = (structure.flatten(next_fn.type_signature.parameter) +
                      structure.flatten(next_fn.type_signature.result))
        if not all([t.is_federated() for t in next_types]):
            offending_types = '\n- '.join(
                [t for t in next_types if not t.is_federated()])
            raise errors.TemplateNotFederatedError(
                f'Provided `next_fn` must be a *federated* computation, that is, '
                f'operate on `tff.FederatedType`s, but found\n'
                f'next_fn with type signature:\n{next_fn.type_signature}\n'
                f'The non-federated types are:\n {offending_types}.')

        if initialize_fn.type_signature.result.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The state controlled by a `ClientWorkProcess` must be placed at '
                f'the SERVER, but found type: {initialize_fn.type_signature.result}.'
            )
        # Note that state of next_fn being placed at SERVER is now ensured by the
        # assertions in base class which would otherwise raise
        # TemplateStateNotAssignableError.

        next_fn_param = next_fn.type_signature.parameter
        if not next_fn_param.is_struct():
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have exactly three input arguments, but found '
                f'the following input type which is not a Struct: {next_fn_param}.'
            )
        if len(next_fn_param) != 3:
            next_param_str = '\n- '.join([str(t) for t in next_fn_param])
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have exactly three input arguments, but found '
                f'{len(next_fn_param)} input arguments:\n{next_param_str}')
        second_next_param = next_fn_param[1]
        client_data_param = next_fn_param[2]
        if second_next_param.placement != placements.CLIENTS:
            raise errors.TemplatePlacementError(
                f'The second input argument of `next_fn` must be placed at CLIENTS '
                f'but found {second_next_param}.')
        if client_data_param.placement != placements.CLIENTS:
            raise errors.TemplatePlacementError(
                f'The third input argument of `next_fn` must be placed at CLIENTS '
                f'but found {client_data_param}.')

        def is_allowed_client_data_type(
                type_spec: computation_types.Type) -> bool:
            if type_spec.is_sequence():
                return type_analysis.is_tensorflow_compatible_type(
                    type_spec.element)
            elif type_spec.is_struct():
                return all(
                    is_allowed_client_data_type(element_type)
                    for element_type in type_spec.children())
            else:
                return False

        if not is_allowed_client_data_type(client_data_param.member):
            raise ClientDataTypeError(
                f'The third input argument of `next_fn` must be a sequence or '
                f'a structure of squences, but found {client_data_param}.')

        next_fn_result = next_fn.type_signature.result
        if (not next_fn_result.result.is_federated()
                or next_fn_result.result.placement != placements.CLIENTS):
            raise errors.TemplatePlacementError(
                f'The "result" attribute of the return type of `next_fn` must be '
                f'placed at CLIENTS, but found {next_fn_result.result}.')
        if (not next_fn_result.result.member.is_struct_with_python()
                or next_fn_result.result.member.python_container
                is not ClientResult):
            raise ClientResultTypeError(
                f'The "result" attribute of the return type of `next_fn` must have '
                f'the `ClientResult` container, but found {next_fn_result.result}.'
            )
        if next_fn_result.measurements.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The "measurements" attribute of return type of `next_fn` must be '
                f'placed at SERVER, but found {next_fn_result.measurements}.')