def __init__(self, initialize_fn, next_fn): super().__init__(initialize_fn, next_fn, next_is_multi_arg=True) if not initialize_fn.type_signature.result.is_federated(): raise errors.TemplateNotFederatedError( f'Provided `initialize_fn` must return a federated type, but found ' f'return type:\n{initialize_fn.type_signature.result}\nTip: If you ' f'see a collection of federated types, try wrapping the returned ' f'value in `tff.federated_zip` before returning.') next_types = (structure.flatten(next_fn.type_signature.parameter) + structure.flatten(next_fn.type_signature.result)) if not all([t.is_federated() for t in next_types]): offending_types = '\n- '.join( [t for t in next_types if not t.is_federated()]) raise errors.TemplateNotFederatedError( f'Provided `next_fn` must be a *federated* computation, that is, ' f'operate on `tff.FederatedType`s, but found\n' f'next_fn with type signature:\n{next_fn.type_signature}\n' f'The non-federated types are:\n {offending_types}.') if initialize_fn.type_signature.result.placement != placements.SERVER: raise errors.TemplatePlacementError( f'The state controlled by an `DistributionProcess` must be placed at ' f'the SERVER, but found type: {initialize_fn.type_signature.result}.' ) # Note that state of next_fn being placed at SERVER is now ensured by the # assertions in base class which would otherwise raise # TemplateStateNotAssignableError. next_fn_param = next_fn.type_signature.parameter next_fn_result = next_fn.type_signature.result if not next_fn_param.is_struct(): raise errors.TemplateNextFnNumArgsError( f'The `next_fn` must have exactly two input arguments, but found ' f'the following input type which is not a Struct: {next_fn_param}.' ) if len(next_fn_param) != 2: next_param_str = '\n- '.join([str(t) for t in next_fn_param]) raise errors.TemplateNextFnNumArgsError( f'The `next_fn` must have exactly two input arguments, but found ' f'{len(next_fn_param)} input arguments:\n{next_param_str}') if next_fn_param[1].placement != placements.SERVER: raise errors.TemplatePlacementError( f'The second input argument of `next_fn` must be placed at SERVER ' f'but found {next_fn_param[1]}.') if next_fn_result.result.placement != placements.CLIENTS: raise errors.TemplatePlacementError( f'The "result" attribute of return type of `next_fn` must be placed ' f'at CLIENTS, but found {next_fn_result.result}.') if next_fn_result.measurements.placement != placements.SERVER: raise errors.TemplatePlacementError( f'The "measurements" attribute of return type of `next_fn` must be ' f'placed at SERVER, but found {next_fn_result.measurements}.')
def _infer_state_type(initialize_result_type, next_parameter_type, next_is_multi_arg): """Infers the state type from the `initialize` and `next` types.""" if next_is_multi_arg is None: # `state_type` may be `next_parameter_type` or # `next_parameter_type[0]`, depending on which one was assignable from # `initialize_result_type`. if next_parameter_type.is_assignable_from(initialize_result_type): return next_parameter_type if (_is_nonempty_struct(next_parameter_type) and next_parameter_type[0].is_assignable_from( initialize_result_type)): return next_parameter_type[0] raise errors.TemplateStateNotAssignableError( 'The return type of `initialize_fn` must be assignable to either\n' 'the whole argument to `next_fn` or the first argument to `next_fn`,\n' 'but found `initialize_fn` return type:\n' f'{initialize_result_type}\n' 'and `next_fn` with whole argument type:\n' f'{next_parameter_type}') elif next_is_multi_arg: if not _is_nonempty_struct(next_parameter_type): raise errors.TemplateNextFnNumArgsError( 'Expected `next_parameter_type` to be a structure type of at least ' f'length one, but found type:\n{next_parameter_type}') if next_parameter_type[0].is_assignable_from(initialize_result_type): return next_parameter_type[0] raise errors.TemplateStateNotAssignableError( 'The return type of `initialize_fn` must be assignable to the first\n' 'argument to `next_fn`, but found `initialize_fn` return type:\n' f'{initialize_result_type}\n' 'and `next_fn` whose first argument type is:\n' f'{next_parameter_type}') else: # `next_is_multi_arg` is `False` if next_parameter_type.is_assignable_from(initialize_result_type): return next_parameter_type raise errors.TemplateStateNotAssignableError( 'The return type of `initialize_fn` must be assignable to the whole\n' 'argument to `next_fn`, but found `initialize_fn` return type:\n' f'{initialize_result_type}\n' 'and `next_fn` whose first argument type is:\n' f'{next_parameter_type}')
def __init__(self, initialize_fn, next_fn): super().__init__(initialize_fn, next_fn, next_is_multi_arg=True) if not initialize_fn.type_signature.result.is_federated(): raise errors.TemplateNotFederatedError( f'Provided `initialize_fn` must return a federated type, but found ' f'return type:\n{initialize_fn.type_signature.result}\nTip: If you ' f'see a collection of federated types, try wrapping the returned ' f'value in `tff.federated_zip` before returning.') next_types = ( structure.flatten(next_fn.type_signature.parameter) + structure.flatten(next_fn.type_signature.result)) if not all([t.is_federated() for t in next_types]): offending_types = '\n- '.join( [t for t in next_types if not t.is_federated()]) raise errors.TemplateNotFederatedError( f'Provided `next_fn` must be a *federated* computation, that is, ' f'operate on `tff.FederatedType`s, but found\n' f'next_fn with type signature:\n{next_fn.type_signature}\n' f'The non-federated types are:\n {offending_types}.') if initialize_fn.type_signature.result.placement != placements.SERVER: raise errors.TemplatePlacementError( f'The state controlled by a `ClientWorkProcess` must be placed at ' f'the SERVER, but found type: {initialize_fn.type_signature.result}.') # Note that state of next_fn being placed at SERVER is now ensured by the # assertions in base class which would otherwise raise # TemplateStateNotAssignableError. next_fn_param = next_fn.type_signature.parameter if not next_fn_param.is_struct(): raise errors.TemplateNextFnNumArgsError( f'The `next_fn` must have exactly three input arguments, but found ' f'the following input type which is not a Struct: {next_fn_param}.') if len(next_fn_param) != 3: next_param_str = '\n- '.join([str(t) for t in next_fn_param]) raise errors.TemplateNextFnNumArgsError( f'The `next_fn` must have exactly three input arguments, but found ' f'{len(next_fn_param)} input arguments:\n{next_param_str}') model_weights_param = next_fn_param[1] client_data_param = next_fn_param[2] if model_weights_param.placement != placements.CLIENTS: raise errors.TemplatePlacementError( f'The second input argument of `next_fn` must be placed at CLIENTS ' f'but found {model_weights_param}.') if (not model_weights_param.member.is_struct_with_python() or model_weights_param.member.python_container is not model_utils.ModelWeights): raise ModelWeightsTypeError( f'The second input argument of `next_fn` must have the ' f'`tff.learning.ModelWeights` container but found ' f'{model_weights_param}') if client_data_param.placement != placements.CLIENTS: raise errors.TemplatePlacementError( f'The third input argument of `next_fn` must be placed at CLIENTS ' f'but found {client_data_param}.') if not client_data_param.member.is_sequence(): raise ClientDataTypeError( f'The third input argument of `next_fn` must be a sequence but found ' f'{client_data_param}.') next_fn_result = next_fn.type_signature.result if (not next_fn_result.result.is_federated() or next_fn_result.result.placement != placements.CLIENTS): raise errors.TemplatePlacementError( f'The "result" attribute of the return type of `next_fn` must be ' f'placed at CLIENTS, but found {next_fn_result.result}.') if (not next_fn_result.result.member.is_struct_with_python() or next_fn_result.result.member.python_container is not ClientResult): raise ClientResultTypeError( f'The "result" attribute of the return type of `next_fn` must have ' f'the `ClientResult` container, but found {next_fn_result.result}.') if not model_weights_param.member.trainable.is_assignable_from( next_fn_result.result.member.update): raise ClientResultTypeError( f'The "update" attribute of returned `ClientResult` must match ' f'the "trainable" attribute of the `tff.learning.ModelWeights` ' f'expected as second input argument of the `next_fn`. Found:\n' f'Second input argument: {model_weights_param.member.trainable}\n' f'Update attribute of result: {next_fn_result.result.member.update}.') if next_fn_result.measurements.placement != placements.SERVER: raise errors.TemplatePlacementError( f'The "measurements" attribute of return type of `next_fn` must be ' f'placed at SERVER, but found {next_fn_result.measurements}.')
def __init__(self, initialize_fn: computation_base.Computation, next_fn: computation_base.Computation, report_fn: computation_base.Computation): """Creates a `tff.templates.AggregationProcess`. Args: initialize_fn: A no-arg `tff.Computation` that creates the initial state of the learning process. next_fn: A `tff.Computation` that defines an iterated function. Given that `initialize_fn` returns a type `S@SERVER`, the `next_fn` must return a `LearningProcessOutput` where the `state` attribute matches the type `S@SERVER`, and accepts two argument of types `S@SERVER` and `{D*}@CLIENTS`. report_fn: A `tff.Computation` that accepts an input `S` where the output of `initialize_fn` is of type `S@SERVER`. This computation is used to create a representation of the state that can be used for downstream tasks without requiring access to the entire server state. For example, `report_fn` could be used to extract model weights for computing metrics on held-out data. Raises: TypeError: If `initialize_fn` and `next_fn` are not instances of `tff.Computation`. TemplateInitFnParamNotEmptyError: If `initialize_fn` has any input arguments. TemplateStateNotAssignableError: If the `state` returned by either `initialize_fn` or `next_fn` is not assignable to the first input argument of `next_fn`. TemplateNextFnNumArgsError: If `next_fn` does not have at exactly two input arguments. LearningProcessPlacementError: If the placements of `initialize_fn` and `next_fn` do not match the expected type placements. LearningProcessOutputError: If `next_fn` does not return a `LearningProcessOutput`. LearningProcessSequenceTypeError: If the second argument to `next_fn` is not a sequence type. """ super().__init__(initialize_fn, next_fn) init_fn_result = initialize_fn.type_signature.result if init_fn_result.placement != placements.SERVER: raise LearningProcessPlacementError( f'The result of `initialize_fn` must be placed at `SERVER` but found ' f'placement {init_fn_result.placement}.') next_result_type = next_fn.type_signature.result if not (isinstance(next_result_type, computation_types.StructWithPythonType) and next_result_type.python_container is LearningProcessOutput): raise LearningProcessOutputError( f'The `next_fn` of a `LearningProcess` must return a ' f'`LearningProcessOutput` object, but returns {next_result_type!r}' ) # We perform a more strict type check on the inputs to `next_fn` than in the # base class. next_fn_param = next_fn.type_signature.parameter if not next_fn_param.is_struct() or len(next_fn_param) != 2: raise errors.TemplateNextFnNumArgsError( f'The `next_fn` must have two input arguments, but found an input ' f'of type {next_fn_param}.') if next_fn_param[1].placement != placements.CLIENTS: raise LearningProcessPlacementError( f'The second input argument of `next_fn` must be placed at `CLIENTS`,' f' but found placement {next_fn_param[1].placement}.') if not next_fn_param[1].member.is_sequence(): raise LearningProcessSequenceTypeError( f'The member type of the second input argument to `next_fn` must be a' f' `tff.SequenceType` but found {next_fn_param[1].member} instead.' ) next_fn_result = next_fn.type_signature.result if next_fn_result.metrics.placement != placements.SERVER: raise LearningProcessPlacementError( f'The result of `next_fn` must be placed at `SERVER` but found ' f'placement {next_fn_result.metrics.placement} for `metrics`.') py_typecheck.check_type(report_fn, computation_base.Computation) report_fn_type = report_fn.type_signature if report_fn_type.is_federated(): raise LearningProcessPlacementError( f'The `report_fn` must not be a federated computation, ' f'but found `report_fn` with type signature:\n' f'{report_fn_type}') report_fn_param = report_fn.type_signature.parameter state_type_without_placement = initialize_fn.type_signature.result.member if not report_fn_param.is_assignable_from( state_type_without_placement): raise ReportFnTypeSignatureError( f'The input type of `report_fn` must be assignable from ' f'the member type of the output of `initialize_fn`, but found input ' f'type {report_fn_param}, which is not assignable from ' f'{state_type_without_placement}.') self._report_fn = report_fn
def __init__(self, initialize_fn: computation_base.Computation, next_fn: computation_base.Computation): """Creates a `tff.templates.AggregationProcess`. Args: initialize_fn: A no-arg `tff.Computation` that returns the initial state of the aggregation process. The returned state must be a server-placed federated value. Let the type of this state be called `S@SERVER`. next_fn: A `tff.Computation` that represents the iterated function. `next_fn` must accept at least two arguments, the first of which is of state type `S@SERVER` and the second of which is client-placed data of type `V@CLIENTS`. `next_fn` must return a `MeasuredProcessOutput` where the `state` attribute matches the type `S@SERVER` and the `result` attribute matches type `V@SERVER`. Raises: TypeError: If `initialize_fn` and `next_fn` are not instances of `tff.Computation`. TemplateInitFnParamNotEmptyError: If `initialize_fn` has any input arguments. TemplateStateNotAssignableError: If the `state` returned by either `initialize_fn` or `next_fn` is not assignable to the first input argument of `next_fn`. TemplateNotMeasuredProcessOutputError: If `next_fn` does not return a `MeasuredProcessOutput`. TemplateNextFnNumArgsError: If `next_fn` does not have at least two input arguments. AggregationNotFederatedError: If `initialize_fn` and `next_fn` are not computations operating on federated types. AggregationPlacementError: If the placements of `initialize_fn` and `next_fn` are not matching the expected type signature. """ # Calling super class __init__ first ensures that # next_fn.type_signature.result is a `MeasuredProcessOutput`, make our # validation here easier as that must be true. super().__init__(initialize_fn, next_fn, next_is_multi_arg=True) if not initialize_fn.type_signature.result.is_federated(): raise AggregationNotFederatedError( f'Provided `initialize_fn` must return a federated type, but found ' f'return type:\n{initialize_fn.type_signature.result}\nTip: If you ' f'see a collection of federated types, try wrapping the returned ' f'value in `tff.federated_zip` before returning.') next_types = (structure.flatten(next_fn.type_signature.parameter) + structure.flatten(next_fn.type_signature.result)) non_federated_types = [t for t in next_types if not t.is_federated()] if non_federated_types: offending_types_str = '\n- '.join( str(t) for t in non_federated_types) raise AggregationNotFederatedError( f'Provided `next_fn` must both be a *federated* computations, that ' f'is, operate on `tff.FederatedType`s, but found\n' f'next_fn with type signature:\n{next_fn.type_signature}\n' f'The non-federated types are:\n {offending_types_str}.') if initialize_fn.type_signature.result.placement != placements.SERVER: raise AggregationPlacementError( f'The state controlled by an `AggregationProcess` must be placed at ' f'the SERVER, but found type: {initialize_fn.type_signature.result}.' ) # Note that state of next_fn being placed at SERVER is now ensured by the # assertions in base class which would otherwise raise # errors.TemplateStateNotAssignableError. next_fn_param = next_fn.type_signature.parameter next_fn_result = next_fn.type_signature.result if len(next_fn_param) < 2: raise errors.TemplateNextFnNumArgsError( f'The `next_fn` must have at least two input arguments, but found ' f'the following input type: {next_fn_param}.') if next_fn_param[_INPUT_PARAM_INDEX].placement != placements.CLIENTS: raise AggregationPlacementError( f'The second input argument of `next_fn` must be placed at CLIENTS ' f'but found {next_fn_param[_INPUT_PARAM_INDEX]}.') if next_fn_result.result.placement != placements.SERVER: raise AggregationPlacementError( f'The "result" attribute of return type of `next_fn` must be placed ' f'at SERVER, but found {next_fn_result.result}.') if next_fn_result.measurements.placement != placements.SERVER: raise AggregationPlacementError( f'The "measurements" attribute of return type of `next_fn` must be ' f'placed at SERVER, but found {next_fn_result.measurements}.')
def __init__(self, initialize_fn: computation_base.Computation, next_fn: computation_base.Computation, get_model_weights: computation_base.Computation, set_model_weights: computation_base.Computation): """Creates a `tff.templates.AggregationProcess`. Args: initialize_fn: A no-arg `tff.Computation` that creates the initial state of the learning process. next_fn: A `tff.Computation` that defines an iterated function. Given that `initialize_fn` returns a type `S@SERVER`, the `next_fn` must return a `LearningProcessOutput` where the `state` attribute is assignable from values with type `S@SERVER`, and accepts two arguments with types assignable from values with type `S@SERVER` and `{D*}@CLIENTS`. get_model_weights: A `tff.Computation` that accepts an input `S` whose type is assignable from the result of `init_fn`. This computation is used to create a representation of the state that can be used for downstream tasks without requiring access to the entire server state. For example, `get_model_weights` could be used to extract model weights suitable for computing evaluation metrics on held-out data. set_model_weights: A `tff.Computation` that accepts two inputs `S` and `M` where the type of `S` is assignable from values with the type returned by `init_fn` and `M` is a representation of the model weights stored in `S`. This updates the model weights representation within the state with the incoming value and returns a new value of type `S`. Raises: TypeError: If `initialize_fn` and `next_fn` are not instances of `tff.Computation`. TemplateInitFnParamNotEmptyError: If `initialize_fn` has any input arguments. TemplateStateNotAssignableError: If the `state` returned by either `initialize_fn` or `next_fn` is not assignable to the first input argument of `next_fn`. TemplateNextFnNumArgsError: If `next_fn` does not have at exactly two input arguments. LearningProcessPlacementError: If the placements of `initialize_fn` and `next_fn` do not match the expected type placements. LearningProcessOutputError: If `next_fn` does not return a `LearningProcessOutput`. LearningProcessSequenceTypeError: If the second argument to `next_fn` is not a sequence type. """ super().__init__(initialize_fn, next_fn) init_fn_result = initialize_fn.type_signature.result if init_fn_result.placement != placements.SERVER: raise LearningProcessPlacementError( f'The result of `initialize_fn` must be placed at `SERVER` but found ' f'placement {init_fn_result.placement}.') next_result_type = next_fn.type_signature.result if not (isinstance(next_result_type, computation_types.StructWithPythonType) and next_result_type.python_container is LearningProcessOutput): raise LearningProcessOutputError( f'The `next_fn` of a `LearningProcess` must return a ' f'`LearningProcessOutput` object, but returns {next_result_type!r}' ) # We perform a more strict type check on the inputs to `next_fn` than in the # base class. next_fn_param = next_fn.type_signature.parameter if not next_fn_param.is_struct() or len(next_fn_param) != 2: raise errors.TemplateNextFnNumArgsError( f'The `next_fn` must have two input arguments, but found an input ' f'of type {next_fn_param}.') if next_fn_param[1].placement != placements.CLIENTS: raise LearningProcessPlacementError( f'The second input argument of `next_fn` must be placed at `CLIENTS`,' f' but found placement {next_fn_param[1].placement}.') def is_allowed_client_data_type( type_spec: computation_types.Type) -> bool: """Returns `True` if the type is a valid client dataset type.""" if type_spec.is_sequence(): return type_analysis.is_tensorflow_compatible_type( type_spec.element) elif type_spec.is_struct(): return all( is_allowed_client_data_type(element_type) for element_type in type_spec.children()) else: return False if not is_allowed_client_data_type(next_fn_param[1].member): raise LearningProcessSequenceTypeError( f'The member type of the second input argument to `next_fn` must be a' f' `tff.SequenceType` or a nested `tff.StructType` of sequence types ' f'but found {next_fn_param[1].member} instead.') next_fn_result = next_fn.type_signature.result if next_fn_result.metrics.placement != placements.SERVER: raise LearningProcessPlacementError( f'The result of `next_fn` must be placed at `SERVER` but found ' f'placement {next_fn_result.metrics.placement} for `metrics`.') py_typecheck.check_type(get_model_weights, computation_base.Computation) get_model_weights_type = get_model_weights.type_signature get_model_weights_param = get_model_weights_type.parameter next_fn_state_param = next_fn.type_signature.parameter[0].member if not get_model_weights_param.is_equivalent_to(next_fn_state_param): raise GetModelWeightsTypeSignatureError( f'The input type of `get_model_weights` must be assignable from ' f'the member type of the output of `initialize_fn`, but found input ' f'type {get_model_weights_param}, which is not equivalent to ' f'{next_fn_state_param}.') self._get_model_weights = get_model_weights py_typecheck.check_type(set_model_weights, computation_base.Computation) set_model_weights_type = set_model_weights.type_signature set_model_weights_state_param = set_model_weights_type.parameter[0] if not set_model_weights_state_param.is_equivalent_to( next_fn_state_param): raise SetModelWeightsTypeSignatureError( f'The input type of `set_model_weights` must be assignable from ' f'the member type of the output of `initialize_fn`, but found input ' f'type {set_model_weights_state_param}, which is not equivalent to ' f'{next_fn_state_param}.') set_model_weights_result = set_model_weights_type.result if not next_fn_state_param.is_assignable_from( set_model_weights_result): raise SetModelWeightsTypeSignatureError( f'The output type of `set_model_weights` must be assignable to ' f'the first parameter of `next_fn`, but found input ' f'type {set_model_weights_result}, which is not assignable to; ' f'{next_fn_state_param}.') self._set_model_weights = set_model_weights
def __init__(self, initialize_fn, next_fn): super().__init__(initialize_fn, next_fn, next_is_multi_arg=True) if not initialize_fn.type_signature.result.is_federated(): raise errors.TemplateNotFederatedError( f'Provided `initialize_fn` must return a federated type, but found ' f'return type:\n{initialize_fn.type_signature.result}\nTip: If you ' f'see a collection of federated types, try wrapping the returned ' f'value in `tff.federated_zip` before returning.') next_types = (structure.flatten(next_fn.type_signature.parameter) + structure.flatten(next_fn.type_signature.result)) if not all([t.is_federated() for t in next_types]): offending_types = '\n- '.join( [t for t in next_types if not t.is_federated()]) raise errors.TemplateNotFederatedError( f'Provided `next_fn` must be a *federated* computation, that is, ' f'operate on `tff.FederatedType`s, but found\n' f'next_fn with type signature:\n{next_fn.type_signature}\n' f'The non-federated types are:\n {offending_types}.') if initialize_fn.type_signature.result.placement != placements.SERVER: raise errors.TemplatePlacementError( f'The state controlled by a `ClientWorkProcess` must be placed at ' f'the SERVER, but found type: {initialize_fn.type_signature.result}.' ) # Note that state of next_fn being placed at SERVER is now ensured by the # assertions in base class which would otherwise raise # TemplateStateNotAssignableError. next_fn_param = next_fn.type_signature.parameter if not next_fn_param.is_struct(): raise errors.TemplateNextFnNumArgsError( f'The `next_fn` must have exactly three input arguments, but found ' f'the following input type which is not a Struct: {next_fn_param}.' ) if len(next_fn_param) != 3: next_param_str = '\n- '.join([str(t) for t in next_fn_param]) raise errors.TemplateNextFnNumArgsError( f'The `next_fn` must have exactly three input arguments, but found ' f'{len(next_fn_param)} input arguments:\n{next_param_str}') second_next_param = next_fn_param[1] client_data_param = next_fn_param[2] if second_next_param.placement != placements.CLIENTS: raise errors.TemplatePlacementError( f'The second input argument of `next_fn` must be placed at CLIENTS ' f'but found {second_next_param}.') if client_data_param.placement != placements.CLIENTS: raise errors.TemplatePlacementError( f'The third input argument of `next_fn` must be placed at CLIENTS ' f'but found {client_data_param}.') def is_allowed_client_data_type( type_spec: computation_types.Type) -> bool: if type_spec.is_sequence(): return type_analysis.is_tensorflow_compatible_type( type_spec.element) elif type_spec.is_struct(): return all( is_allowed_client_data_type(element_type) for element_type in type_spec.children()) else: return False if not is_allowed_client_data_type(client_data_param.member): raise ClientDataTypeError( f'The third input argument of `next_fn` must be a sequence or ' f'a structure of squences, but found {client_data_param}.') next_fn_result = next_fn.type_signature.result if (not next_fn_result.result.is_federated() or next_fn_result.result.placement != placements.CLIENTS): raise errors.TemplatePlacementError( f'The "result" attribute of the return type of `next_fn` must be ' f'placed at CLIENTS, but found {next_fn_result.result}.') if (not next_fn_result.result.member.is_struct_with_python() or next_fn_result.result.member.python_container is not ClientResult): raise ClientResultTypeError( f'The "result" attribute of the return type of `next_fn` must have ' f'the `ClientResult` container, but found {next_fn_result.result}.' ) if next_fn_result.measurements.placement != placements.SERVER: raise errors.TemplatePlacementError( f'The "measurements" attribute of return type of `next_fn` must be ' f'placed at SERVER, but found {next_fn_result.measurements}.')