def type_to_tf_structure(type_spec: computation_types.Type): """Returns nested `tf.data.experimental.Structure` for a given TFF type. Args: type_spec: A `computation_types.Type`, the type specification must be composed of only named tuples and tensors. In all named tuples that appear in the type spec, all the elements must be named. Returns: An instance of `tf.data.experimental.Structure`, possibly nested, that corresponds to `type_spec`. Raises: ValueError: if the `type_spec` is composed of something other than named tuples and tensors, or if any of the elements in named tuples are unnamed. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return tf.TensorSpec(type_spec.shape, type_spec.dtype) elif type_spec.is_tuple(): elements = anonymous_tuple.to_elements(type_spec) if not elements: raise ValueError('Empty tuples are unsupported.') element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements] named = element_outputs[0][0] is not None if not all((e[0] is not None) == named for e in element_outputs): raise ValueError('Tuple elements inconsistently named.') if not type_spec.is_tuple_with_py_container(): if named: output = collections.OrderedDict(element_outputs) else: output = tuple(v for _, v in element_outputs) else: container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type( type_spec) if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): output = container_type(**dict(element_outputs)) elif named: output = container_type(element_outputs) else: output = container_type(e if e[0] is not None else e[1] for e in element_outputs) return output else: raise ValueError('Unsupported type {}.'.format( py_typecheck.type_string(type(type_spec))))
def transform_type_postorder( type_signature: computation_types.Type, transform_fn: Callable[[computation_types.Type], Tuple[computation_types.Type, bool]]): """Walks type tree of `type_signature` postorder, calling `transform_fn`. Args: type_signature: Instance of `computation_types.Type` to transform recursively. transform_fn: Transformation function to apply to each node in the type tree of `type_signature`. Must be instance of Python function type. Returns: A possibly transformed version of `type_signature`, with each node in its tree the result of applying `transform_fn` to the corresponding node in `type_signature`. Raises: TypeError: If the types don't match the specification above. """ py_typecheck.check_type(type_signature, computation_types.Type) py_typecheck.check_callable(transform_fn) if type_signature.is_federated(): transformed_member, member_mutated = transform_type_postorder( type_signature.member, transform_fn) if member_mutated: type_signature = computation_types.FederatedType( transformed_member, type_signature.placement, type_signature.all_equal) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or member_mutated elif type_signature.is_sequence(): transformed_element, element_mutated = transform_type_postorder( type_signature.element, transform_fn) if element_mutated: type_signature = computation_types.SequenceType( transformed_element) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or element_mutated elif type_signature.is_function(): if type_signature.parameter is not None: transformed_parameter, parameter_mutated = transform_type_postorder( type_signature.parameter, transform_fn) else: transformed_parameter, parameter_mutated = (None, False) transformed_result, result_mutated = transform_type_postorder( type_signature.result, transform_fn) if parameter_mutated or result_mutated: type_signature = computation_types.FunctionType( transformed_parameter, transformed_result) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, (type_signature_mutated or parameter_mutated or result_mutated) elif type_signature.is_tuple(): elements = [] elements_mutated = False for element in anonymous_tuple.iter_elements(type_signature): transformed_element, element_mutated = transform_type_postorder( element[1], transform_fn) elements_mutated = elements_mutated or element_mutated elements.append((element[0], transformed_element)) if elements_mutated: if type_signature.is_tuple_with_py_container(): container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type( type_signature) type_signature = computation_types.NamedTupleTypeWithPyContainerType( elements, container_type) else: type_signature = computation_types.NamedTupleType(elements) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or elements_mutated elif type_signature.is_abstract() or type_signature.is_placement( ) or type_signature.is_tensor(): return transform_fn(type_signature)
def type_to_tf_dtypes_and_shapes(type_spec: computation_types.Type): """Returns nested structures of tensor dtypes and shapes for a given TFF type. The returned dtypes and shapes match those used by `tf.data.Dataset`s to indicate the type and shape of their elements. They can be used, e.g., as arguments in constructing an iterator over a string handle. Args: type_spec: A `computation_types.Type`, the type specification must be composed of only named tuples and tensors. In all named tuples that appear in the type spec, all the elements must be named. Returns: A pair of parallel nested structures with the dtypes and shapes of tensors defined in `type_spec`. The layout of the two structures returned is the same as the layout of the nested type defined by `type_spec`. Named tuples are represented as dictionaries. Raises: ValueError: if the `type_spec` is composed of something other than named tuples and tensors, or if any of the elements in named tuples are unnamed. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return (type_spec.dtype, type_spec.shape) elif type_spec.is_tuple(): elements = anonymous_tuple.to_elements(type_spec) if not elements: output_dtypes = [] output_shapes = [] elif elements[0][0] is not None: output_dtypes = collections.OrderedDict() output_shapes = collections.OrderedDict() for e in elements: element_name = e[0] element_spec = e[1] if element_name is None: raise ValueError( 'When a sequence appears as a part of a parameter to a section ' 'of TensorFlow code, in the type signature of elements of that ' 'sequence all named tuples must have their elements explicitly ' 'named, and this does not appear to be the case in {}.' .format(type_spec)) element_output = type_to_tf_dtypes_and_shapes(element_spec) output_dtypes[element_name] = element_output[0] output_shapes[element_name] = element_output[1] else: output_dtypes = [] output_shapes = [] for e in elements: element_name = e[0] element_spec = e[1] if element_name is not None: raise ValueError( 'When a sequence appears as a part of a parameter to a section ' 'of TensorFlow code, in the type signature of elements of that ' 'sequence all named tuples must have their elements explicitly ' 'named, and this does not appear to be the case in {}.' .format(type_spec)) element_output = type_to_tf_dtypes_and_shapes(element_spec) output_dtypes.append(element_output[0]) output_shapes.append(element_output[1]) if type_spec.is_tuple_with_py_container(): container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type( type_spec) def build_py_container(elements): if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): return container_type(**dict(elements)) else: return container_type(elements) output_dtypes = build_py_container(output_dtypes) output_shapes = build_py_container(output_shapes) else: output_dtypes = tuple(output_dtypes) output_shapes = tuple(output_shapes) return (output_dtypes, output_shapes) else: raise ValueError('Unsupported type {}.'.format( py_typecheck.type_string(type(type_spec))))