def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None): """Wrapper function to plug Tensorflow logic in to TFF framework.""" del name # Unused. target_fn = function_utils.wrap_as_zero_or_one_arg_callable( target_fn, parameter_type, unpack) if not type_utils.is_tensorflow_compatible_type(parameter_type): raise TypeError('`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `NamedTupleType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( target_fn, parameter_type, ctx_stack) return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def check_and_pack_before_aggregate_type_signature(type_spec, previously_packed_types): """Checks types inferred from `before_aggregate` and packs in `previously_packed_types`. After splitting the `after_broadcast` portion of a `tff.utils.IterativeProcess` into `before_aggregate` and `after_aggregate`, `before_aggregate` should have type signature `<<s1,c1>,c2> -> <c5,zero,accumulate,merge,report>`. This function validates `c1`, `s1` and `c2` against the existing entries in `previously_packed_types`, then packs `s5`, `zero`, `accumulate`, `merge` and `report`. Args: type_spec: The `type_signature` attribute of the `before_aggregate` portion of the `tff.utils.IterativeProcess` from which we are looking to extract an instance of `canonical_form.CanonicalForm`. previously_packed_types: Dict containing the information from `next` and `before_broadcast` in the iterative process we are parsing. Returns: A `dict` packing the types which can be inferred from `type_spec`. Raises: TypeError: If `type_signature` is incompatible with `previously_packed_types`. """ should_raise = False if not (isinstance(type_spec, computation_types.FunctionType) and isinstance(type_spec.parameter, computation_types.NamedTupleType)): should_raise = True if not (isinstance(type_spec.parameter[0], computation_types.NamedTupleType) and len(type_spec.parameter[0]) == 2 and type_spec.parameter[0][0] == previously_packed_types['s1_type'] and type_spec.parameter[0][1] == previously_packed_types['c1_type']): should_raise = True if not (isinstance(type_spec.parameter[1], computation_types.FederatedType) and type_spec.parameter[1].placement == placements.CLIENTS and type_spec.parameter[1].member == previously_packed_types['s2_type'].member): should_raise = True if not (isinstance(type_spec.result, computation_types.NamedTupleType) and len(type_spec.result) == 5 and isinstance( type_spec.result[0], computation_types.FederatedType) and type_spec.result[0].placement == placements.CLIENTS and type_utils.is_tensorflow_compatible_type(type_spec.result[1]) and type_spec.result[2] == computation_types.FunctionType( [type_spec.result[1], type_spec.result[0].member], type_spec.result[1]) and type_spec.result[3] == computation_types.FunctionType( [type_spec.result[1], type_spec.result[1]], type_spec.result[1]) and type_spec.result[4].parameter == type_spec.result[1] and type_utils.is_tensorflow_compatible_type( type_spec.result[4].result)): should_raise = True if should_raise: # TODO(b/121290421): These error messages, and indeed the 'track boolean and # raise once' logic of these methods as well, is intended to be provisional # and revisited when we've seen the compilation pipeline fail more clearly, # or maybe preferably iteratively improved as new failure modes are # encountered. raise TypeError( 'Encountered a type error while checking ' '`before_aggregate`. Expected a type signature of the ' 'form `<<s1,c1>,c2> -> <c5,zero,accumulate,merge,report>`, ' 'where `s1` matches {}, `c1` matches {}, and `c2` matches ' 'the result of broadcasting {}, as defined in ' '`canonical_form.CanonicalForm`. Found type signature {}.'.format( previously_packed_types['s1_type'], previously_packed_types['c1_type'], previously_packed_types['s2_type'], type_spec)) newly_determined_types = {} c2_type = type_spec.parameter[1] newly_determined_types['c2_type'] = c2_type c3_type = computation_types.FederatedType( [previously_packed_types['c1_type'].member, c2_type.member], placements.CLIENTS) newly_determined_types['c3_type'] = c3_type c5_type = type_spec.result[0] zero_type = computation_types.FunctionType(None, type_spec.result[1]) accumulate_type = type_spec.result[2] merge_type = type_spec.result[3] report_type = type_spec.result[4] newly_determined_types['c5_type'] = c5_type newly_determined_types['zero_type'] = zero_type newly_determined_types['accumulate_type'] = accumulate_type newly_determined_types['merge_type'] = merge_type newly_determined_types['report_type'] = report_type newly_determined_types['s3_type'] = computation_types.FederatedType( report_type.result, placements.SERVER) c4_type = computation_types.FederatedType([ newly_determined_types['c5_type'].member, previously_packed_types['c6_type'].member ], placements.CLIENTS) newly_determined_types['c4_type'] = c4_type newly_determined_types['work_type'] = computation_types.FunctionType( c3_type.member, c4_type.member) return dict( itertools.chain(previously_packed_types.items(), newly_determined_types.items()))