def is_valid_bitwidth_type_for_value_type( bitwidth_type: computation_types.Type, value_type: computation_types.Type) -> bool: """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`.""" # NOTE: this function is primarily a helper for `intrinsic_factory.py`'s # `federated_secure_sum` function. py_typecheck.check_type(bitwidth_type, computation_types.Type) py_typecheck.check_type(value_type, computation_types.Type) if value_type.is_tensor() and bitwidth_type.is_tensor(): # Here, `value_type` refers to a tensor. Rather than check that # `bitwidth_type` is exactly the same, we check that it is a single integer, # since we want a single bitwidth integer per tensor. return bitwidth_type.dtype.is_integer and ( bitwidth_type.shape.num_elements() == 1) elif value_type.is_tuple() and bitwidth_type.is_tuple(): bitwidth_name_and_types = list( anonymous_tuple.iter_elements(bitwidth_type)) value_name_and_types = list(anonymous_tuple.iter_elements(value_type)) if len(bitwidth_name_and_types) != len(value_name_and_types): return False for (inner_bitwidth_name, inner_bitwidth_type), (inner_value_name, inner_value_type) in zip( bitwidth_name_and_types, value_name_and_types): if inner_bitwidth_name != inner_value_name: return False if not is_valid_bitwidth_type_for_value_type( inner_bitwidth_type, inner_value_type): return False return True else: return False
def is_sum_compatible(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` is a type that can be added to itself. Types that are sum-compatible are composed of scalars of numeric types, possibly packaged into nested named tuples, and possibly federated. Types that are sum-incompatible include sequences, functions, abstract types, and placements. Args: type_spec: A `computation_types.Type`. Returns: `True` iff `type_spec` is sum-compatible, `False` otherwise. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return is_numeric_dtype(type_spec.dtype) elif type_spec.is_tuple(): return all( is_sum_compatible(v) for _, v in anonymous_tuple.iter_elements(type_spec)) elif type_spec.is_federated(): return is_sum_compatible(type_spec.member) else: return False
def type_to_tf_structure(type_spec: computation_types.Type): """Returns nested `tf.data.experimental.Structure` for a given TFF type. Args: type_spec: A `computation_types.Type`, the type specification must be composed of only named tuples and tensors. In all named tuples that appear in the type spec, all the elements must be named. Returns: An instance of `tf.data.experimental.Structure`, possibly nested, that corresponds to `type_spec`. Raises: ValueError: if the `type_spec` is composed of something other than named tuples and tensors, or if any of the elements in named tuples are unnamed. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return tf.TensorSpec(type_spec.shape, type_spec.dtype) elif type_spec.is_tuple(): elements = anonymous_tuple.to_elements(type_spec) if not elements: raise ValueError('Empty tuples are unsupported.') element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements] named = element_outputs[0][0] is not None if not all((e[0] is not None) == named for e in element_outputs): raise ValueError('Tuple elements inconsistently named.') if not type_spec.is_tuple_with_py_container(): if named: output = collections.OrderedDict(element_outputs) else: output = tuple(v for _, v in element_outputs) else: container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type( type_spec) if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): output = container_type(**dict(element_outputs)) elif named: output = container_type(element_outputs) else: output = container_type(e if e[0] is not None else e[1] for e in element_outputs) return output else: raise ValueError('Unsupported type {}.'.format( py_typecheck.type_string(type(type_spec))))
def is_structure_of_integers(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` is a structure of integers. Args: type_spec: A `computation_types.Type`. Returns: `True` iff `type_spec` is a structure of integers, otherwise `False`. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): py_typecheck.check_type(type_spec.dtype, tf.DType) return type_spec.dtype.is_integer elif type_spec.is_tuple(): return all( is_structure_of_integers(v) for _, v in anonymous_tuple.iter_elements(type_spec)) elif type_spec.is_federated(): return is_structure_of_integers(type_spec.member) else: return False
def is_average_compatible(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` can be averaged. Types that are average-compatible are composed of numeric tensor types, either floating-point or complex, possibly packaged into nested named tuples, and possibly federated. Args: type_spec: a `computation_types.Type`. Returns: `True` iff `type_spec` is average-compatible, `False` otherwise. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return type_spec.dtype.is_floating or type_spec.dtype.is_complex elif type_spec.is_tuple(): return all( is_average_compatible(v) for _, v in anonymous_tuple.iter_elements(type_spec)) elif type_spec.is_federated(): return is_average_compatible(type_spec.member) else: return False
def type_to_tf_dtypes_and_shapes(type_spec: computation_types.Type): """Returns nested structures of tensor dtypes and shapes for a given TFF type. The returned dtypes and shapes match those used by `tf.data.Dataset`s to indicate the type and shape of their elements. They can be used, e.g., as arguments in constructing an iterator over a string handle. Args: type_spec: A `computation_types.Type`, the type specification must be composed of only named tuples and tensors. In all named tuples that appear in the type spec, all the elements must be named. Returns: A pair of parallel nested structures with the dtypes and shapes of tensors defined in `type_spec`. The layout of the two structures returned is the same as the layout of the nested type defined by `type_spec`. Named tuples are represented as dictionaries. Raises: ValueError: if the `type_spec` is composed of something other than named tuples and tensors, or if any of the elements in named tuples are unnamed. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return (type_spec.dtype, type_spec.shape) elif type_spec.is_tuple(): elements = anonymous_tuple.to_elements(type_spec) if not elements: output_dtypes = [] output_shapes = [] elif elements[0][0] is not None: output_dtypes = collections.OrderedDict() output_shapes = collections.OrderedDict() for e in elements: element_name = e[0] element_spec = e[1] if element_name is None: raise ValueError( 'When a sequence appears as a part of a parameter to a section ' 'of TensorFlow code, in the type signature of elements of that ' 'sequence all named tuples must have their elements explicitly ' 'named, and this does not appear to be the case in {}.' .format(type_spec)) element_output = type_to_tf_dtypes_and_shapes(element_spec) output_dtypes[element_name] = element_output[0] output_shapes[element_name] = element_output[1] else: output_dtypes = [] output_shapes = [] for e in elements: element_name = e[0] element_spec = e[1] if element_name is not None: raise ValueError( 'When a sequence appears as a part of a parameter to a section ' 'of TensorFlow code, in the type signature of elements of that ' 'sequence all named tuples must have their elements explicitly ' 'named, and this does not appear to be the case in {}.' .format(type_spec)) element_output = type_to_tf_dtypes_and_shapes(element_spec) output_dtypes.append(element_output[0]) output_shapes.append(element_output[1]) if type_spec.python_container is not None: container_type = type_spec.python_container def build_py_container(elements): if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): return container_type(**dict(elements)) else: return container_type(elements) output_dtypes = build_py_container(output_dtypes) output_shapes = build_py_container(output_shapes) else: output_dtypes = tuple(output_dtypes) output_shapes = tuple(output_shapes) return (output_dtypes, output_shapes) else: raise ValueError('Unsupported type {}.'.format( py_typecheck.type_string(type(type_spec))))
def create_binary_operator( operator, operand_type: computation_types.Type) -> pb.Computation: """Returns a tensorflow computation computing a binary operation. The returned computation has the type signature `(<T,T> -> U)`, where `T` is `operand_type` and `U` is the result of applying the `operator` to a tuple of type `<T,T>` Note: If `operand_type` is a `computation_types.NamedTupleType`, then `operator` will be applied pointwise. This places the burden on callers of this function to construct the correct values to pass into the returned function. For example, to divide `[2, 2]` by `2`, first `2` must be packed into the data structure `[x, x]`, before the division operator of the appropriate type is called. Args: operator: A callable taking two arguments representing the operation to encode For example: `tf.math.add`, `tf.math.multiply`, and `tf.math.divide`. operand_type: A `computation_types.Type` to use as the argument to the constructed binary operator; must contain only named tuples and tensor types. Raises: TypeError: If the constraints of `operand_type` are violated or `operator` is not callable. """ if not type_analysis.is_generic_op_compatible_type(operand_type): raise TypeError( 'The type {} contains a type other than `computation_types.TensorType` ' 'and `computation_types.NamedTupleType`; this is disallowed in the ' 'generic operators.'.format(operand_type)) py_typecheck.check_callable(operator) with tf.Graph().as_default() as graph: operand_1_value, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', operand_type, graph) operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', operand_type, graph) if operand_type is not None: if operand_type.is_tensor(): result_value = operator(operand_1_value, operand_2_value) elif operand_type.is_tuple(): result_value = anonymous_tuple.map_structure(operator, operand_1_value, operand_2_value) else: raise TypeError( 'Operand type {} cannot be used in generic operations. The call to ' '`type_analysis.is_generic_op_compatible_type` has allowed it to ' 'pass, and should be updated.'.format(operand_type)) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType( computation_types.NamedTupleType((operand_type, operand_type)), result_type) parameter_binding = pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding( element=[operand_1_binding, operand_2_binding])) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)
def _concretize_abstract_types( abstract_type_spec: computation_types.Type, concrete_type_spec: computation_types.Type ) -> computation_types.Type: """Recursive helper function to construct concrete type spec.""" if abstract_type_spec.is_abstract(): bound_type = bound_abstract_types.get(str( abstract_type_spec.label)) if bound_type: return bound_type else: bound_abstract_types[str( abstract_type_spec.label)] = concrete_type_spec return concrete_type_spec elif abstract_type_spec.is_tensor(): return abstract_type_spec elif abstract_type_spec.is_tuple(): if not concrete_type_spec.is_tuple(): raise TypeError(type_error_string) abstract_elements = anonymous_tuple.to_elements(abstract_type_spec) concrete_elements = anonymous_tuple.to_elements(concrete_type_spec) if len(abstract_elements) != len(concrete_elements): raise TypeError(type_error_string) concretized_tuple_elements = [] for k in range(len(abstract_elements)): if abstract_elements[k][0] != concrete_elements[k][0]: raise TypeError(type_error_string) concretized_tuple_elements.append( (abstract_elements[k][0], _concretize_abstract_types(abstract_elements[k][1], concrete_elements[k][1]))) return computation_types.NamedTupleType(concretized_tuple_elements) elif abstract_type_spec.is_sequence(): if not concrete_type_spec.is_sequence(): raise TypeError(type_error_string) return computation_types.SequenceType( _concretize_abstract_types(abstract_type_spec.element, concrete_type_spec.element)) elif abstract_type_spec.is_function(): if not concrete_type_spec.is_function(): raise TypeError(type_error_string) if abstract_type_spec.parameter is None: if concrete_type_spec.parameter is not None: return TypeError(type_error_string) concretized_param = None else: concretized_param = _concretize_abstract_types( abstract_type_spec.parameter, concrete_type_spec.parameter) concretized_result = _concretize_abstract_types( abstract_type_spec.result, concrete_type_spec.result) return computation_types.FunctionType(concretized_param, concretized_result) elif abstract_type_spec.is_placement(): if not concrete_type_spec.is_placement(): raise TypeError(type_error_string) return abstract_type_spec elif abstract_type_spec.is_federated(): if not concrete_type_spec.is_federated(): raise TypeError(type_error_string) new_member = _concretize_abstract_types(abstract_type_spec.member, concrete_type_spec.member) return computation_types.FederatedType( new_member, abstract_type_spec.placement, abstract_type_spec.all_equal) else: raise TypeError( 'Unexpected abstract typespec {}.'.format(abstract_type_spec))
def create_binary_operator_with_upcast( type_signature: computation_types.Type, operator: Callable[[Any, Any], Any]) -> pb.Computation: """Creates TF computation upcasting its argument and applying `operator`. Args: type_signature: Value convertible to `computation_types.NamedTupleType`, with two elements, both of the same type or the second able to be upcast to the first, as explained in `apply_binary_operator_with_upcast`, and both containing only tuples and tensors in their type tree. operator: Callable defining the operator. Returns: A `building_blocks.CompiledComputation` encapsulating a function which upcasts the second element of its argument and applies the binary operator. """ py_typecheck.check_callable(operator) type_signature = computation_types.to_type(type_signature) type_analysis.check_tensorflow_compatible_type(type_signature) if not type_signature.is_tuple() or len(type_signature) != 2: raise TypeError( 'To apply a binary operator, we must by definition have an ' 'argument which is a `NamedTupleType` with 2 elements; ' 'asked to create a binary operator for type: {t}'.format( t=type_signature)) if type_analysis.contains(type_signature, lambda t: t.is_sequence()): raise TypeError('Applying binary operators in TensorFlow is only ' 'supported on Tensors and NamedTupleTypes; you ' 'passed {t} which contains a SequenceType.'.format( t=type_signature)) def _pack_into_type(to_pack, type_spec): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if type_spec.is_tuple(): elem_iter = anonymous_tuple.iter_elements(type_spec) return anonymous_tuple.AnonymousTuple([ (elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elem_iter ]) elif type_spec.is_tensor(): return tf.broadcast_to(to_pack, type_spec.shape) with tf.Graph().as_default() as graph: first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', type_signature[0], graph) operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', type_signature[1], graph) if type_signature[0].is_equivalent_to(type_signature[1]): second_arg = operand_2_value else: second_arg = _pack_into_type(operand_2_value, type_signature[0]) if type_signature[0].is_tensor(): result_value = operator(first_arg, second_arg) elif type_signature[0].is_tuple(): result_value = anonymous_tuple.map_structure( operator, first_arg, second_arg) else: raise TypeError( 'Encountered unexpected type {t}; can only handle Tensor ' 'and NamedTupleTypes.'.format(t=type_signature[0])) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType(type_signature, result_type) parameter_binding = pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding( element=[operand_1_binding, operand_2_binding])) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)
def transform_type_postorder( type_signature: computation_types.Type, transform_fn: Callable[[computation_types.Type], Tuple[computation_types.Type, bool]]): """Walks type tree of `type_signature` postorder, calling `transform_fn`. Args: type_signature: Instance of `computation_types.Type` to transform recursively. transform_fn: Transformation function to apply to each node in the type tree of `type_signature`. Must be instance of Python function type. Returns: A possibly transformed version of `type_signature`, with each node in its tree the result of applying `transform_fn` to the corresponding node in `type_signature`. Raises: TypeError: If the types don't match the specification above. """ py_typecheck.check_type(type_signature, computation_types.Type) py_typecheck.check_callable(transform_fn) if type_signature.is_federated(): transformed_member, member_mutated = transform_type_postorder( type_signature.member, transform_fn) if member_mutated: type_signature = computation_types.FederatedType( transformed_member, type_signature.placement, type_signature.all_equal) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or member_mutated elif type_signature.is_sequence(): transformed_element, element_mutated = transform_type_postorder( type_signature.element, transform_fn) if element_mutated: type_signature = computation_types.SequenceType( transformed_element) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or element_mutated elif type_signature.is_function(): if type_signature.parameter is not None: transformed_parameter, parameter_mutated = transform_type_postorder( type_signature.parameter, transform_fn) else: transformed_parameter, parameter_mutated = (None, False) transformed_result, result_mutated = transform_type_postorder( type_signature.result, transform_fn) if parameter_mutated or result_mutated: type_signature = computation_types.FunctionType( transformed_parameter, transformed_result) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, (type_signature_mutated or parameter_mutated or result_mutated) elif type_signature.is_tuple(): elements = [] elements_mutated = False for element in anonymous_tuple.iter_elements(type_signature): transformed_element, element_mutated = transform_type_postorder( element[1], transform_fn) elements_mutated = elements_mutated or element_mutated elements.append((element[0], transformed_element)) if elements_mutated: if type_signature.is_tuple_with_py_container(): container_type = computation_types.NamedTupleTypeWithPyContainerType.get_container_type( type_signature) type_signature = computation_types.NamedTupleTypeWithPyContainerType( elements, container_type) else: type_signature = computation_types.NamedTupleType(elements) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or elements_mutated elif type_signature.is_abstract() or type_signature.is_placement( ) or type_signature.is_tensor(): return transform_fn(type_signature)
def _is_compatible(t: computation_types.Type) -> bool: return type_analysis.contains_only( t, lambda t: t.is_tuple() or t.is_tensor() or t.is_federated())