def is_valid_bitwidth_type_for_value_type( bitwidth_type: computation_types.Type, value_type: computation_types.Type) -> bool: """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`.""" # NOTE: this function is primarily a helper for `intrinsic_factory.py`'s # `federated_secure_sum` function. py_typecheck.check_type(bitwidth_type, computation_types.Type) py_typecheck.check_type(value_type, computation_types.Type) if bitwidth_type.is_tensor(): # This condition applies to both `value_type` being a tensor or structure, # as the same integer bitwidth can be used for all values in the structure. return bitwidth_type.dtype.is_integer and ( bitwidth_type.shape.num_elements() == 1) elif value_type.is_struct() and bitwidth_type.is_struct(): bitwidth_name_and_types = list(structure.iter_elements(bitwidth_type)) value_name_and_types = list(structure.iter_elements(value_type)) if len(bitwidth_name_and_types) != len(value_name_and_types): return False for (inner_bitwidth_name, inner_bitwidth_type), (inner_value_name, inner_value_type) in zip( bitwidth_name_and_types, value_name_and_types): if inner_bitwidth_name != inner_value_name: return False if not is_valid_bitwidth_type_for_value_type(inner_bitwidth_type, inner_value_type): return False return True else: return False
def is_valid_bitwidth_type_for_value_type( bitwidth_type: computation_types.Type, value_type: computation_types.Type) -> bool: """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`.""" # NOTE: this function is primarily a helper for `intrinsic_factory.py`'s # `federated_secure_sum` function. py_typecheck.check_type(bitwidth_type, computation_types.Type) py_typecheck.check_type(value_type, computation_types.Type) if value_type.is_tensor() and bitwidth_type.is_tensor(): # Here, `value_type` refers to a tensor. Rather than check that # `bitwidth_type` is exactly the same, we check that it is a single integer, # since we want a single bitwidth integer per tensor. return bitwidth_type.dtype.is_integer and ( bitwidth_type.shape.num_elements() == 1) elif value_type.is_struct() and bitwidth_type.is_struct(): bitwidth_name_and_types = list(structure.iter_elements(bitwidth_type)) value_name_and_types = list(structure.iter_elements(value_type)) if len(bitwidth_name_and_types) != len(value_name_and_types): return False for (inner_bitwidth_name, inner_bitwidth_type), (inner_value_name, inner_value_type) in zip( bitwidth_name_and_types, value_name_and_types): if inner_bitwidth_name != inner_value_name: return False if not is_valid_bitwidth_type_for_value_type(inner_bitwidth_type, inner_value_type): return False return True else: return False
def is_sum_compatible(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` is a type that can be added to itself. Types that are sum-compatible are composed of scalars of numeric types, possibly packaged into nested named tuples, and possibly federated. Types that are sum-incompatible include sequences, functions, abstract types, and placements. Args: type_spec: A `computation_types.Type`. Returns: `True` iff `type_spec` is sum-compatible, `False` otherwise. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return is_numeric_dtype( type_spec.dtype) and type_spec.shape.is_fully_defined() elif type_spec.is_struct(): return all( is_sum_compatible(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_sum_compatible(type_spec.member) else: return False
def create_identity(type_signature: computation_types.Type) -> ProtoAndType: """Returns a tensorflow computation representing an identity function. The returned computation has the type signature `(T -> T)`, where `T` is `type_signature`. NOTE: if `T` contains `computation_types.StructType`s without an associated container type, they will be given the container type `tuple` by this function. Args: type_signature: A `computation_types.Type` to use as the parameter type and result type of the identity function. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings. """ type_analysis.check_tensorflow_compatible_type(type_signature) parameter_type = type_signature if parameter_type is None: raise TypeError('TensorFlow identity cannot be created for NoneType.') # TF relies on feeds not-identical to fetches in certain circumstances. if type_signature.is_tensor() or type_signature.is_sequence(): identity_fn = tf.identity elif type_signature.is_struct(): identity_fn = functools.partial(structure.map_structure, tf.identity) else: raise NotImplementedError( f'TensorFlow identity cannot be created for type {type_signature}') return create_computation_for_py_fn(identity_fn, parameter_type)
def ingest_value( self, value: Any, type_signature: computation_types.Type ) -> executor_value_base.ExecutorValue: if type_signature is not None: if type_signature.is_federated(): self._check_strategy_compatible_with_placement(type_signature.placement) elif type_signature.is_function() and type_signature.result.is_federated( ): self._check_strategy_compatible_with_placement( type_signature.result.placement) return FederatedResolvingStrategyValue(value, type_signature)
def _to_sequence_internal_rep( *, value: Any, type_spec: computation_types.Type) -> tf.data.Dataset: """Ingests `value`, converting to an eager dataset.""" if isinstance(value, list): value = tensorflow_utils.make_data_set_from_elements( None, value, type_spec.element) py_typecheck.check_type(value, type_conversions.TF_DATASET_REPRESENTATION_TYPES) element_type = computation_types.to_type(value.element_spec) value_type = computation_types.SequenceType(element_type) type_spec.check_assignable_from(value_type) return value
def _ensure_deserialized_types_compatible( previous_type: Optional[computation_types.Type], next_type: computation_types.Type) -> computation_types.Type: """Ensures one of `previous_type` or `next_type` is assignable to the other. Returns the type which is assignable from the other. Args: previous_type: Instance of `computation_types.Type` or `None`. next_type: Instance of `computation_types.Type`. Returns: The supertype of `previous_type` and `next_type`. Raises: TypeError if neither type is assignable from the other. """ if previous_type is None: return next_type else: if next_type.is_assignable_from(previous_type): return next_type elif previous_type.is_assignable_from(next_type): return previous_type raise TypeError('Type mismatch checking member assignability under a ' 'federated value. Deserialized type {} is incompatible ' 'with previously deserialized {}.'.format( next_type, previous_type))
def type_to_tf_structure(type_spec: computation_types.Type): """Returns nested `tf.data.experimental.Structure` for a given TFF type. Args: type_spec: A `computation_types.Type`, the type specification must be composed of only named tuples and tensors. In all named tuples that appear in the type spec, all the elements must be named. Returns: An instance of `tf.data.experimental.Structure`, possibly nested, that corresponds to `type_spec`. Raises: ValueError: if the `type_spec` is composed of something other than named tuples and tensors, or if any of the elements in named tuples are unnamed. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return tf.TensorSpec(type_spec.shape, type_spec.dtype) elif type_spec.is_tuple(): elements = anonymous_tuple.to_elements(type_spec) if not elements: raise ValueError('Empty tuples are unsupported.') element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements] named = element_outputs[0][0] is not None if not all((e[0] is not None) == named for e in element_outputs): raise ValueError('Tuple elements inconsistently named.') if not type_spec.is_tuple_with_py_container(): if named: output = collections.OrderedDict(element_outputs) else: output = tuple(v for _, v in element_outputs) else: container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type( type_spec) if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): output = container_type(**dict(element_outputs)) elif named: output = container_type(element_outputs) else: output = container_type(e if e[0] is not None else e[1] for e in element_outputs) return output else: raise ValueError('Unsupported type {}.'.format( py_typecheck.type_string(type(type_spec))))
def _visit_preorder( type_spec: computation_types.Type, fn: Callable[[computation_types.Type, T], T], context: T, ): context = fn(type_spec, context) for child_type in type_spec.children(): visit_preorder(child_type, fn, context)
def contains(type_signature: computation_types.Type, predicate: _TypePredicate) -> bool: """Checks if `type_signature` contains any types that pass `predicate`.""" if predicate(type_signature): return True for child in type_signature.children(): if contains(child, predicate): return True return False
def _check_type_is_fn( target: computation_types.Type, name: str, err_fn: Callable[ [str], Exception] = transformations.CanonicalFormCompilationError, ): if not target.is_function(): raise err_fn(f'Expected {name} to be a function, but {name} had type ' f'{target}.')
def transform_to_tff_known_type( type_spec: computation_types.Type) -> Tuple[computation_types.Type, bool]: """Transforms `StructType` to `StructWithPythonType`.""" if type_spec.is_struct() and not type_spec.is_struct_with_python(): field_is_named = tuple( name is not None for name, _ in structure.iter_elements(type_spec)) has_names = any(field_is_named) is_all_named = all(field_is_named) if is_all_named: return computation_types.StructWithPythonType( elements=structure.iter_elements(type_spec), container_type=collections.OrderedDict), True elif not has_names: return computation_types.StructWithPythonType( elements=structure.iter_elements(type_spec), container_type=tuple), True else: raise TypeError('Cannot represent TFF type in TF because it contains ' f'partially named structures. Type: {type_spec}') return type_spec, False
def is_structure_of_integers(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` is a structure of integers. Args: type_spec: A `computation_types.Type`. Returns: `True` iff `type_spec` is a structure of integers, otherwise `False`. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): py_typecheck.check_type(type_spec.dtype, tf.DType) return type_spec.dtype.is_integer elif type_spec.is_struct(): return all( is_structure_of_integers(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_structure_of_integers(type_spec.member) else: return False
def count(type_signature: computation_types.Type, predicate: _TypePredicate) -> int: """Returns the number of types in `type_signature` matching `predicate`. Args: type_signature: A tree of `computation_type.Type`s to count. predicate: A Python function that takes a type as a parameter and returns a boolean value. """ counter = 1 if predicate(type_signature) else 0 counter += sum(count(child, predicate) for child in type_signature.children()) return counter
def create_identity(type_signature: computation_types.Type) -> ProtoAndType: """Returns a tensorflow computation representing an identity function. The returned computation has the type signature `(T -> T)`, where `T` is `type_signature`. NOTE: if `T` contains `computation_types.StructType`s without an associated container type, they will be given the container type `tuple` by this function. Args: type_signature: A `computation_types.Type` to use as the parameter type and result type of the identity function. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings. """ type_analysis.check_tensorflow_compatible_type(type_signature) parameter_type = type_signature if parameter_type is None: raise TypeError('TensorFlow identity cannot be created for NoneType.') with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', parameter_type, graph) # TF relies on feeds not-identical to fetches in certain circumstances. if type_signature.is_tensor(): parameter_value = tf.identity(parameter_value) elif type_signature.is_struct(): parameter_value = structure.map_structure(tf.identity, parameter_value) result_type, result_binding = tensorflow_utils.capture_result_from_graph( parameter_value, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def is_binary_op_with_upcast_compatible_pair( possibly_nested_type: Optional[computation_types.Type], type_to_upcast: computation_types.Type) -> bool: """Checks unambiguity in applying `type_to_upcast` to `possibly_nested_type`. That is, checks that either these types are equivalent and contain only tuples and tensors, or that `possibly_nested_type` is perhaps a nested structure containing only tensors with `dtype` of `type_to_upcast` at the leaves, where `type_to_upcast` must be a scalar tensor type. Notice that this relationship is not symmetric, since binary operators need not respect this symmetry in general. For example, it makes perfect sence to divide a nested structure of tensors by a scalar, but not the other way around. Args: possibly_nested_type: A `computation_types.Type`, or `None`. type_to_upcast: A `computation_types.Type`, or `None`. Returns: Boolean indicating whether `type_to_upcast` can be upcast to `possibly_nested_type` in the manner described above. """ if possibly_nested_type is not None: py_typecheck.check_type(possibly_nested_type, computation_types.Type) if type_to_upcast is not None: py_typecheck.check_type(type_to_upcast, computation_types.Type) if not (is_generic_op_compatible_type(possibly_nested_type) and is_generic_op_compatible_type(type_to_upcast)): return False if possibly_nested_type is None: return type_to_upcast is None if possibly_nested_type.is_equivalent_to(type_to_upcast): return True if not (type_to_upcast.is_tensor() and type_to_upcast.shape == tf.TensorShape( ())): return False types_are_ok = [True] only_allowed_dtype = type_to_upcast.dtype def _check_tensor_types(type_spec): if type_spec.is_tensor() and type_spec.dtype != only_allowed_dtype: types_are_ok[0] = False return type_spec, False type_transformations.transform_type_postorder(possibly_nested_type, _check_tensor_types) return types_are_ok[0]
def is_average_compatible(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` can be averaged. Types that are average-compatible are composed of numeric tensor types, either floating-point or complex, possibly packaged into nested named tuples, and possibly federated. Args: type_spec: a `computation_types.Type`. Returns: `True` iff `type_spec` is average-compatible, `False` otherwise. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return type_spec.dtype.is_floating or type_spec.dtype.is_complex elif type_spec.is_struct(): return all( is_average_compatible(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_average_compatible(type_spec.member) else: return False
def _to_tensor_internal_rep(*, value: Any, type_spec: computation_types.Type) -> tf.Tensor: """Normalizes tensor-like value to a tf.Tensor.""" if not tf.is_tensor(value): value = tf.convert_to_tensor(value, dtype=type_spec.dtype) elif hasattr(value, 'read_value'): # a tf.Variable-like result, get a proper tensor. value = value.read_value() value_type = (computation_types.TensorType(value.dtype.base_dtype, value.shape)) if not type_spec.is_assignable_from(value_type): raise TypeError( 'The apparent type {} of a tensor {} does not match the expected ' 'type {}.'.format(value_type, value, type_spec)) return value
def is_structure_of_floats(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` is a structure of floats. Note that an empty `computation_types.StructType` will return `True`, as it does not contain any non-floating types. Args: type_spec: A `computation_types.Type`. Returns: `True` iff `type_spec` is a structure of floats, otherwise `False`. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): py_typecheck.check_type(type_spec.dtype, tf.DType) return type_spec.dtype.is_floating elif type_spec.is_struct(): return all( is_structure_of_floats(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_structure_of_floats(type_spec.member) else: return False
def check_type(value: Any, type_spec: computation_types.Type): """Checks whether `val` is of TFF type `type_spec`. Args: value: The object to check. type_spec: A `computation_types.Type`, the type that `value` is checked against. Raises: TypeError: If the infferred type of `value` is not `type_spec`. """ py_typecheck.check_type(type_spec, computation_types.Type) value_type = type_conversions.infer_type(value) if not type_spec.is_assignable_from(value_type): raise TypeError( 'Expected TFF type {}, which is not assignable from {}.'.format( type_spec, value_type))
def visit_preorder(type_signature: computation_types.Type, fn: Callable[[computation_types.Type, T], T], context: T): """Recursively calls `fn` on the possibly nested structure `type_signature`. Walks the tree in a preorder manner. Updates `context` on the way down with the appropriate information, as defined in `fn`. Args: type_signature: A `computation_types.Type`. fn: A function to apply to each of the constituent elements of `type_signature` with the argument `context`. Must return an updated version of `context` which incorporated the information we'd like to track as we move down the type tree. context: Initial state of information to be passed down the tree. """ context = fn(type_signature, context) for child_type in type_signature.children(): visit_preorder(child_type, fn, context)
def reconcile_value_type_with_type_spec( value_type: computation_types.Type, type_spec: Optional[computation_types.Type]) -> computation_types.Type: """Reconciles a pair of types. Args: value_type: An instance of `tff.Type`. type_spec: An instance of `tff.Type`, or `None`. Returns: Either `value_type` if `type_spec` is `None`, or `type_spec` if `type_spec` is not `None` and rquivalent with `value_type`. Raises: TypeError: If arguments are of incompatible types. """ py_typecheck.check_type(value_type, computation_types.Type) if type_spec is not None: py_typecheck.check_type(value_type, computation_types.Type) if not value_type.is_equivalent_to(type_spec): raise TypeError('Expected a value of type {}, found {}.'.format( type_spec, value_type)) return type_spec if type_spec is not None else value_type
def type_to_tf_dtypes_and_shapes(type_spec: computation_types.Type): """Returns nested structures of tensor dtypes and shapes for a given TFF type. The returned dtypes and shapes match those used by `tf.data.Dataset`s to indicate the type and shape of their elements. They can be used, e.g., as arguments in constructing an iterator over a string handle. Args: type_spec: A `computation_types.Type`, the type specification must be composed of only named tuples and tensors. In all named tuples that appear in the type spec, all the elements must be named. Returns: A pair of parallel nested structures with the dtypes and shapes of tensors defined in `type_spec`. The layout of the two structures returned is the same as the layout of the nested type defined by `type_spec`. Named tuples are represented as dictionaries. Raises: ValueError: if the `type_spec` is composed of something other than named tuples and tensors, or if any of the elements in named tuples are unnamed. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return (type_spec.dtype, type_spec.shape) elif type_spec.is_struct(): elements = structure.to_elements(type_spec) if not elements: output_dtypes = [] output_shapes = [] elif elements[0][0] is not None: output_dtypes = collections.OrderedDict() output_shapes = collections.OrderedDict() for e in elements: element_name = e[0] element_spec = e[1] if element_name is None: raise ValueError( 'When a sequence appears as a part of a parameter to a section ' 'of TensorFlow code, in the type signature of elements of that ' 'sequence all named tuples must have their elements explicitly ' 'named, and this does not appear to be the case in {}.' .format(type_spec)) element_output = type_to_tf_dtypes_and_shapes(element_spec) output_dtypes[element_name] = element_output[0] output_shapes[element_name] = element_output[1] else: output_dtypes = [] output_shapes = [] for e in elements: element_name = e[0] element_spec = e[1] if element_name is not None: raise ValueError( 'When a sequence appears as a part of a parameter to a section ' 'of TensorFlow code, in the type signature of elements of that ' 'sequence all named tuples must have their elements explicitly ' 'named, and this does not appear to be the case in {}.' .format(type_spec)) element_output = type_to_tf_dtypes_and_shapes(element_spec) output_dtypes.append(element_output[0]) output_shapes.append(element_output[1]) if type_spec.python_container is not None: container_type = type_spec.python_container def build_py_container(elements): if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): return container_type(**dict(elements)) else: return container_type(elements) output_dtypes = build_py_container(output_dtypes) output_shapes = build_py_container(output_shapes) else: output_dtypes = tuple(output_dtypes) output_shapes = tuple(output_shapes) return (output_dtypes, output_shapes) else: raise ValueError('Unsupported type {}.'.format( py_typecheck.type_string(type(type_spec))))
def _is_two_tuple(t: computation_types.Type) -> bool: return t.is_struct() and len(t) == 2
def _concretize_abstract_types( abstract_type_spec: computation_types.Type, concrete_type_spec: computation_types.Type) -> computation_types.Type: """Recursive helper function to construct concrete type spec.""" if abstract_type_spec.is_abstract(): bound_type = bound_abstract_types.get(str(abstract_type_spec.label)) if bound_type: return bound_type else: bound_abstract_types[str(abstract_type_spec.label)] = concrete_type_spec return concrete_type_spec elif abstract_type_spec.is_tensor(): return abstract_type_spec elif abstract_type_spec.is_struct(): if not concrete_type_spec.is_struct(): raise TypeError(type_error_string) abstract_elements = structure.to_elements(abstract_type_spec) concrete_elements = structure.to_elements(concrete_type_spec) if len(abstract_elements) != len(concrete_elements): raise TypeError(type_error_string) concretized_tuple_elements = [] for k in range(len(abstract_elements)): if abstract_elements[k][0] != concrete_elements[k][0]: raise TypeError(type_error_string) concretized_tuple_elements.append( (abstract_elements[k][0], _concretize_abstract_types(abstract_elements[k][1], concrete_elements[k][1]))) return computation_types.StructType(concretized_tuple_elements) elif abstract_type_spec.is_sequence(): if not concrete_type_spec.is_sequence(): raise TypeError(type_error_string) return computation_types.SequenceType( _concretize_abstract_types(abstract_type_spec.element, concrete_type_spec.element)) elif abstract_type_spec.is_function(): if not concrete_type_spec.is_function(): raise TypeError(type_error_string) if abstract_type_spec.parameter is None: if concrete_type_spec.parameter is not None: return TypeError(type_error_string) concretized_param = None else: concretized_param = _concretize_abstract_types( abstract_type_spec.parameter, concrete_type_spec.parameter) concretized_result = _concretize_abstract_types(abstract_type_spec.result, concrete_type_spec.result) return computation_types.FunctionType(concretized_param, concretized_result) elif abstract_type_spec.is_placement(): if not concrete_type_spec.is_placement(): raise TypeError(type_error_string) return abstract_type_spec elif abstract_type_spec.is_federated(): if not concrete_type_spec.is_federated(): raise TypeError(type_error_string) new_member = _concretize_abstract_types(abstract_type_spec.member, concrete_type_spec.member) return computation_types.FederatedType(new_member, abstract_type_spec.placement, abstract_type_spec.all_equal) else: raise TypeError( 'Unexpected abstract typespec {}.'.format(abstract_type_spec))
def create_binary_operator( operator, operand_type: computation_types.Type) -> ProtoAndType: """Returns a tensorflow computation computing a binary operation. The returned computation has the type signature `(<T,T> -> U)`, where `T` is `operand_type` and `U` is the result of applying the `operator` to a tuple of type `<T,T>` Note: If `operand_type` is a `computation_types.StructType`, then `operator` will be applied pointwise. This places the burden on callers of this function to construct the correct values to pass into the returned function. For example, to divide `[2, 2]` by `2`, first `2` must be packed into the data structure `[x, x]`, before the division operator of the appropriate type is called. Args: operator: A callable taking two arguments representing the operation to encode For example: `tf.math.add`, `tf.math.multiply`, and `tf.math.divide`. operand_type: A `computation_types.Type` to use as the argument to the constructed binary operator; must contain only named tuples and tensor types. Raises: TypeError: If the constraints of `operand_type` are violated or `operator` is not callable. """ if not type_analysis.is_generic_op_compatible_type(operand_type): raise TypeError( 'The type {} contains a type other than `computation_types.TensorType` ' 'and `computation_types.StructType`; this is disallowed in the ' 'generic operators.'.format(operand_type)) py_typecheck.check_callable(operator) with tf.Graph().as_default() as graph: operand_1_value, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', operand_type, graph) operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', operand_type, graph) if operand_type is not None: if operand_type.is_tensor(): result_value = operator(operand_1_value, operand_2_value) elif operand_type.is_struct(): result_value = structure.map_structure(operator, operand_1_value, operand_2_value) else: raise TypeError( 'Operand type {} cannot be used in generic operations. The call to ' '`type_analysis.is_generic_op_compatible_type` has allowed it to ' 'pass, and should be updated.'.format(operand_type)) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType( computation_types.StructType((operand_type, operand_type)), result_type) parameter_binding = pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=[operand_1_binding, operand_2_binding])) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def _federated_same_placement(x: computation_types.Type, y: computation_types.Type) -> bool: return x.is_federated() and y.is_federated() and x.placement == y.placement
def _check_helper(generic_type_member: computation_types.Type, concrete_type_member: computation_types.Type, defining: bool): """Recursive helper function.""" def _raise_structural(mismatch): raise MismatchedStructureError(concrete_type, generic_type, concrete_type_member, generic_type_member, mismatch) def _both_are(predicate): if predicate(generic_type_member): if predicate(concrete_type_member): return True else: _raise_structural('kind') else: return False if generic_type_member.is_abstract(): label = str(generic_type_member.label) if not defining: non_defining_usages[label].append(concrete_type_member) else: bound_type = type_bindings.get(label) if bound_type is not None: if not concrete_type_member.is_equivalent_to(bound_type): raise MismatchedConcreteTypesError(concrete_type, generic_type, label, bound_type, concrete_type_member) else: type_bindings[label] = concrete_type_member elif _both_are(lambda t: t.is_tensor()): if generic_type_member != concrete_type_member: _raise_structural('tensor types') elif _both_are(lambda t: t.is_placement()): if generic_type_member != concrete_type_member: _raise_structural('placements') elif _both_are(lambda t: t.is_struct()): generic_elements = structure.to_elements(generic_type_member) concrete_elements = structure.to_elements(concrete_type_member) if len(generic_elements) != len(concrete_elements): _raise_structural('length') for k in range(len(generic_elements)): if generic_elements[k][0] != concrete_elements[k][0]: _raise_structural('element names') _check_helper(generic_elements[k][1], concrete_elements[k][1], defining) elif _both_are(lambda t: t.is_sequence()): _check_helper(generic_type_member.element, concrete_type_member.element, defining) elif _both_are(lambda t: t.is_function()): if generic_type_member.parameter is None: if concrete_type_member.parameter is not None: _raise_structural('parameter') else: _check_helper(generic_type_member.parameter, concrete_type_member.parameter, not defining) _check_helper(generic_type_member.result, concrete_type_member.result, defining) elif _both_are(lambda t: t.is_federated()): if generic_type_member.placement != concrete_type_member.placement: _raise_structural('placement') if generic_type_member.all_equal != concrete_type_member.all_equal: _raise_structural('all equal') _check_helper(generic_type_member.member, concrete_type_member.member, defining) else: raise TypeError(f'Unexpected type kind {generic_type}.')
def _deserialize_dataset_from_graph_def(serialized_graph_def: bytes, element_type: computation_types.Type): """Deserializes a serialized `tf.compat.v1.GraphDef` to a `tf.data.Dataset`. Args: serialized_graph_def: `bytes` object produced by `tensorflow_serialization.serialize_dataset` element_type: a `tff.Type` object representing the type structure of the elements yielded from the dataset. Returns: A `tf.data.Dataset` instance. """ py_typecheck.check_type(element_type, computation_types.Type) type_analysis.check_tensorflow_compatible_type(element_type) def transform_to_tff_known_type( type_spec: computation_types.Type ) -> Tuple[computation_types.Type, bool]: """Transforms `StructType` to `StructWithPythonType`.""" if type_spec.is_struct() and not type_spec.is_struct_with_python(): field_is_named = tuple( name is not None for name, _ in structure.iter_elements(type_spec)) has_names = any(field_is_named) is_all_named = all(field_is_named) if is_all_named: return computation_types.StructWithPythonType( elements=structure.iter_elements(type_spec), container_type=collections.OrderedDict), True elif not has_names: return computation_types.StructWithPythonType( elements=structure.iter_elements(type_spec), container_type=tuple), True else: raise TypeError( 'Cannot represent TFF type in TF because it contains ' f'partially named structures. Type: {type_spec}') return type_spec, False if element_type.is_struct(): # TF doesn't suppor `structure.Strut` types, so we must transform the # `StructType` into a `StructWithPythonType` for use as the # `tf.data.Dataset.element_spec` later. tf_compatible_type, _ = type_transformations.transform_type_postorder( element_type, transform_to_tff_known_type) else: # We've checked this is only a struct or tensors, so we know this is a # `TensorType` here and will use as-is. tf_compatible_type = element_type def type_to_tensorspec(t: computation_types.TensorType) -> tf.TensorSpec: return tf.TensorSpec(shape=t.shape, dtype=t.dtype) element_spec = type_conversions.structure_from_tensor_type_tree( type_to_tensorspec, tf_compatible_type) ds = tf.data.experimental.from_variant( tf.raw_ops.DatasetFromGraph(graph_def=serialized_graph_def), structure=element_spec) # If a serialized dataset had elements of nested structes of tensors (e.g. # `dict`, `OrderedDict`), the deserialized dataset will return `dict`, # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion). # # Since the dataset will only be used inside TFF, we wrap the dictionary # coming from TF in an `OrderedDict` when necessary (a type that both TF and # TFF understand), using the field order stored in the TFF type stored during # serialization. return tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( ds, tf_compatible_type)
def _preorder_types(type_signature: computation_types.Type): yield type_signature for child in type_signature.children(): yield from _preorder_types(child)