def create_replicate_input(type_signature: computation_types.Type, count: int) -> ProtoAndType: """Returns a tensorflow computation returning `count` copies of its argument. The returned computation has the type signature `(T -> <T, T, T, ...>)`, where `T` is `type_signature` and the length of the result is `count`. Args: type_signature: A `computation_types.Type` to replicate. count: An integer, the number of times the input is replicated. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings or if `which` is not an integer. """ type_analysis.check_tensorflow_compatible_type(type_signature) py_typecheck.check_type(count, int) parameter_type = type_signature with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', parameter_type, graph) result = [parameter_value] * count result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def create_identity( type_signature: computation_types.Type) -> ComputationProtoAndType: """Returns a tensorflow computation representing an identity function. The returned computation has the type signature `(T -> T)`, where `T` is `type_signature`. NOTE: if `T` contains `computation_types.StructType`s without an associated container type, they will be given the container type `tuple` by this function. Args: type_signature: A `computation_types.Type` to use as the parameter type and result type of the identity function. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings. """ type_analysis.check_tensorflow_compatible_type(type_signature) parameter_type = type_signature if parameter_type is None: raise TypeError('TensorFlow identity cannot be created for NoneType.') # TF relies on feeds not-identical to fetches in certain circumstances. if type_signature.is_tensor() or type_signature.is_sequence(): identity_fn = tf.identity elif type_signature.is_struct(): identity_fn = functools.partial(structure.map_structure, tf.identity) else: raise NotImplementedError( f'TensorFlow identity cannot be created for type {type_signature}') return create_computation_for_py_fn(identity_fn, parameter_type)
def create_identity(type_signature: computation_types.Type) -> ProtoAndType: """Returns a tensorflow computation representing an identity function. The returned computation has the type signature `(T -> T)`, where `T` is `type_signature`. NOTE: if `T` contains `computation_types.StructType`s without an associated container type, they will be given the container type `tuple` by this function. Args: type_signature: A `computation_types.Type` to use as the parameter type and result type of the identity function. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings. """ type_analysis.check_tensorflow_compatible_type(type_signature) parameter_type = type_signature with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', parameter_type, graph) result_type, result_binding = tensorflow_utils.capture_result_from_graph( parameter_value, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def create_replicate_input(type_spec, count: int) -> pb.Computation: """Returns a tensorflow computation which returns `count` clones of an input. The returned computation has the type signature `(T -> <T, T, T, ...>)`, where `T` is `type_spec` and the length of the result is `count`. Args: type_spec: A type convertible to instance of `computation_types.Type` via `computation_types.to_type`. count: An integer, the number of times the input is replicated. Raises: TypeError: If `type_spec` contains any types which cannot appear in TensorFlow bindings or if `which` is not an integer. """ type_spec = computation_types.to_type(type_spec) type_analysis.check_tensorflow_compatible_type(type_spec) py_typecheck.check_type(count, int) with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', type_spec, graph) result = [parameter_value] * count result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(type_spec, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)
def create_identity(type_spec) -> pb.Computation: """Returns a tensorflow computation representing an identity function. The returned computation has the type signature `(T -> T)`, where `T` is `type_spec`. Args: type_spec: A type convertible to instance of `computation_types.Type` via `computation_types.to_type`. Raises: TypeError: If `type_spec` contains any types which cannot appear in TensorFlow bindings. """ type_spec = computation_types.to_type(type_spec) type_analysis.check_tensorflow_compatible_type(type_spec) with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', type_spec, graph) result_type, result_binding = tensorflow_utils.capture_result_from_graph( parameter_value, graph) type_signature = computation_types.FunctionType(type_spec, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)
def compile_local_computation_to_tensorflow( comp: building_blocks.ComputationBuildingBlock, ) -> building_blocks.ComputationBuildingBlock: """Compiles a fully specified local computation to TensorFlow. Args: comp: A `building_blocks.ComputationBuildingBlock` which can be compiled to TensorFlow. In order to compile a computation to TensorFlow, it must not contain 1. References to values defined outside of comp, 2. `Data`, `Intrinsic`, or `Placement` blocks, or 3. Calls to intrinsics or non-TensorFlow computations. Returns: A `building_blocks.ComputationBuildingBlock` containing a TensorFlow-only representation of `comp`. If `comp` is of functional type, this will be a `building_blocks.CompiledComputation`. Otherwise, it will be a `building_blocks.Call` which wraps a `building_blocks.CompiledComputation`. """ if not comp.type_signature.is_function(): lambda_wrapped = building_blocks.Lambda(None, None, comp) return building_blocks.Call( compile_local_computation_to_tensorflow(lambda_wrapped), None) parameter_type = comp.type_signature.parameter type_analysis.check_tensorflow_compatible_type(parameter_type) type_analysis.check_tensorflow_compatible_type(comp.type_signature.result) if (comp.is_compiled_computation() and comp.proto.WhichOneof('computation') == 'tensorflow'): return comp # Ensure that unused values are removed and that reference bindings have # unique names. comp = unpack_compiled_computations(comp) comp = transformations.to_call_dominant(comp) if parameter_type is None: to_evaluate = building_blocks.Call(comp) @tensorflow_computation.tf_computation def result_computation(): return _evaluate_to_tensorflow(to_evaluate, {}) else: name_generator = building_block_factory.unique_name_generator(comp) parameter_name = next(name_generator) to_evaluate = building_blocks.Call( comp, building_blocks.Reference(parameter_name, parameter_type)) @tensorflow_computation.tf_computation(parameter_type) def result_computation(arg): if parameter_type.is_struct(): arg = structure.from_container(arg, recursive=True) return _evaluate_to_tensorflow(to_evaluate, {parameter_name: arg}) return result_computation.to_compiled_building_block()
def create_replicate_input(type_signature: computation_types.Type, count: int) -> ComputationProtoAndType: """Returns a tensorflow computation returning `count` copies of its argument. The returned computation has the type signature `(T -> <T, T, T, ...>)`, where `T` is `type_signature` and the length of the result is `count`. Args: type_signature: A `computation_types.Type` to replicate. count: An integer, the number of times the input is replicated. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings or if `which` is not an integer. """ type_analysis.check_tensorflow_compatible_type(type_signature) py_typecheck.check_type(count, int) parameter_type = type_signature return create_computation_for_py_fn(lambda v: [v] * count, parameter_type)
def create_replicate_input(type_signature: computation_types.Type, count: int) -> ComputationProtoAndType: """Returns a tensorflow computation returning `count` copies of its argument. The returned computation has the type signature `(T -> <T, T, T, ...>)`, where `T` is `type_signature` and the length of the result is `count`. Args: type_signature: A `computation_types.Type` to replicate. count: An integer, the number of times the input is replicated. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings or if `which` is not an integer. """ type_analysis.check_tensorflow_compatible_type(type_signature) py_typecheck.check_type(count, int) parameter_type = type_signature identity_comp, _ = create_identity(parameter_type) # This manual proto manipulation is significantly faster than using TFF's # GraphDef serialization for large `count` arguments. tensorflow_comp = identity_comp.tensorflow single_result_binding = tensorflow_comp.result if tensorflow_comp.parameter: new_tf_pb = pb.TensorFlow( graph_def=tensorflow_comp.graph_def, parameter=tensorflow_comp.parameter, result=pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=(single_result_binding for _ in range(count))))) else: new_tf_pb = pb.TensorFlow( graph_def=tensorflow_comp.graph_def, result=pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=(single_result_binding for _ in range(count))))) fn_type = computation_types.FunctionType( parameter_type, computation_types.StructType([(None, parameter_type) for _ in range(count) ])) return _tensorflow_comp(new_tf_pb, fn_type)
def create_binary_operator_with_upcast( type_signature: computation_types.StructType, operator: Callable[[Any, Any], Any]) -> ProtoAndType: """Creates TF computation upcasting its argument and applying `operator`. Args: type_signature: A `computation_types.StructType` with two elements, both of the same type or the second able to be upcast to the first, as explained in `apply_binary_operator_with_upcast`, and both containing only tuples and tensors in their type tree. operator: Callable defining the operator. Returns: A `building_blocks.CompiledComputation` encapsulating a function which upcasts the second element of its argument and applies the binary operator. """ py_typecheck.check_type(type_signature, computation_types.StructType) py_typecheck.check_callable(operator) type_analysis.check_tensorflow_compatible_type(type_signature) if not type_signature.is_struct() or len(type_signature) != 2: raise TypeError('To apply a binary operator, we must by definition have an ' 'argument which is a `StructType` with 2 elements; ' 'asked to create a binary operator for type: {t}'.format( t=type_signature)) if type_analysis.contains(type_signature, lambda t: t.is_sequence()): raise TypeError( 'Applying binary operators in TensorFlow is only ' 'supported on Tensors and StructTypes; you ' 'passed {t} which contains a SequenceType.'.format(t=type_signature)) def _pack_into_type(to_pack, type_spec): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if type_spec.is_struct(): elem_iter = structure.iter_elements(type_spec) return structure.Struct([(elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elem_iter]) elif type_spec.is_tensor(): return tf.broadcast_to(to_pack, type_spec.shape) with tf.Graph().as_default() as graph: first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', type_signature[0], graph) operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', type_signature[1], graph) if type_signature[0].is_equivalent_to(type_signature[1]): second_arg = operand_2_value else: second_arg = _pack_into_type(operand_2_value, type_signature[0]) if type_signature[0].is_tensor(): result_value = operator(first_arg, second_arg) elif type_signature[0].is_struct(): result_value = structure.map_structure(operator, first_arg, second_arg) else: raise TypeError('Encountered unexpected type {t}; can only handle Tensor ' 'and StructTypes.'.format(t=type_signature[0])) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType(type_signature, result_type) parameter_binding = pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=[operand_1_binding, operand_2_binding])) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def _get_type_info(initialize_tree, before_broadcast, after_broadcast, before_aggregate, after_aggregate): """Returns type information for an `tff.templates.IterativeProcess`. This function is intended to be used by `get_canonical_form_for_iterative_process` to create the expected type signatures when compiling a given `tff.templates.IterativeProcess` into a `tff.backends.mapreduce.CanonicalForm` and returns a `collections.OrderedDict` whose keys and order match the explicit and intermediate componets of `tff.backends.mapreduce.CanonicalForm` defined here: ``` s1 = arg[0] c1 = arg[1] s2 = intrinsics.federated_map(cf.prepare, s1) c2 = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([c1, c2]) c4 = intrinsics.federated_map(cf.work, c3) c5 = c4[0] c6 = c4[1] s3 = intrinsics.federated_aggregate(c5, cf.zero(), cf.accumulate, cf.merge, cf.report) s4 = intrinsics.federated_secure_sum(c6, cf.bitwidth()) s5 = intrinsics.federated_zip([s3, s4]) s6 = intrinsics.federated_zip([s1, s5]) s7 = intrinsics.federated_map(cf.update, s6) s8 = s7[0] s9 = s7[1] ``` Note that the type signatures for the `initalize` and `next` components of an `tff.templates.IterativeProcess` are: initalize: `( -> s1)` next: `(<s1,c1> -> <s8,s9>)` However, the `next` component of an `tff.templates.IterativeProcess` has been split into a before and after broadcast and a before and after aggregate with the given semantics: ``` (arg -> after(<arg, intrinsic(before(arg))>)) ``` as a result, the type signatures for the components split from the `next` component of an `tff.templates.IterativeProcess` are: before_broadcast: `(<s1,c1> -> s2)` after_broadcast: `(<<s1,c1>,c2> -> <s8,s9>)` before_aggregate: `(<<s1,c1>,c2> -> <<c5,zero,accumulate,merge,report>,c6>)` after_aggregate: `(<<<s1,c1>,c2>,<s3,s4>> -> <s8,s9>)` Args: initialize_tree: An instance of `building_blocks.ComputationBuildingBlock` representing the `initalize` component of an `tff.templates.IterativeProcess`. before_broadcast: The first result of splitting `next` component of an `tff.templates.IterativeProcess` on broadcast. after_broadcast: The second result of splitting `next` component of an `tff.templates.IterativeProcess` on broadcast. before_aggregate: The first result of splitting `next` component of an `tff.templates.IterativeProcess` on aggregate. after_aggregate: The second result of splitting `next` component of an `tff.templates.IterativeProcess` on aggregate. Raises: transformations.CanonicalFormCompilationError: If the arguments are of the wrong types. """ # The type signature of `initalize` is: `( -> s1)`. init_tree_ty = initialize_tree.type_signature _check_type_is_no_arg_fn(init_tree_ty) _check_type(init_tree_ty.result, computation_types.FederatedType) _check_placement(init_tree_ty.result, placements.SERVER) # The named components of canonical form have no placement, so we must # remove the placement on the return type of initialize_tree initialize_type = computation_types.FunctionType( initialize_tree.type_signature.parameter, initialize_tree.type_signature.result.member) # The type signature of `before_broadcast` is: `(<s1,c1> -> s2)`. _check_type(before_broadcast.type_signature, computation_types.FunctionType) _check_type(before_broadcast.type_signature.parameter, computation_types.StructType) _check_len(before_broadcast.type_signature.parameter, 2) s1_type = before_broadcast.type_signature.parameter[0] _check_type(s1_type, computation_types.FederatedType) _check_placement(s1_type, placements.SERVER) c1_type = before_broadcast.type_signature.parameter[1] _check_type(c1_type, computation_types.FederatedType) _check_placement(c1_type, placements.CLIENTS) s2_type = before_broadcast.type_signature.result _check_type(s2_type, computation_types.FederatedType) _check_placement(s2_type, placements.SERVER) prepare_type = computation_types.FunctionType(s1_type.member, s2_type.member) # The type signature of `after_broadcast` is: `(<<s1,c1>,c2> -> <s8,s9>)'. _check_type(after_broadcast.type_signature, computation_types.FunctionType) _check_type(after_broadcast.type_signature.parameter, computation_types.StructType) _check_len(after_broadcast.type_signature.parameter, 2) _check_type(after_broadcast.type_signature.parameter[0], computation_types.StructType) _check_len(after_broadcast.type_signature.parameter[0], 2) _check_type_equal(after_broadcast.type_signature.parameter[0][0], s1_type) _check_type_equal(after_broadcast.type_signature.parameter[0][1], c1_type) c2_type = after_broadcast.type_signature.parameter[1] _check_type(c2_type, computation_types.FederatedType) _check_placement(c2_type, placements.CLIENTS) _check_type(after_broadcast.type_signature.result, computation_types.StructType) _check_len(after_broadcast.type_signature.result, 2) s8_type = after_broadcast.type_signature.result[0] _check_type(s8_type, computation_types.FederatedType) _check_placement(s8_type, placements.SERVER) s9_type = after_broadcast.type_signature.result[1] _check_type(s9_type, computation_types.FederatedType) _check_placement(s9_type, placements.SERVER) # The type signature of `before_aggregate` is: # `(<<s1,c1>,c2> -> <<c5,zero,accumulate,merge,report>,<c6,bitwidth>>)`. _check_type(before_aggregate.type_signature, computation_types.FunctionType) _check_type(before_aggregate.type_signature.parameter, computation_types.StructType) _check_len(before_aggregate.type_signature.parameter, 2) _check_type(before_aggregate.type_signature.parameter[0], computation_types.StructType) _check_len(before_aggregate.type_signature.parameter[0], 2) _check_type_equal(before_aggregate.type_signature.parameter[0][0], s1_type) _check_type_equal(before_aggregate.type_signature.parameter[0][1], c1_type) _check_type_equal(before_aggregate.type_signature.parameter[1], c2_type) _check_type(before_aggregate.type_signature.result, computation_types.StructType) _check_len(before_aggregate.type_signature.result, 2) _check_len(before_aggregate.type_signature.result[0], 5) c5_type = before_aggregate.type_signature.result[0][0] _check_type(c5_type, computation_types.FederatedType) _check_placement(c5_type, placements.CLIENTS) zero_type = computation_types.FunctionType( None, before_aggregate.type_signature.result[0][1]) type_analysis.check_tensorflow_compatible_type(zero_type.result) accumulate_type = before_aggregate.type_signature.result[0][2] _check_type(accumulate_type, computation_types.FunctionType) merge_type = before_aggregate.type_signature.result[0][3] _check_type(merge_type, computation_types.FunctionType) report_type = before_aggregate.type_signature.result[0][4] _check_type(report_type, computation_types.FunctionType) _check_type(before_aggregate.type_signature.result[1], computation_types.StructType) _check_len(before_aggregate.type_signature.result[1], 2) c6_type = before_aggregate.type_signature.result[1][0] _check_type(c6_type, computation_types.FederatedType) _check_placement(c6_type, placements.CLIENTS) bitwidth_type = computation_types.FunctionType( None, before_aggregate.type_signature.result[1][1]) type_analysis.check_tensorflow_compatible_type(bitwidth_type.result) c3_type = computation_types.FederatedType([c1_type.member, c2_type.member], placements.CLIENTS) c4_type = computation_types.FederatedType([c5_type.member, c6_type.member], placements.CLIENTS) # The type signature of `after_aggregate` is: # `(<<<s1,c1>,c2>,<s3,s4>> -> <s8,s9>)'. _check_type(after_aggregate.type_signature, computation_types.FunctionType) _check_type(after_aggregate.type_signature.parameter, computation_types.StructType) _check_len(after_aggregate.type_signature.parameter, 2) _check_type(after_aggregate.type_signature.parameter[0], computation_types.StructType) _check_len(after_aggregate.type_signature.parameter[0], 2) _check_type(after_aggregate.type_signature.parameter[0][0], computation_types.StructType) _check_len(after_aggregate.type_signature.parameter[0][0], 2) _check_type_equal(after_aggregate.type_signature.parameter[0][0][0], s1_type) _check_type_equal(after_aggregate.type_signature.parameter[0][0][1], c1_type) _check_type_equal(after_aggregate.type_signature.parameter[0][1], c2_type) _check_len(after_aggregate.type_signature.parameter[1], 2) s3_type = after_aggregate.type_signature.parameter[1][0] _check_type(s3_type, computation_types.FederatedType) _check_placement(s3_type, placements.SERVER) s4_type = after_aggregate.type_signature.parameter[1][1] _check_type(s4_type, computation_types.FederatedType) _check_placement(s4_type, placements.SERVER) _check_len(after_aggregate.type_signature.result, 2) _check_type_equal(after_aggregate.type_signature.result[0], s8_type) _check_type_equal(after_aggregate.type_signature.result[1], s9_type) work_type = computation_types.FunctionType(c3_type.member, c4_type.member) s5_type = computation_types.FederatedType([s3_type.member, s4_type.member], placements.SERVER) s6_type = computation_types.FederatedType([s1_type.member, s5_type.member], placements.SERVER) s7_type = computation_types.FederatedType([s8_type.member, s9_type.member], placements.SERVER) update_type = computation_types.FunctionType(s6_type.member, s7_type.member) return collections.OrderedDict( initialize_type=initialize_type, s1_type=s1_type, c1_type=c1_type, prepare_type=prepare_type, s2_type=s2_type, c2_type=c2_type, c3_type=c3_type, work_type=work_type, c4_type=c4_type, c5_type=c5_type, c6_type=c6_type, zero_type=zero_type, accumulate_type=accumulate_type, merge_type=merge_type, report_type=report_type, s3_type=s3_type, bitwidth_type=bitwidth_type, s4_type=s4_type, s5_type=s5_type, s6_type=s6_type, update_type=update_type, s7_type=s7_type, s8_type=s8_type, s9_type=s9_type, )
def _deserialize_dataset_from_graph_def(serialized_graph_def: bytes, element_type: computation_types.Type): """Deserializes a serialized `tf.compat.v1.GraphDef` to a `tf.data.Dataset`. Args: serialized_graph_def: `bytes` object produced by `tensorflow_serialization.serialize_dataset` element_type: a `tff.Type` object representing the type structure of the elements yielded from the dataset. Returns: A `tf.data.Dataset` instance. """ py_typecheck.check_type(element_type, computation_types.Type) type_analysis.check_tensorflow_compatible_type(element_type) def transform_to_tff_known_type( type_spec: computation_types.Type ) -> Tuple[computation_types.Type, bool]: """Transforms `StructType` to `StructWithPythonType`.""" if type_spec.is_struct() and not type_spec.is_struct_with_python(): field_is_named = tuple( name is not None for name, _ in structure.iter_elements(type_spec)) has_names = any(field_is_named) is_all_named = all(field_is_named) if is_all_named: return computation_types.StructWithPythonType( elements=structure.iter_elements(type_spec), container_type=collections.OrderedDict), True elif not has_names: return computation_types.StructWithPythonType( elements=structure.iter_elements(type_spec), container_type=tuple), True else: raise TypeError( 'Cannot represent TFF type in TF because it contains ' f'partially named structures. Type: {type_spec}') return type_spec, False if element_type.is_struct(): # TF doesn't suppor `structure.Strut` types, so we must transform the # `StructType` into a `StructWithPythonType` for use as the # `tf.data.Dataset.element_spec` later. tf_compatible_type, _ = type_transformations.transform_type_postorder( element_type, transform_to_tff_known_type) else: # We've checked this is only a struct or tensors, so we know this is a # `TensorType` here and will use as-is. tf_compatible_type = element_type def type_to_tensorspec(t: computation_types.TensorType) -> tf.TensorSpec: return tf.TensorSpec(shape=t.shape, dtype=t.dtype) element_spec = type_conversions.structure_from_tensor_type_tree( type_to_tensorspec, tf_compatible_type) ds = tf.data.experimental.from_variant( tf.raw_ops.DatasetFromGraph(graph_def=serialized_graph_def), structure=element_spec) # If a serialized dataset had elements of nested structes of tensors (e.g. # `dict`, `OrderedDict`), the deserialized dataset will return `dict`, # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion). # # Since the dataset will only be used inside TFF, we wrap the dictionary # coming from TF in an `OrderedDict` when necessary (a type that both TF and # TFF understand), using the field order stored in the TFF type stored during # serialization. return tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( ds, tf_compatible_type)
def create_binary_operator_with_upcast( type_signature: computation_types.StructType, operator: Callable[[Any, Any], Any]) -> ComputationProtoAndType: """Creates TF computation upcasting its argument and applying `operator`. Args: type_signature: A `computation_types.StructType` with two elements, both only containing structs or tensors in their type tree. The first and second element must match in structure, or the second element may be a single tensor type that is broadcasted (upcast) to the leaves of the structure of the first type. operator: Callable defining the operator. Returns: Same as `create_binary_operator()`. """ py_typecheck.check_type(type_signature, computation_types.StructType) py_typecheck.check_callable(operator) type_analysis.check_tensorflow_compatible_type(type_signature) if not type_signature.is_struct() or len(type_signature) != 2: raise TypeError( 'To apply a binary operator, we must by definition have an ' 'argument which is a `StructType` with 2 elements; ' 'asked to create a binary operator for type: {t}'.format( t=type_signature)) if type_analysis.contains(type_signature, lambda t: t.is_sequence()): raise TypeError('Applying binary operators in TensorFlow is only ' 'supported on Tensors and StructTypes; you ' 'passed {t} which contains a SequenceType.'.format( t=type_signature)) def _pack_into_type(to_pack, type_spec): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if type_spec.is_struct(): elem_iter = structure.iter_elements(type_spec) return structure.Struct([(elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elem_iter]) elif type_spec.is_tensor(): return tf.broadcast_to(to_pack, type_spec.shape) with tf.Graph().as_default() as graph: first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', type_signature[0], graph) operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', type_signature[1], graph) if type_signature[0].is_struct() and type_signature[1].is_struct(): # If both the first and second arguments are structs with the same # structure, simply re-use operand_2_value as. `tf.nest.map_structure` # below will map the binary operator pointwise to the leaves of the # structure. if structure.is_same_structure(type_signature[0], type_signature[1]): second_arg = operand_2_value else: raise TypeError( 'Cannot upcast one structure to a different structure. ' '{x} -> {y}'.format(x=type_signature[1], y=type_signature[0])) elif type_signature[0].is_equivalent_to(type_signature[1]): second_arg = operand_2_value else: second_arg = _pack_into_type(operand_2_value, type_signature[0]) if type_signature[0].is_tensor(): result_value = operator(first_arg, second_arg) elif type_signature[0].is_struct(): result_value = structure.map_structure(operator, first_arg, second_arg) else: raise TypeError( 'Encountered unexpected type {t}; can only handle Tensor ' 'and StructTypes.'.format(t=type_signature[0])) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType(type_signature, result_type) parameter_binding = pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=[operand_1_binding, operand_2_binding])) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)