def _create_structure_of_coro_references( coro: Coroutine[Any, Any, Any], type_signature: computation_types.Type) -> Any: """Returns a structure of `tff.program.CoroValueReference`s.""" py_typecheck.check_type(type_signature, computation_types.Type) if type_signature.is_struct(): async def _to_structure(coro: Coroutine[Any, Any, Any]) -> structure.Struct: return structure.from_container(await coro) coro = _to_structure(coro) shared_awaitable = async_utils.SharedAwaitable(coro) async def _get_item(awaitable: Awaitable[structure.Struct], index: int) -> Any: value = await awaitable return value[index] elements = [] element_types = structure.iter_elements(type_signature) for index, (name, element_type) in enumerate(element_types): element_coro = _get_item(shared_awaitable, index) element = _create_structure_of_coro_references(element_coro, element_type) elements.append((name, element)) return structure.Struct(elements) elif (type_signature.is_federated() and type_signature.placement == placements.SERVER): return _create_structure_of_coro_references(coro, type_signature.member) elif type_signature.is_sequence(): return CoroValueReference(coro, type_signature) elif type_signature.is_tensor(): return CoroValueReference(coro, type_signature) else: raise NotImplementedError(f'Unexpected type found: {type_signature}.')
def is_valid_bitwidth_type_for_value_type( bitwidth_type: computation_types.Type, value_type: computation_types.Type) -> bool: """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`.""" py_typecheck.check_type(bitwidth_type, computation_types.Type) py_typecheck.check_type(value_type, computation_types.Type) if bitwidth_type.is_tensor(): # This condition applies to both `value_type` being a tensor or structure, # as the same integer bitwidth can be used for all values in the structure. return bitwidth_type.dtype.is_integer and ( bitwidth_type.shape.num_elements() == 1) elif value_type.is_struct() and bitwidth_type.is_struct(): bitwidth_name_and_types = list(structure.iter_elements(bitwidth_type)) value_name_and_types = list(structure.iter_elements(value_type)) if len(bitwidth_name_and_types) != len(value_name_and_types): return False for (inner_bitwidth_name, inner_bitwidth_type), (inner_value_name, inner_value_type) in zip( bitwidth_name_and_types, value_name_and_types): if inner_bitwidth_name != inner_value_name: return False if not is_valid_bitwidth_type_for_value_type(inner_bitwidth_type, inner_value_type): return False return True else: return False
def is_single_integer_or_matches_structure( type_sig: computation_types.Type, shape_type: computation_types.Type) -> bool: """If `type_sig` is an integer or integer structure matching `shape_type`.""" py_typecheck.check_type(type_sig, computation_types.Type) py_typecheck.check_type(shape_type, computation_types.Type) if type_sig.is_tensor(): # This condition applies to both `shape_type` being a tensor or structure, # as the same integer bitwidth can be used for all values in the structure. return type_sig.dtype.is_integer and (type_sig.shape.num_elements() == 1) elif shape_type.is_struct() and type_sig.is_struct(): bitwidth_name_and_types = list(structure.iter_elements(type_sig)) shape_name_and_types = list(structure.iter_elements(shape_type)) if len(type_sig) != len(shape_name_and_types): return False for (inner_name, type_sig), (inner_shape_name, inner_shape_type) in zip( bitwidth_name_and_types, shape_name_and_types): if inner_name != inner_shape_name: return False if not is_single_integer_or_matches_structure(type_sig, inner_shape_type): return False return True else: return False
async def _materialize_structure_of_value_references( value: Any, type_signature: computation_types.Type) -> Any: """Returns a structure of materialized values.""" py_typecheck.check_type(type_signature, computation_types.Type) async def _materialize(value: Any) -> Any: if isinstance(value, value_reference.MaterializableValueReference): return await value.get_value() else: return value if type_signature.is_struct(): value = structure.from_container(value) element_types = list(structure.iter_elements(type_signature)) element_coros = [ _materialize_structure_of_value_references(v, t) for v, (_, t) in zip(value, element_types) ] elements = await asyncio.gather(*element_coros) elements = [(n, v) for v, (n, _) in zip(elements, element_types)] return structure.Struct(elements) elif (type_signature.is_federated() and type_signature.placement == placements.SERVER): return await _materialize_structure_of_value_references( value, type_signature.member) elif type_signature.is_sequence(): return await _materialize(value) elif type_signature.is_tensor(): return await _materialize(value) else: return value
def _repackage_partitioned_values(after_merge_results, result_type_spec: computation_types.Type): """Inverts `_split_value_into_subrounds` above.""" py_typecheck.check_type(after_merge_results, list) if result_type_spec.is_struct(): after_merge_structs = [ structure.from_container(x) for x in after_merge_results ] result_container = [] for idx, (name, elem_type) in enumerate( structure.iter_elements(result_type_spec)): result_container.append( (name, _repackage_partitioned_values( [x[idx] for x in after_merge_structs], elem_type))) return structure.Struct(result_container) elif result_type_spec.is_federated( ) and result_type_spec.placement.is_clients(): if result_type_spec.all_equal: return after_merge_results[0] for x in after_merge_results: py_typecheck.check_type(x, (list, tuple)) # Merges all clients-placed values back together. return functools.reduce(lambda x, y: x + y, after_merge_results) else: return after_merge_results[0]
def create_identity( type_signature: computation_types.Type) -> ComputationProtoAndType: """Returns a tensorflow computation representing an identity function. The returned computation has the type signature `(T -> T)`, where `T` is `type_signature`. NOTE: if `T` contains `computation_types.StructType`s without an associated container type, they will be given the container type `tuple` by this function. Args: type_signature: A `computation_types.Type` to use as the parameter type and result type of the identity function. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings. """ type_analysis.check_tensorflow_compatible_type(type_signature) parameter_type = type_signature if parameter_type is None: raise TypeError('TensorFlow identity cannot be created for NoneType.') # TF relies on feeds not-identical to fetches in certain circumstances. if type_signature.is_tensor() or type_signature.is_sequence(): identity_fn = tf.identity elif type_signature.is_struct(): identity_fn = functools.partial(structure.map_structure, tf.identity) else: raise NotImplementedError( f'TensorFlow identity cannot be created for type {type_signature}') return create_computation_for_py_fn(identity_fn, parameter_type)
def is_allowed_client_data_type( type_spec: computation_types.Type) -> bool: if type_spec.is_sequence(): return type_analysis.is_tensorflow_compatible_type( type_spec.element) elif type_spec.is_struct(): return all( is_allowed_client_data_type(element_type) for element_type in type_spec.children()) else: return False
def ingest_value( self, value: Any, type_signature: computation_types.Type ) -> executor_value_base.ExecutorValue: if type_signature is not None: if type_signature.is_federated(): self._check_strategy_compatible_with_placement( type_signature.placement) elif type_signature.is_function( ) and type_signature.result.is_federated(): self._check_strategy_compatible_with_placement( type_signature.result.placement) return FederatedResolvingStrategyValue(value, type_signature)
def _to_sequence_internal_rep( *, value: Any, type_spec: computation_types.Type) -> tf.data.Dataset: """Ingests `value`, converting to an eager dataset.""" if isinstance(value, list): value = tensorflow_utils.make_data_set_from_elements( None, value, type_spec.element) if isinstance(value, type_conversions.TF_DATASET_REPRESENTATION_TYPES): element_type = computation_types.to_type(value.element_spec) value_type = computation_types.SequenceType(element_type) type_spec.check_assignable_from(value_type) return value py_typecheck.check_type(type_spec, computation_types.SequenceType) output_sig = type_conversions.type_to_tf_tensor_specs(type_spec.element) return tf.data.Dataset.from_generator(value, output_signature=output_sig)
def _ensure_deserialized_types_compatible( previous_type: Optional[computation_types.Type], next_type: computation_types.Type) -> computation_types.Type: """Ensures one of `previous_type` or `next_type` is assignable to the other. Returns the type which is assignable from the other. Args: previous_type: Instance of `computation_types.Type` or `None`. next_type: Instance of `computation_types.Type`. Returns: The supertype of `previous_type` and `next_type`. Raises: TypeError if neither type is assignable from the other. """ if previous_type is None: return next_type else: if next_type.is_assignable_from(previous_type): return next_type elif previous_type.is_assignable_from(next_type): return previous_type raise TypeError( 'Type mismatch checking member assignability under a ' 'federated value. Deserialized type {} is incompatible ' 'with previously deserialized {}.'.format(next_type, previous_type))
def _stamp_value_into_graph(value: Any, type_signature: computation_types.Type, graph: tf.Graph) -> Any: """Stamps `value` in `graph` as an object of type `type_signature`. Args: value: A value to stamp. type_signature: The type of the value to stamp. graph: The graph to stamp in. Returns: A Python object made of tensors stamped into `graph`, `tf.data.Dataset`s, or `structure.Struct`s that structurally corresponds to the value passed at input. """ if value is None: return None if type_signature.is_tensor(): if isinstance(value, np.ndarray) or tf.is_tensor(value): value_type = computation_types.TensorType( tf.dtypes.as_dtype(value.dtype), tf.TensorShape(value.shape)) type_signature.check_assignable_from(value_type) with graph.as_default(): return tf.constant(value) else: with graph.as_default(): return tf.constant(value, dtype=type_signature.dtype, shape=type_signature.shape) elif type_signature.is_struct(): if isinstance(value, (list, dict)): value = structure.from_container(value) stamped_elements = [] named_type_signatures = structure.to_elements(type_signature) for (name, type_signature), element in zip(named_type_signatures, value): stamped_element = _stamp_value_into_graph(element, type_signature, graph) stamped_elements.append((name, stamped_element)) return structure.Struct(stamped_elements) elif type_signature.is_sequence(): return tensorflow_utils.make_data_set_from_elements( graph, value, type_signature.element) else: raise NotImplementedError( 'Unable to stamp a value of type {} in graph.'.format( type_signature))
def type_to_tf_structure(type_spec: computation_types.Type): """Returns nested `tf.data.experimental.Structure` for a given TFF type. Args: type_spec: A `computation_types.Type`, the type specification must be composed of only named tuples and tensors. In all named tuples that appear in the type spec, all the elements must be named. Returns: An instance of `tf.data.experimental.Structure`, possibly nested, that corresponds to `type_spec`. Raises: ValueError: if the `type_spec` is composed of something other than named tuples and tensors, or if any of the elements in named tuples are unnamed. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return tf.TensorSpec(type_spec.shape, type_spec.dtype) elif type_spec.is_struct(): elements = structure.to_elements(type_spec) if not elements: raise ValueError('Empty tuples are unsupported.') element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements] named = element_outputs[0][0] is not None if not all((e[0] is not None) == named for e in element_outputs): raise ValueError('Tuple elements inconsistently named.') if type_spec.python_container is None: if named: output = collections.OrderedDict(element_outputs) else: output = tuple(v for _, v in element_outputs) else: container_type = type_spec.python_container if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): output = container_type(**dict(element_outputs)) elif named: output = container_type(element_outputs) else: output = container_type( e if e[0] is not None else e[1] for e in element_outputs) return output else: raise ValueError('Unsupported type {}.'.format( py_typecheck.type_string(type(type_spec))))
def _pack_into_type(to_pack: tf.Tensor, type_spec: computation_types.Type): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if type_spec.is_struct(): elem_iter = structure.iter_elements(type_spec) return structure.Struct([(elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elem_iter]) elif type_spec.is_tensor(): value_tensor_type = type_conversions.type_from_tensors(to_pack) if type_spec.is_assignable_from(value_tensor_type): return to_pack elif not type_spec.shape.is_fully_defined(): raise TypeError('Cannot generate TensorFlow creating binary operator ' 'with first type not assignable from second, and ' 'first type without fully defined shapes. First ' f'type contains an element of type: {type_spec}.\n' f'Packing value {to_pack} into this type is ' 'undefined.') return tf.cast(tf.broadcast_to(to_pack, type_spec.shape), type_spec.dtype)
def _check_type_is_fn( target: computation_types.Type, name: str, err_fn: Callable[[str], Exception] = compiler.MapReduceFormCompilationError, ): if not target.is_function(): raise err_fn(f'Expected {name} to be a function, but {name} had type ' f'{target}.')
def _partition_value( val: _PartitioningValue, type_signature: computation_types.Type) -> _PartitioningValue: """Partitions value as specified in _split_value_into_subrounds.""" if type_signature.is_struct(): struct_val = structure.from_container(val.payload) result_container = [] for (_, val_elem), (name, type_elem) in zip( structure.iter_elements(struct_val), structure.iter_elements(type_signature)): partitioning_val_elem = _PartitioningValue( val_elem, val.num_remaining_clients, val.num_remaining_partitions, val.last_client_index) partition_result = _partition_value(partitioning_val_elem, type_elem) result_container.append((name, partition_result.payload)) return _PartitioningValue(structure.Struct(result_container), partition_result.num_remaining_clients, partition_result.num_remaining_partitions, partition_result.last_client_index) elif (type_signature.is_federated() and type_signature.placement.is_clients()): if type_signature.all_equal: # In this case we simply replicate the argument for every subround. return val py_typecheck.check_type(val.payload, Sequence) num_clients_for_subround = math.ceil(val.num_remaining_clients / val.num_remaining_partitions) num_remaining_clients = val.num_remaining_clients - num_clients_for_subround num_remaining_partitions = val.num_remaining_partitions - 1 values_to_return = val.payload[val.last_client_index:val. last_client_index + num_clients_for_subround] last_client_index = val.last_client_index + num_clients_for_subround return _PartitioningValue( payload=values_to_return, num_remaining_clients=num_remaining_clients, num_remaining_partitions=num_remaining_partitions, last_client_index=last_client_index) else: return val
def transform_to_tff_known_type( type_spec: computation_types.Type ) -> Tuple[computation_types.Type, bool]: """Transforms `StructType` to `StructWithPythonType`.""" if type_spec.is_struct() and not type_spec.is_struct_with_python(): field_is_named = tuple( name is not None for name, _ in structure.iter_elements(type_spec)) has_names = any(field_is_named) is_all_named = all(field_is_named) if is_all_named: return computation_types.StructWithPythonType( elements=structure.iter_elements(type_spec), container_type=collections.OrderedDict), True elif not has_names: return computation_types.StructWithPythonType( elements=structure.iter_elements(type_spec), container_type=tuple), True else: raise TypeError( 'Cannot represent TFF type in TF because it contains ' f'partially named structures. Type: {type_spec}') return type_spec, False
def _unique_dtypes_in_structure( type_spec: computation_types.Type) -> Set[tf.DType]: """Returns a set of unique dtypes in `type_spec`. Args: type_spec: A `computation_types.Type`. Returns: A `set` containing unique dtypes found in `type_spec`. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): py_typecheck.check_type(type_spec.dtype, tf.DType) return set([type_spec.dtype]) elif type_spec.is_struct(): return set( tf.nest.flatten( type_conversions.structure_from_tensor_type_tree( lambda t: t.dtype, type_spec))) elif type_spec.is_federated(): return _unique_dtypes_in_structure(type_spec.member) else: return set()
def is_binary_op_with_upcast_compatible_pair( possibly_nested_type: Optional[computation_types.Type], type_to_upcast: computation_types.Type) -> bool: """Checks unambiguity in applying `type_to_upcast` to `possibly_nested_type`. That is, checks that either these types are equivalent and contain only tuples and tensors, or that `possibly_nested_type` is perhaps a nested structure containing only tensors with `dtype` of `type_to_upcast` at the leaves, where `type_to_upcast` must be a scalar tensor type. Notice that this relationship is not symmetric, since binary operators need not respect this symmetry in general. For example, it makes perfect sence to divide a nested structure of tensors by a scalar, but not the other way around. Args: possibly_nested_type: A `computation_types.Type`, or `None`. type_to_upcast: A `computation_types.Type`, or `None`. Returns: Boolean indicating whether `type_to_upcast` can be upcast to `possibly_nested_type` in the manner described above. """ if possibly_nested_type is not None: py_typecheck.check_type(possibly_nested_type, computation_types.Type) if type_to_upcast is not None: py_typecheck.check_type(type_to_upcast, computation_types.Type) if not (is_generic_op_compatible_type(possibly_nested_type) and is_generic_op_compatible_type(type_to_upcast)): return False if possibly_nested_type is None: return type_to_upcast is None if possibly_nested_type.is_equivalent_to(type_to_upcast): return True if not (type_to_upcast.is_tensor() and type_to_upcast.shape == tf.TensorShape( ())): return False types_are_ok = [True] only_allowed_dtype = type_to_upcast.dtype def _check_tensor_types(type_spec): if type_spec.is_tensor() and type_spec.dtype != only_allowed_dtype: types_are_ok[0] = False return type_spec, False type_transformations.transform_type_postorder(possibly_nested_type, _check_tensor_types) return types_are_ok[0]
def is_average_compatible(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` can be averaged. Types that are average-compatible are composed of numeric tensor types, either floating-point or complex, possibly packaged into nested named tuples, and possibly federated. Args: type_spec: a `computation_types.Type`. Returns: `True` iff `type_spec` is average-compatible, `False` otherwise. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return type_spec.dtype.is_floating or type_spec.dtype.is_complex elif type_spec.is_struct(): return all( is_average_compatible(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_average_compatible(type_spec.member) else: return False
def is_structure_of_integers(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` is a structure of integers. Note that an empty `computation_types.StructType` will return `True`, as it does not contain any non-integer types. Args: type_spec: A `computation_types.Type`. Returns: `True` iff `type_spec` is a structure of integers, otherwise `False`. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): py_typecheck.check_type(type_spec.dtype, tf.DType) return type_spec.dtype.is_integer elif type_spec.is_struct(): return all( is_structure_of_integers(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_structure_of_integers(type_spec.member) else: return False
def _to_tensor_internal_rep(*, value: Any, type_spec: computation_types.Type) -> tf.Tensor: """Normalizes tensor-like value to a tf.Tensor.""" if not tf.is_tensor(value): value = tf.convert_to_tensor(value, dtype=type_spec.dtype) elif hasattr(value, 'read_value'): # a tf.Variable-like result, get a proper tensor. value = value.read_value() value_type = ( computation_types.TensorType(value.dtype.base_dtype, value.shape)) if not type_spec.is_assignable_from(value_type): raise TypeError( 'The apparent type {} of a tensor {} does not match the expected ' 'type {}.'.format(value_type, value, type_spec)) return value
def check_type(value: Any, type_spec: computation_types.Type): """Checks whether `val` is of TFF type `type_spec`. Args: value: The object to check. type_spec: A `computation_types.Type`, the type that `value` is checked against. Raises: TypeError: If the infferred type of `value` is not `type_spec`. """ py_typecheck.check_type(type_spec, computation_types.Type) value_type = type_conversions.infer_type(value) if not type_spec.is_assignable_from(value_type): raise TypeError( 'Expected TFF type {}, which is not assignable from {}.'.format( type_spec, value_type))
def visit_preorder(type_signature: computation_types.Type, fn: Callable[[computation_types.Type, T], T], context: T): """Recursively calls `fn` on the possibly nested structure `type_signature`. Walks the tree in a preorder manner. Updates `context` on the way down with the appropriate information, as defined in `fn`. Args: type_signature: A `computation_types.Type`. fn: A function to apply to each of the constituent elements of `type_signature` with the argument `context`. Must return an updated version of `context` which incorporated the information we'd like to track as we move down the type tree. context: Initial state of information to be passed down the tree. """ context = fn(type_signature, context) for child_type in type_signature.children(): visit_preorder(child_type, fn, context)
def check_type(value: Any, type_spec: computation_types.Type): """Checks whether `val` is of TFF type `type_spec`. Args: value: The object to check. type_spec: A `computation_types.Type`, the type that `value` is checked against. Raises: TypeError: If the inferred type of `value` is not assignable to `type_spec`. """ py_typecheck.check_type(type_spec, computation_types.Type) value_type = type_conversions.infer_type(value) if not type_spec.is_assignable_from(value_type): raise TypeError( computation_types.type_mismatch_error_message( value_type, type_spec, computation_types.TypeRelation.ASSIGNABLE, second_is_expected=True))
def reconcile_value_type_with_type_spec( value_type: computation_types.Type, type_spec: Optional[computation_types.Type]) -> computation_types.Type: """Reconciles a pair of types. Args: value_type: An instance of `tff.Type`. type_spec: An instance of `tff.Type`, or `None`. Returns: Either `value_type` if `type_spec` is `None`, or `type_spec` if `type_spec` is not `None` and rquivalent with `value_type`. Raises: TypeError: If arguments are of incompatible types. """ py_typecheck.check_type(value_type, computation_types.Type) if type_spec is not None: py_typecheck.check_type(value_type, computation_types.Type) if not value_type.is_equivalent_to(type_spec): raise TypeError('Expected a value of type {}, found {}.'.format( type_spec, value_type)) return type_spec if type_spec is not None else value_type
def _is_two_tuple(t: computation_types.Type) -> bool: return t.is_struct() and len(t) == 2
def federated_empty_struct(type_spec: computation_types.Type) -> bool: return type_spec.is_struct() or type_spec.is_federated()
def create_constant( value, type_spec: computation_types.Type) -> ComputationProtoAndType: """Returns a tensorflow computation returning a constant `value`. The returned computation has the type signature `( -> T)`, where `T` is `type_spec`. `value` must be a value convertible to a tensor or a structure of values, such that the dtype and shapes match `type_spec`. `type_spec` must contain only named tuples and tensor types, but these can be arbitrarily nested. Args: value: A value to embed as a constant in the tensorflow graph. type_spec: A `computation_types.Type` to use as the argument to the constructed binary operator; must contain only named tuples and tensor types. Raises: TypeError: If the constraints of `type_spec` are violated. """ if not type_analysis.is_generic_op_compatible_type(type_spec): raise TypeError( 'Type spec {} cannot be constructed as a TensorFlow constant in TFF; ' ' only nested tuples and tensors are permitted.'.format(type_spec)) inferred_value_type = type_conversions.infer_type(value) if (inferred_value_type.is_struct() and not type_spec.is_assignable_from(inferred_value_type)): raise TypeError( 'Must pass a only tensor or structure of tensor values to ' '`create_tensorflow_constant`; encountered a value {v} with inferred ' 'type {t!r}, but needed {s!r}'.format(v=value, t=inferred_value_type, s=type_spec)) if inferred_value_type.is_struct(): value = structure.from_container(value, recursive=True) tensor_dtypes_in_type_spec = [] def _pack_dtypes(type_signature): """Appends dtype of `type_signature` to nonlocal variable.""" if type_signature.is_tensor(): tensor_dtypes_in_type_spec.append(type_signature.dtype) return type_signature, False type_transformations.transform_type_postorder(type_spec, _pack_dtypes) if (any(x.is_integer for x in tensor_dtypes_in_type_spec) and (inferred_value_type.is_tensor() and not inferred_value_type.dtype.is_integer)): raise TypeError( 'Only integers can be used as scalar values if our desired constant ' 'type spec contains any integer tensors; passed scalar {} of dtype {} ' 'for type spec {}.'.format(value, inferred_value_type.dtype, type_spec)) result_type = type_spec def _create_result_tensor(type_spec, value): """Packs `value` into `type_spec` recursively.""" if type_spec.is_tensor(): type_spec.shape.assert_is_fully_defined() result = tf.constant(value, dtype=type_spec.dtype, shape=type_spec.shape) else: elements = [] if inferred_value_type.is_struct(): # Copy the leaf values according to the type_spec structure. for (name, elem_type), value in zip( structure.iter_elements(type_spec), value): elements.append( (name, _create_result_tensor(elem_type, value))) else: # "Broadcast" the value to each level of the type_spec structure. for _, elem_type in structure.iter_elements(type_spec): elements.append( (None, _create_result_tensor(elem_type, value))) result = structure.Struct(elements) return result with tf.Graph().as_default() as graph: result = _create_result_tensor(result_type, value) _, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(None, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=None, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def type_to_tf_structure(type_spec: computation_types.Type): """Returns nested `tf.data.experimental.Structure` for a given TFF type. Args: type_spec: A `computation_types.Type`, the type specification must be composed of only named tuples and tensors. In all named tuples that appear in the type spec, all the elements must be named. Returns: An instance of `tf.data.experimental.Structure`, possibly nested, that corresponds to `type_spec`. Raises: ValueError: if the `type_spec` is composed of something other than named tuples and tensors, or if any of the elements in named tuples are unnamed. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return tf.TensorSpec(type_spec.shape, type_spec.dtype) elif type_spec.is_struct(): elements = structure.to_elements(type_spec) if not elements: return () element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements] named = element_outputs[0][0] is not None if not all((e[0] is not None) == named for e in element_outputs): raise ValueError('Tuple elements inconsistently named.') if type_spec.python_container is None: if named: return collections.OrderedDict(element_outputs) else: return tuple(v for _, v in element_outputs) else: container_type = type_spec.python_container if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): return container_type(**dict(element_outputs)) elif container_type is tf.RaggedTensor: flat_values = type_spec.flat_values nested_row_splits = type_spec.nested_row_splits ragged_rank = len(nested_row_splits) return tf.RaggedTensorSpec( shape=tf.TensorShape([None] * (ragged_rank + 1)), dtype=flat_values.dtype, ragged_rank=ragged_rank, row_splits_dtype=nested_row_splits[0].dtype, flat_values_spec=None) elif container_type is tf.SparseTensor: # We can't generally infer the shape from the type of the tensors, but # we *can* infer the rank based on the shapes of `indices` or # `dense_shape`. if (type_spec.indices.shape is not None and type_spec.indices.shape.dims[1] is not None): rank = type_spec.indices.shape.dims[1] shape = tf.TensorShape([None] * rank) elif (type_spec.dense_shape.shape is not None and type_spec.dense_shape.shape.dims[0] is not None): rank = type_spec.dense_shape.shape.dims[0] shape = tf.TensorShape([None] * rank) else: shape = None return tf.SparseTensorSpec(shape=shape, dtype=type_spec.values.dtype) elif named: return container_type(element_outputs) else: return container_type(e if e[0] is not None else e[1] for e in element_outputs) else: raise ValueError('Unsupported type {}.'.format( py_typecheck.type_string(type(type_spec))))
def transform_type_postorder( type_signature: computation_types.Type, transform_fn: Callable[[computation_types.Type], Tuple[computation_types.Type, bool]]): """Walks type tree of `type_signature` postorder, calling `transform_fn`. Args: type_signature: Instance of `computation_types.Type` to transform recursively. transform_fn: Transformation function to apply to each node in the type tree of `type_signature`. Must be instance of Python function type. Returns: A possibly transformed version of `type_signature`, with each node in its tree the result of applying `transform_fn` to the corresponding node in `type_signature`. Raises: TypeError: If the types don't match the specification above. """ py_typecheck.check_type(type_signature, computation_types.Type) py_typecheck.check_callable(transform_fn) if type_signature.is_federated(): transformed_member, member_mutated = transform_type_postorder( type_signature.member, transform_fn) if member_mutated: type_signature = computation_types.FederatedType( transformed_member, type_signature.placement, type_signature.all_equal) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or member_mutated elif type_signature.is_sequence(): transformed_element, element_mutated = transform_type_postorder( type_signature.element, transform_fn) if element_mutated: type_signature = computation_types.SequenceType( transformed_element) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or element_mutated elif type_signature.is_function(): if type_signature.parameter is not None: transformed_parameter, parameter_mutated = transform_type_postorder( type_signature.parameter, transform_fn) else: transformed_parameter, parameter_mutated = (None, False) transformed_result, result_mutated = transform_type_postorder( type_signature.result, transform_fn) if parameter_mutated or result_mutated: type_signature = computation_types.FunctionType( transformed_parameter, transformed_result) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, (type_signature_mutated or parameter_mutated or result_mutated) elif type_signature.is_struct(): elements = [] elements_mutated = False for element in structure.iter_elements(type_signature): transformed_element, element_mutated = transform_type_postorder( element[1], transform_fn) elements_mutated = elements_mutated or element_mutated elements.append((element[0], transformed_element)) if elements_mutated: if type_signature.is_struct_with_python(): type_signature = computation_types.StructWithPythonType( elements, type_signature.python_container) else: type_signature = computation_types.StructType(elements) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or elements_mutated elif type_signature.is_abstract() or type_signature.is_placement( ) or type_signature.is_tensor(): return transform_fn(type_signature)