def is_valid_bitwidth_type_for_value_type( bitwidth_type: computation_types.Type, value_type: computation_types.Type) -> bool: """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`.""" # NOTE: this function is primarily a helper for `intrinsic_factory.py`'s # `federated_secure_sum` function. py_typecheck.check_type(bitwidth_type, computation_types.Type) py_typecheck.check_type(value_type, computation_types.Type) if value_type.is_tensor() and bitwidth_type.is_tensor(): # Here, `value_type` refers to a tensor. Rather than check that # `bitwidth_type` is exactly the same, we check that it is a single integer, # since we want a single bitwidth integer per tensor. return bitwidth_type.dtype.is_integer and ( bitwidth_type.shape.num_elements() == 1) elif value_type.is_struct() and bitwidth_type.is_struct(): bitwidth_name_and_types = list(structure.iter_elements(bitwidth_type)) value_name_and_types = list(structure.iter_elements(value_type)) if len(bitwidth_name_and_types) != len(value_name_and_types): return False for (inner_bitwidth_name, inner_bitwidth_type), (inner_value_name, inner_value_type) in zip( bitwidth_name_and_types, value_name_and_types): if inner_bitwidth_name != inner_value_name: return False if not is_valid_bitwidth_type_for_value_type(inner_bitwidth_type, inner_value_type): return False return True else: return False
def is_valid_bitwidth_type_for_value_type( bitwidth_type: computation_types.Type, value_type: computation_types.Type) -> bool: """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`.""" # NOTE: this function is primarily a helper for `intrinsic_factory.py`'s # `federated_secure_sum` function. py_typecheck.check_type(bitwidth_type, computation_types.Type) py_typecheck.check_type(value_type, computation_types.Type) if bitwidth_type.is_tensor(): # This condition applies to both `value_type` being a tensor or structure, # as the same integer bitwidth can be used for all values in the structure. return bitwidth_type.dtype.is_integer and ( bitwidth_type.shape.num_elements() == 1) elif value_type.is_struct() and bitwidth_type.is_struct(): bitwidth_name_and_types = list(structure.iter_elements(bitwidth_type)) value_name_and_types = list(structure.iter_elements(value_type)) if len(bitwidth_name_and_types) != len(value_name_and_types): return False for (inner_bitwidth_name, inner_bitwidth_type), (inner_value_name, inner_value_type) in zip( bitwidth_name_and_types, value_name_and_types): if inner_bitwidth_name != inner_value_name: return False if not is_valid_bitwidth_type_for_value_type(inner_bitwidth_type, inner_value_type): return False return True else: return False
def is_sum_compatible(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` is a type that can be added to itself. Types that are sum-compatible are composed of scalars of numeric types, possibly packaged into nested named tuples, and possibly federated. Types that are sum-incompatible include sequences, functions, abstract types, and placements. Args: type_spec: A `computation_types.Type`. Returns: `True` iff `type_spec` is sum-compatible, `False` otherwise. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return is_numeric_dtype( type_spec.dtype) and type_spec.shape.is_fully_defined() elif type_spec.is_struct(): return all( is_sum_compatible(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_sum_compatible(type_spec.member) else: return False
def create_identity(type_signature: computation_types.Type) -> ProtoAndType: """Returns a tensorflow computation representing an identity function. The returned computation has the type signature `(T -> T)`, where `T` is `type_signature`. NOTE: if `T` contains `computation_types.StructType`s without an associated container type, they will be given the container type `tuple` by this function. Args: type_signature: A `computation_types.Type` to use as the parameter type and result type of the identity function. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings. """ type_analysis.check_tensorflow_compatible_type(type_signature) parameter_type = type_signature if parameter_type is None: raise TypeError('TensorFlow identity cannot be created for NoneType.') # TF relies on feeds not-identical to fetches in certain circumstances. if type_signature.is_tensor() or type_signature.is_sequence(): identity_fn = tf.identity elif type_signature.is_struct(): identity_fn = functools.partial(structure.map_structure, tf.identity) else: raise NotImplementedError( f'TensorFlow identity cannot be created for type {type_signature}') return create_computation_for_py_fn(identity_fn, parameter_type)
def type_to_tf_structure(type_spec: computation_types.Type): """Returns nested `tf.data.experimental.Structure` for a given TFF type. Args: type_spec: A `computation_types.Type`, the type specification must be composed of only named tuples and tensors. In all named tuples that appear in the type spec, all the elements must be named. Returns: An instance of `tf.data.experimental.Structure`, possibly nested, that corresponds to `type_spec`. Raises: ValueError: if the `type_spec` is composed of something other than named tuples and tensors, or if any of the elements in named tuples are unnamed. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return tf.TensorSpec(type_spec.shape, type_spec.dtype) elif type_spec.is_struct(): elements = structure.to_elements(type_spec) if not elements: raise ValueError('Empty tuples are unsupported.') element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements] named = element_outputs[0][0] is not None if not all((e[0] is not None) == named for e in element_outputs): raise ValueError('Tuple elements inconsistently named.') if type_spec.python_container is None: if named: output = collections.OrderedDict(element_outputs) else: output = tuple(v for _, v in element_outputs) else: container_type = type_spec.python_container if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): output = container_type(**dict(element_outputs)) elif named: output = container_type(element_outputs) else: output = container_type(e if e[0] is not None else e[1] for e in element_outputs) return output else: raise ValueError('Unsupported type {}.'.format( py_typecheck.type_string(type(type_spec))))
def transform_to_tff_known_type( type_spec: computation_types.Type) -> Tuple[computation_types.Type, bool]: """Transforms `StructType` to `StructWithPythonType`.""" if type_spec.is_struct() and not type_spec.is_struct_with_python(): field_is_named = tuple( name is not None for name, _ in structure.iter_elements(type_spec)) has_names = any(field_is_named) is_all_named = all(field_is_named) if is_all_named: return computation_types.StructWithPythonType( elements=structure.iter_elements(type_spec), container_type=collections.OrderedDict), True elif not has_names: return computation_types.StructWithPythonType( elements=structure.iter_elements(type_spec), container_type=tuple), True else: raise TypeError('Cannot represent TFF type in TF because it contains ' f'partially named structures. Type: {type_spec}') return type_spec, False
def is_structure_of_integers(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` is a structure of integers. Args: type_spec: A `computation_types.Type`. Returns: `True` iff `type_spec` is a structure of integers, otherwise `False`. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): py_typecheck.check_type(type_spec.dtype, tf.DType) return type_spec.dtype.is_integer elif type_spec.is_struct(): return all( is_structure_of_integers(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_structure_of_integers(type_spec.member) else: return False
def create_identity(type_signature: computation_types.Type) -> ProtoAndType: """Returns a tensorflow computation representing an identity function. The returned computation has the type signature `(T -> T)`, where `T` is `type_signature`. NOTE: if `T` contains `computation_types.StructType`s without an associated container type, they will be given the container type `tuple` by this function. Args: type_signature: A `computation_types.Type` to use as the parameter type and result type of the identity function. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings. """ type_analysis.check_tensorflow_compatible_type(type_signature) parameter_type = type_signature if parameter_type is None: raise TypeError('TensorFlow identity cannot be created for NoneType.') with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', parameter_type, graph) # TF relies on feeds not-identical to fetches in certain circumstances. if type_signature.is_tensor(): parameter_value = tf.identity(parameter_value) elif type_signature.is_struct(): parameter_value = structure.map_structure(tf.identity, parameter_value) result_type, result_binding = tensorflow_utils.capture_result_from_graph( parameter_value, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def is_average_compatible(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` can be averaged. Types that are average-compatible are composed of numeric tensor types, either floating-point or complex, possibly packaged into nested named tuples, and possibly federated. Args: type_spec: a `computation_types.Type`. Returns: `True` iff `type_spec` is average-compatible, `False` otherwise. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return type_spec.dtype.is_floating or type_spec.dtype.is_complex elif type_spec.is_struct(): return all( is_average_compatible(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_average_compatible(type_spec.member) else: return False
def is_structure_of_floats(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` is a structure of floats. Note that an empty `computation_types.StructType` will return `True`, as it does not contain any non-floating types. Args: type_spec: A `computation_types.Type`. Returns: `True` iff `type_spec` is a structure of floats, otherwise `False`. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): py_typecheck.check_type(type_spec.dtype, tf.DType) return type_spec.dtype.is_floating elif type_spec.is_struct(): return all( is_structure_of_floats(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_structure_of_floats(type_spec.member) else: return False
def type_to_tf_dtypes_and_shapes(type_spec: computation_types.Type): """Returns nested structures of tensor dtypes and shapes for a given TFF type. The returned dtypes and shapes match those used by `tf.data.Dataset`s to indicate the type and shape of their elements. They can be used, e.g., as arguments in constructing an iterator over a string handle. Args: type_spec: A `computation_types.Type`, the type specification must be composed of only named tuples and tensors. In all named tuples that appear in the type spec, all the elements must be named. Returns: A pair of parallel nested structures with the dtypes and shapes of tensors defined in `type_spec`. The layout of the two structures returned is the same as the layout of the nested type defined by `type_spec`. Named tuples are represented as dictionaries. Raises: ValueError: if the `type_spec` is composed of something other than named tuples and tensors, or if any of the elements in named tuples are unnamed. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return (type_spec.dtype, type_spec.shape) elif type_spec.is_struct(): elements = structure.to_elements(type_spec) if not elements: output_dtypes = [] output_shapes = [] elif elements[0][0] is not None: output_dtypes = collections.OrderedDict() output_shapes = collections.OrderedDict() for e in elements: element_name = e[0] element_spec = e[1] if element_name is None: raise ValueError( 'When a sequence appears as a part of a parameter to a section ' 'of TensorFlow code, in the type signature of elements of that ' 'sequence all named tuples must have their elements explicitly ' 'named, and this does not appear to be the case in {}.' .format(type_spec)) element_output = type_to_tf_dtypes_and_shapes(element_spec) output_dtypes[element_name] = element_output[0] output_shapes[element_name] = element_output[1] else: output_dtypes = [] output_shapes = [] for e in elements: element_name = e[0] element_spec = e[1] if element_name is not None: raise ValueError( 'When a sequence appears as a part of a parameter to a section ' 'of TensorFlow code, in the type signature of elements of that ' 'sequence all named tuples must have their elements explicitly ' 'named, and this does not appear to be the case in {}.' .format(type_spec)) element_output = type_to_tf_dtypes_and_shapes(element_spec) output_dtypes.append(element_output[0]) output_shapes.append(element_output[1]) if type_spec.python_container is not None: container_type = type_spec.python_container def build_py_container(elements): if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): return container_type(**dict(elements)) else: return container_type(elements) output_dtypes = build_py_container(output_dtypes) output_shapes = build_py_container(output_shapes) else: output_dtypes = tuple(output_dtypes) output_shapes = tuple(output_shapes) return (output_dtypes, output_shapes) else: raise ValueError('Unsupported type {}.'.format( py_typecheck.type_string(type(type_spec))))
def _is_compatible(t: computation_types.Type) -> bool: return type_analysis.contains_only( t, lambda t: t.is_struct() or t.is_tensor() or t.is_federated())
def _is_two_tuple(t: computation_types.Type) -> bool: return t.is_struct() and len(t) == 2
def create_binary_operator( operator, operand_type: computation_types.Type) -> ProtoAndType: """Returns a tensorflow computation computing a binary operation. The returned computation has the type signature `(<T,T> -> U)`, where `T` is `operand_type` and `U` is the result of applying the `operator` to a tuple of type `<T,T>` Note: If `operand_type` is a `computation_types.StructType`, then `operator` will be applied pointwise. This places the burden on callers of this function to construct the correct values to pass into the returned function. For example, to divide `[2, 2]` by `2`, first `2` must be packed into the data structure `[x, x]`, before the division operator of the appropriate type is called. Args: operator: A callable taking two arguments representing the operation to encode For example: `tf.math.add`, `tf.math.multiply`, and `tf.math.divide`. operand_type: A `computation_types.Type` to use as the argument to the constructed binary operator; must contain only named tuples and tensor types. Raises: TypeError: If the constraints of `operand_type` are violated or `operator` is not callable. """ if not type_analysis.is_generic_op_compatible_type(operand_type): raise TypeError( 'The type {} contains a type other than `computation_types.TensorType` ' 'and `computation_types.StructType`; this is disallowed in the ' 'generic operators.'.format(operand_type)) py_typecheck.check_callable(operator) with tf.Graph().as_default() as graph: operand_1_value, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', operand_type, graph) operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', operand_type, graph) if operand_type is not None: if operand_type.is_tensor(): result_value = operator(operand_1_value, operand_2_value) elif operand_type.is_struct(): result_value = structure.map_structure(operator, operand_1_value, operand_2_value) else: raise TypeError( 'Operand type {} cannot be used in generic operations. The call to ' '`type_analysis.is_generic_op_compatible_type` has allowed it to ' 'pass, and should be updated.'.format(operand_type)) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType( computation_types.StructType((operand_type, operand_type)), result_type) parameter_binding = pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=[operand_1_binding, operand_2_binding])) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def _concretize_abstract_types( abstract_type_spec: computation_types.Type, concrete_type_spec: computation_types.Type) -> computation_types.Type: """Recursive helper function to construct concrete type spec.""" if abstract_type_spec.is_abstract(): bound_type = bound_abstract_types.get(str(abstract_type_spec.label)) if bound_type: return bound_type else: bound_abstract_types[str(abstract_type_spec.label)] = concrete_type_spec return concrete_type_spec elif abstract_type_spec.is_tensor(): return abstract_type_spec elif abstract_type_spec.is_struct(): if not concrete_type_spec.is_struct(): raise TypeError(type_error_string) abstract_elements = structure.to_elements(abstract_type_spec) concrete_elements = structure.to_elements(concrete_type_spec) if len(abstract_elements) != len(concrete_elements): raise TypeError(type_error_string) concretized_tuple_elements = [] for k in range(len(abstract_elements)): if abstract_elements[k][0] != concrete_elements[k][0]: raise TypeError(type_error_string) concretized_tuple_elements.append( (abstract_elements[k][0], _concretize_abstract_types(abstract_elements[k][1], concrete_elements[k][1]))) return computation_types.StructType(concretized_tuple_elements) elif abstract_type_spec.is_sequence(): if not concrete_type_spec.is_sequence(): raise TypeError(type_error_string) return computation_types.SequenceType( _concretize_abstract_types(abstract_type_spec.element, concrete_type_spec.element)) elif abstract_type_spec.is_function(): if not concrete_type_spec.is_function(): raise TypeError(type_error_string) if abstract_type_spec.parameter is None: if concrete_type_spec.parameter is not None: return TypeError(type_error_string) concretized_param = None else: concretized_param = _concretize_abstract_types( abstract_type_spec.parameter, concrete_type_spec.parameter) concretized_result = _concretize_abstract_types(abstract_type_spec.result, concrete_type_spec.result) return computation_types.FunctionType(concretized_param, concretized_result) elif abstract_type_spec.is_placement(): if not concrete_type_spec.is_placement(): raise TypeError(type_error_string) return abstract_type_spec elif abstract_type_spec.is_federated(): if not concrete_type_spec.is_federated(): raise TypeError(type_error_string) new_member = _concretize_abstract_types(abstract_type_spec.member, concrete_type_spec.member) return computation_types.FederatedType(new_member, abstract_type_spec.placement, abstract_type_spec.all_equal) else: raise TypeError( 'Unexpected abstract typespec {}.'.format(abstract_type_spec))
def transform_type_postorder( type_signature: computation_types.Type, transform_fn: Callable[[computation_types.Type], Tuple[computation_types.Type, bool]]): """Walks type tree of `type_signature` postorder, calling `transform_fn`. Args: type_signature: Instance of `computation_types.Type` to transform recursively. transform_fn: Transformation function to apply to each node in the type tree of `type_signature`. Must be instance of Python function type. Returns: A possibly transformed version of `type_signature`, with each node in its tree the result of applying `transform_fn` to the corresponding node in `type_signature`. Raises: TypeError: If the types don't match the specification above. """ py_typecheck.check_type(type_signature, computation_types.Type) py_typecheck.check_callable(transform_fn) if type_signature.is_federated(): transformed_member, member_mutated = transform_type_postorder( type_signature.member, transform_fn) if member_mutated: type_signature = computation_types.FederatedType(transformed_member, type_signature.placement, type_signature.all_equal) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or member_mutated elif type_signature.is_sequence(): transformed_element, element_mutated = transform_type_postorder( type_signature.element, transform_fn) if element_mutated: type_signature = computation_types.SequenceType(transformed_element) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or element_mutated elif type_signature.is_function(): if type_signature.parameter is not None: transformed_parameter, parameter_mutated = transform_type_postorder( type_signature.parameter, transform_fn) else: transformed_parameter, parameter_mutated = (None, False) transformed_result, result_mutated = transform_type_postorder( type_signature.result, transform_fn) if parameter_mutated or result_mutated: type_signature = computation_types.FunctionType(transformed_parameter, transformed_result) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, ( type_signature_mutated or parameter_mutated or result_mutated) elif type_signature.is_struct(): elements = [] elements_mutated = False for element in structure.iter_elements(type_signature): transformed_element, element_mutated = transform_type_postorder( element[1], transform_fn) elements_mutated = elements_mutated or element_mutated elements.append((element[0], transformed_element)) if elements_mutated: if type_signature.is_struct_with_python(): type_signature = computation_types.StructWithPythonType( elements, type_signature.python_container) else: type_signature = computation_types.StructType(elements) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or elements_mutated elif type_signature.is_abstract() or type_signature.is_placement( ) or type_signature.is_tensor(): return transform_fn(type_signature)
def _deserialize_dataset_from_graph_def(serialized_graph_def: bytes, element_type: computation_types.Type): """Deserializes a serialized `tf.compat.v1.GraphDef` to a `tf.data.Dataset`. Args: serialized_graph_def: `bytes` object produced by `tensorflow_serialization.serialize_dataset` element_type: a `tff.Type` object representing the type structure of the elements yielded from the dataset. Returns: A `tf.data.Dataset` instance. """ py_typecheck.check_type(element_type, computation_types.Type) type_analysis.check_tensorflow_compatible_type(element_type) def transform_to_tff_known_type( type_spec: computation_types.Type ) -> Tuple[computation_types.Type, bool]: """Transforms `StructType` to `StructWithPythonType`.""" if type_spec.is_struct() and not type_spec.is_struct_with_python(): field_is_named = tuple( name is not None for name, _ in structure.iter_elements(type_spec)) has_names = any(field_is_named) is_all_named = all(field_is_named) if is_all_named: return computation_types.StructWithPythonType( elements=structure.iter_elements(type_spec), container_type=collections.OrderedDict), True elif not has_names: return computation_types.StructWithPythonType( elements=structure.iter_elements(type_spec), container_type=tuple), True else: raise TypeError( 'Cannot represent TFF type in TF because it contains ' f'partially named structures. Type: {type_spec}') return type_spec, False if element_type.is_struct(): # TF doesn't suppor `structure.Strut` types, so we must transform the # `StructType` into a `StructWithPythonType` for use as the # `tf.data.Dataset.element_spec` later. tf_compatible_type, _ = type_transformations.transform_type_postorder( element_type, transform_to_tff_known_type) else: # We've checked this is only a struct or tensors, so we know this is a # `TensorType` here and will use as-is. tf_compatible_type = element_type def type_to_tensorspec(t: computation_types.TensorType) -> tf.TensorSpec: return tf.TensorSpec(shape=t.shape, dtype=t.dtype) element_spec = type_conversions.structure_from_tensor_type_tree( type_to_tensorspec, tf_compatible_type) ds = tf.data.experimental.from_variant( tf.raw_ops.DatasetFromGraph(graph_def=serialized_graph_def), structure=element_spec) # If a serialized dataset had elements of nested structes of tensors (e.g. # `dict`, `OrderedDict`), the deserialized dataset will return `dict`, # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion). # # Since the dataset will only be used inside TFF, we wrap the dictionary # coming from TF in an `OrderedDict` when necessary (a type that both TF and # TFF understand), using the field order stored in the TFF type stored during # serialization. return tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( ds, tf_compatible_type)