def _ensure_deserialized_types_compatible( previous_type: Optional[computation_types.Type], next_type: computation_types.Type) -> computation_types.Type: """Ensures one of `previous_type` or `next_type` is assignable to the other. Returns the type which is assignable from the other. Args: previous_type: Instance of `computation_types.Type` or `None`. next_type: Instance of `computation_types.Type`. Returns: The supertype of `previous_type` and `next_type`. Raises: TypeError if neither type is assignable from the other. """ if previous_type is None: return next_type else: if next_type.is_assignable_from(previous_type): return next_type elif previous_type.is_assignable_from(next_type): return previous_type raise TypeError('Type mismatch checking member assignability under a ' 'federated value. Deserialized type {} is incompatible ' 'with previously deserialized {}.'.format( next_type, previous_type))
def _to_tensor_internal_rep(*, value: Any, type_spec: computation_types.Type) -> tf.Tensor: """Normalizes tensor-like value to a tf.Tensor.""" if not tf.is_tensor(value): value = tf.convert_to_tensor(value, dtype=type_spec.dtype) elif hasattr(value, 'read_value'): # a tf.Variable-like result, get a proper tensor. value = value.read_value() value_type = (computation_types.TensorType(value.dtype.base_dtype, value.shape)) if not type_spec.is_assignable_from(value_type): raise TypeError( 'The apparent type {} of a tensor {} does not match the expected ' 'type {}.'.format(value_type, value, type_spec)) return value
def check_type(value: Any, type_spec: computation_types.Type): """Checks whether `val` is of TFF type `type_spec`. Args: value: The object to check. type_spec: A `computation_types.Type`, the type that `value` is checked against. Raises: TypeError: If the infferred type of `value` is not `type_spec`. """ py_typecheck.check_type(type_spec, computation_types.Type) value_type = type_conversions.infer_type(value) if not type_spec.is_assignable_from(value_type): raise TypeError( 'Expected TFF type {}, which is not assignable from {}.'.format( type_spec, value_type))
def create_constant(value, type_spec: computation_types.Type) -> ProtoAndType: """Returns a tensorflow computation returning a constant `value`. The returned computation has the type signature `( -> T)`, where `T` is `type_spec`. `value` must be a value convertible to a tensor or a structure of values, such that the dtype and shapes match `type_spec`. `type_spec` must contain only named tuples and tensor types, but these can be arbitrarily nested. Args: value: A value to embed as a constant in the tensorflow graph. type_spec: A `computation_types.Type` to use as the argument to the constructed binary operator; must contain only named tuples and tensor types. Raises: TypeError: If the constraints of `type_spec` are violated. """ if not type_analysis.is_generic_op_compatible_type(type_spec): raise TypeError( 'Type spec {} cannot be constructed as a TensorFlow constant in TFF; ' ' only nested tuples and tensors are permitted.'.format(type_spec)) inferred_value_type = type_conversions.infer_type(value) if (inferred_value_type.is_struct() and not type_spec.is_assignable_from(inferred_value_type)): raise TypeError( 'Must pass a only tensor or structure of tensor values to ' '`create_tensorflow_constant`; encountered a value {v} with inferred ' 'type {t!r}, but needed {s!r}'.format(v=value, t=inferred_value_type, s=type_spec)) if inferred_value_type.is_struct(): value = structure.from_container(value, recursive=True) tensor_dtypes_in_type_spec = [] def _pack_dtypes(type_signature): """Appends dtype of `type_signature` to nonlocal variable.""" if type_signature.is_tensor(): tensor_dtypes_in_type_spec.append(type_signature.dtype) return type_signature, False type_transformations.transform_type_postorder(type_spec, _pack_dtypes) if (any(x.is_integer for x in tensor_dtypes_in_type_spec) and (inferred_value_type.is_tensor() and not inferred_value_type.dtype.is_integer)): raise TypeError( 'Only integers can be used as scalar values if our desired constant ' 'type spec contains any integer tensors; passed scalar {} of dtype {} ' 'for type spec {}.'.format(value, inferred_value_type.dtype, type_spec)) result_type = type_spec def _create_result_tensor(type_spec, value): """Packs `value` into `type_spec` recursively.""" if type_spec.is_tensor(): type_spec.shape.assert_is_fully_defined() result = tf.constant(value, dtype=type_spec.dtype, shape=type_spec.shape) else: elements = [] if inferred_value_type.is_struct(): # Copy the leaf values according to the type_spec structure. for (name, elem_type), value in zip( structure.iter_elements(type_spec), value): elements.append( (name, _create_result_tensor(elem_type, value))) else: # "Broadcast" the value to each level of the type_spec structure. for _, elem_type in structure.iter_elements(type_spec): elements.append( (None, _create_result_tensor(elem_type, value))) result = structure.Struct(elements) return result with tf.Graph().as_default() as graph: result = _create_result_tensor(result_type, value) _, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(None, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=None, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)