def assemble_result_from_graph(type_spec, binding, output_map): """Assembles a result stamped into a `tf.Graph` given type signature/binding. This method does roughly the opposite of `capture_result_from_graph`, in that whereas `capture_result_from_graph` starts with a single structured object made up of tensors and computes its type and bindings, this method starts with the type/bindings and constructs a structured object made up of tensors. Args: type_spec: The type signature of the result to assemble, an instance of `types.Type` or something convertible to it. binding: The binding that relates the type signature to names of tensors in the graph, an instance of `pb.TensorFlow.Binding`. output_map: The mapping from tensor names that appear in the binding to actual stamped tensors (possibly renamed during import). Returns: The assembled result, a Python object that is composed of tensors, possibly nested within Python structures such as anonymous tuples. Raises: TypeError: If the argument or any of its parts are of an uexpected type. ValueError: If the arguments are invalid or inconsistent witch other, e.g., the type and binding don't match, or the tensor is not found in the map. """ type_spec = computation_types.to_type(type_spec) py_typecheck.check_type(type_spec, computation_types.Type) py_typecheck.check_type(binding, pb.TensorFlow.Binding) py_typecheck.check_type(output_map, dict) for k, v in output_map.items(): py_typecheck.check_type(k, str) if not tf.is_tensor(v): raise TypeError( 'Element with key {} in the output map is {}, not a tensor.'.format( k, py_typecheck.type_string(type(v)))) binding_oneof = binding.WhichOneof('binding') if type_spec.is_tensor(): if binding_oneof != 'tensor': raise ValueError( 'Expected a tensor binding, found {}.'.format(binding_oneof)) elif binding.tensor.tensor_name not in output_map: raise ValueError('Tensor named {} not found in the output map.'.format( binding.tensor.tensor_name)) else: return output_map[binding.tensor.tensor_name] elif type_spec.is_struct(): if binding_oneof != 'struct': raise ValueError( 'Expected a struct binding, found {}.'.format(binding_oneof)) else: type_elements = structure.to_elements(type_spec) if len(binding.struct.element) != len(type_elements): raise ValueError( 'Mismatching tuple sizes in type ({}) and binding ({}).'.format( len(type_elements), len(binding.struct.element))) result_elements = [] for (element_name, element_type), element_binding in zip(type_elements, binding.struct.element): element_object = assemble_result_from_graph(element_type, element_binding, output_map) result_elements.append((element_name, element_object)) if type_spec.python_container is None: return structure.Struct(result_elements) container_type = type_spec.python_container if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): return container_type(**dict(result_elements)) return container_type(result_elements) elif type_spec.is_sequence(): if binding_oneof != 'sequence': raise ValueError( 'Expected a sequence binding, found {}.'.format(binding_oneof)) else: sequence_oneof = binding.sequence.WhichOneof('binding') if sequence_oneof == 'variant_tensor_name': variant_tensor = output_map[binding.sequence.variant_tensor_name] return make_dataset_from_variant_tensor(variant_tensor, type_spec.element) else: raise ValueError( 'Unsupported sequence binding \'{}\'.'.format(sequence_oneof)) else: raise ValueError('Unsupported type \'{}\'.'.format(type_spec))
def is_container_type_with_names(container_type): return (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type) or issubclass(container_type, dict))
def capture_result_from_graph(result, graph): """Captures a result stamped into a tf.Graph as a type signature and binding. Args: result: The result to capture, a Python object that is composed of tensors, possibly nested within Python structures such as dictionaries, lists, tuples, or named tuples. graph: The instance of tf.Graph to use. Returns: A tuple (type_spec, binding), where 'type_spec' is an instance of computation_types.Type that describes the type of the result, and 'binding' is an instance of TensorFlow.Binding that indicates how parts of the result type relate to the tensors and ops that appear in the result. Raises: TypeError: If the argument or any of its parts are of an uexpected type. """ def _get_bindings_for_elements(name_value_pairs, graph, type_fn): """Build `(type_spec, binding)` tuple for name value pairs.""" element_name_type_binding_triples = [ ((k,) + capture_result_from_graph(v, graph)) for k, v in name_value_pairs ] type_spec = type_fn([((e[0], e[1]) if e[0] else e[1]) for e in element_name_type_binding_triples]) binding = pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=[e[2] for e in element_name_type_binding_triples])) return type_spec, binding # TODO(b/113112885): The emerging extensions for serializing SavedModels may # end up introducing similar concepts of bindings, etc., we should look here # into the possibility of reusing some of that code when it's available. if isinstance(result, TENSOR_REPRESENTATION_TYPES): with graph.as_default(): result = tf.constant(result) if tf.is_tensor(result): if hasattr(result, 'read_value'): # We have a tf.Variable-like result, get a proper tensor to fetch. with graph.as_default(): result = result.read_value() return (computation_types.TensorType(result.dtype.base_dtype, result.shape), pb.TensorFlow.Binding( tensor=pb.TensorFlow.TensorBinding(tensor_name=result.name))) elif py_typecheck.is_named_tuple(result): # Special handling needed for collections.namedtuples since they do not have # anything in the way of a shared base class. Note we don't want to rely on # the fact that collections.namedtuples inherit from 'tuple' because we'd be # failing to retain the information about naming of tuple members. # pylint: disable=protected-access name_value_pairs = result._asdict().items() # pylint: enable=protected-access return _get_bindings_for_elements( name_value_pairs, graph, functools.partial( computation_types.StructWithPythonType, container_type=type(result))) elif py_typecheck.is_attrs(result): name_value_pairs = attr.asdict( result, dict_factory=collections.OrderedDict, recurse=False) return _get_bindings_for_elements( name_value_pairs.items(), graph, functools.partial( computation_types.StructWithPythonType, container_type=type(result))) elif isinstance(result, structure.Struct): return _get_bindings_for_elements( structure.to_elements(result), graph, computation_types.StructType) elif isinstance(result, collections.abc.Mapping): if isinstance(result, collections.OrderedDict): name_value_pairs = result.items() else: name_value_pairs = sorted(result.items()) return _get_bindings_for_elements( name_value_pairs, graph, functools.partial( computation_types.StructWithPythonType, container_type=type(result))) elif isinstance(result, (list, tuple)): element_type_binding_pairs = [ capture_result_from_graph(e, graph) for e in result ] return (computation_types.StructWithPythonType( [e[0] for e in element_type_binding_pairs], type(result)), pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=[e[1] for e in element_type_binding_pairs]))) elif isinstance(result, type_conversions.TF_DATASET_REPRESENTATION_TYPES): variant_tensor = tf.data.experimental.to_variant(result) element_structure = result.element_spec try: element_type = computation_types.to_type(element_structure) except TypeError as e: raise TypeError( 'TFF does not support Datasets that yield elements of structure {!s}' .format(element_structure)) from e return (computation_types.SequenceType(element_type), pb.TensorFlow.Binding( sequence=pb.TensorFlow.SequenceBinding( variant_tensor_name=variant_tensor.name))) else: raise TypeError('Cannot capture a result of an unsupported type {}.'.format( py_typecheck.type_string(type(result))))
def capture_result_from_graph(result, graph): """Captures a result stamped into a tf.Graph as a type signature and binding. Args: result: The result to capture, a Python object that is composed of tensors, possibly nested within Python structures such as dictionaries, lists, tuples, or named tuples. graph: The instance of tf.Graph to use. Returns: A tuple (type_spec, binding), where 'type_spec' is an instance of computation_types.Type that describes the type of the result, and 'binding' is an instance of TensorFlow.Binding that indicates how parts of the result type relate to the tensors and ops that appear in the result. Raises: TypeError: If the argument or any of its parts are of an uexpected type. """ # TODO(b/113112885): The emerging extensions for serializing SavedModels may # end up introducing similar concepts of bindings, etc., we should look here # into the possibility of reusing some of that code when it's available. if isinstance(result, dtype_utils.TENSOR_REPRESENTATION_TYPES): with graph.as_default(): result = tf.constant(result) if tf.contrib.framework.is_tensor(result): if hasattr(result, 'read_value'): # We have a tf.Variable-like result, get a proper tensor to fetch. with graph.as_default(): result = result.read_value() return (computation_types.TensorType(result.dtype.base_dtype, result.shape), pb.TensorFlow.Binding( tensor=pb.TensorFlow.TensorBinding(tensor_name=result.name))) elif py_typecheck.is_named_tuple(result): # Special handling needed for collections.namedtuples since they do not have # anything in the way of a shared base class. Note we don't want to rely on # the fact that collections.namedtuples inherit from 'tuple' because we'd be # failing to retain the information about naming of tuple members. # pylint: disable=protected-access return capture_result_from_graph(result._asdict(), graph) # pylint: enable=protected-access elif isinstance(result, (dict, anonymous_tuple.AnonymousTuple)): if isinstance(result, anonymous_tuple.AnonymousTuple): name_value_pairs = anonymous_tuple.to_elements(result) elif isinstance(result, collections.OrderedDict): name_value_pairs = six.iteritems(result) else: name_value_pairs = sorted(six.iteritems(result)) element_name_type_binding_triples = [ ((k,) + capture_result_from_graph(v, graph)) for k, v in name_value_pairs ] return (computation_types.NamedTupleType([ ((e[0], e[1]) if e[0] else e[1]) for e in element_name_type_binding_triples ]), pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding( element=[e[2] for e in element_name_type_binding_triples]))) elif isinstance(result, (list, tuple)): element_type_binding_pairs = [ capture_result_from_graph(e, graph) for e in result ] return (computation_types.NamedTupleType( [e[0] for e in element_type_binding_pairs]), pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding( element=[e[1] for e in element_type_binding_pairs]))) elif isinstance(result, DATASET_REPRESENTATION_TYPES): element_type = type_utils.tf_dtypes_and_shapes_to_type( result.output_types, result.output_shapes) handle_name = result.make_one_shot_iterator().string_handle().name return (computation_types.SequenceType(element_type), pb.TensorFlow.Binding( sequence=pb.TensorFlow.SequenceBinding( iterator_string_handle_name=handle_name))) else: raise TypeError('Cannot capture a result of an unsupported type {}.'.format( py_typecheck.type_string(type(result))))
def is_container_type_without_names(container_type): return (issubclass(container_type, (list, tuple)) and not py_typecheck.is_named_tuple(container_type))
def append_to_list_structure_for_element_type_spec(structure, value, type_spec): """Adds an element `value` to a nested `structure` of lists for `type_spec`. This function appends tensor-level constituents of an element `value` to the lists created by `make_empty_list_structure_for_element_type_spec`. The nested structure of `value` must match that created by the above function, and consistent with `type_spec`. Args: structure: Output of `make_empty_list_structure_for_element_type_spec`. value: A value (Python object) that a hierarchical structure of dictionary, list, and other containers holding tensor-like items that matches the hierarchy of `type_spec`. type_spec: An instance of `tff.Type` or something convertible to it, as in `make_empty_list_structure_for_element_type_spec`. Raises: TypeError: If the `type_spec` is not of a form described above, or the value is not of a type compatible with `type_spec`. """ if value is None: return type_spec = computation_types.to_type(type_spec) py_typecheck.check_type(type_spec, computation_types.Type) # TODO(b/113116813): This could be made more efficient, but for now we won't # need to worry about it as this is an odd corner case. if isinstance(value, anonymous_tuple.AnonymousTuple): elements = anonymous_tuple.to_elements(value) if all(k is not None for k, _ in elements): value = collections.OrderedDict(elements) elif all(k is None for k, _ in elements): value = tuple([v for _, v in elements]) else: raise TypeError( 'Expected an anonymous tuple to either have all elements named ' 'or all unnamed, got {}.'.format(str(value))) if isinstance(type_spec, computation_types.TensorType): py_typecheck.check_type(structure, list) structure.append(value) elif isinstance(type_spec, computation_types.NamedTupleType): elements = anonymous_tuple.to_elements(type_spec) if isinstance(structure, collections.OrderedDict): if py_typecheck.is_named_tuple(value): value = value._asdict() if isinstance(value, dict): if set(value.keys()) != set(k for k, _ in elements): raise TypeError('Value {} does not match type {}.'.format( str(value), str(type_spec))) for elem_name, elem_type in elements: append_to_list_structure_for_element_type_spec( structure[elem_name], value[elem_name], elem_type) elif isinstance(value, (list, tuple)): if len(value) != len(elements): raise TypeError('Value {} does not match type {}.'.format( str(value), str(type_spec))) for idx, (elem_name, elem_type) in enumerate(elements): append_to_list_structure_for_element_type_spec( structure[elem_name], value[idx], elem_type) else: raise TypeError( 'Unexpected type of value {} for TFF type {}.'.format( py_typecheck.type_string(type(value)), str(type_spec))) elif isinstance(structure, tuple): py_typecheck.check_type(value, (list, tuple)) if len(value) != len(elements): raise TypeError('Value {} does not match type {}.'.format( str(value), str(type_spec))) for idx, (_, elem_type) in enumerate(elements): append_to_list_structure_for_element_type_spec( structure[idx], value[idx], elem_type) else: raise TypeError( 'Invalid nested structure, unexpected container type {}.'. format(py_typecheck.type_string(type(structure)))) else: raise TypeError( 'Expected a tensor or named tuple type, found {}.'.format( str(type_spec)))
def _convert(value, recursive, must_be_container=False): """The actual conversion function. Args: value: Same as in `from_container`. recursive: Same as in `from_container`. must_be_container: When set to `True`, causes an exception to be raised if `value` is not a container. Returns: The result of conversion. Raises: TypeError: If `value` is not a container and `must_be_container` has been set to `True`. """ if isinstance(value, Struct): if recursive: return Struct( (k, _convert(v, True)) for k, v in iter_elements(value)) else: return value elif py_typecheck.is_attrs(value): return _convert( attr.asdict(value, dict_factory=collections.OrderedDict, recurse=False), recursive, must_be_container) elif py_typecheck.is_named_tuple(value): return _convert( # In Python 3.8 and later `_asdict` no longer return OrdereDict, # rather a regular `dict`. collections.OrderedDict(value._asdict()), recursive, must_be_container) elif isinstance(value, collections.OrderedDict): items = value.items() if recursive: return Struct((k, _convert(v, True)) for k, v in items) else: return Struct(items) elif isinstance(value, dict): items = sorted(list(value.items())) if recursive: return Struct((k, _convert(v, True)) for k, v in items) else: return Struct(items) elif isinstance(value, (tuple, list)): if recursive: return Struct((None, _convert(v, True)) for v in value) else: return Struct((None, v) for v in value) elif isinstance(value, tf.RaggedTensor): if recursive: nested_row_splits = _convert(value.nested_row_splits, True) else: nested_row_splits = value.nested_row_splits return Struct([('flat_values', value.flat_values), ('nested_row_splits', nested_row_splits)]) elif isinstance(value, tf.SparseTensor): # Each element is a tensor return Struct([('indices', value.indices), ('values', value.values), ('dense_shape', value.dense_shape)]) elif must_be_container: raise TypeError( 'Unable to convert a Python object of type {} into ' 'an `Struct`. Object: {}'.format( py_typecheck.type_string(type(value)), value)) else: return value
def type_to_py_container(value, type_spec): """Recursively convert `anonymous_tuple.AnonymousTuple`s to Python containers. This is in some sense the inverse operation to `anonymous_tuple.from_container`. Args: value: A structure of anonymous tuples of values corresponding to `type_spec`. type_spec: The `tff.Type` to which value should conform, possibly including `computation_types.NamedTupleTypeWithPyContainerType`. Returns: The input value, with containers converted to appropriate Python containers as specified by the `type_spec`. Raises: ValueError: If the conversion is not possible due to a mix of named and unnamed values. """ if type_spec.is_federated(): structure_type_spec = type_spec.member else: structure_type_spec = type_spec if structure_type_spec.is_sequence(): element_type = structure_type_spec.element if isinstance(value, list): return [ type_to_py_container(element, element_type) for element in value ] if isinstance(value, tf.data.Dataset): # `tf.data.Dataset` does not understand `AnonymousTuple`, so the dataset # in `value` must already be yielding Python containers. This is because # when TFF is constructing datasets it always uses the proper Python # container, so we simply return `value` here without modification. return value raise TypeError('Unexpected Python type for TF type {}: {}'.format( structure_type_spec, type(value))) if not structure_type_spec.is_struct(): return value if not isinstance(value, anonymous_tuple.AnonymousTuple): # NOTE: When encountering non-anonymous tuples, we assume that # this means that we're attempting to re-convert a value that # already has the proper containers, and we short-circuit to # avoid re-converting. This is a possibly dangerous assumption. return value anon_tuple = value def is_container_type_without_names(container_type): return (issubclass(container_type, (list, tuple)) and not py_typecheck.is_named_tuple(container_type)) def is_container_type_with_names(container_type): return (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type) or issubclass(container_type, dict)) # TODO(b/133228705): Consider requiring NamedTupleTypeWithPyContainerType. container_type = structure_type_spec.python_container or anonymous_tuple.AnonymousTuple container_is_anon_tuple = structure_type_spec.python_container is None # Avoid projecting the `anonymous_tuple.AnonymousTuple` into a Python # container that is not supported. if not container_is_anon_tuple: num_named_elements = len(dir(anon_tuple)) num_unnamed_elements = len(anon_tuple) - num_named_elements if num_named_elements > 0 and num_unnamed_elements > 0: raise ValueError( 'Cannot represent value {} with container type {}, ' 'because value contains a mix of named and unnamed ' 'elements.'.format(anon_tuple, container_type)) if (num_named_elements > 0 and is_container_type_without_names(container_type)): # Note: This could be relaxed in some cases if needed. raise ValueError( 'Cannot represent value {} with named elements ' 'using container type {} which does not support names.'.format( anon_tuple, container_type)) if (num_unnamed_elements > 0 and is_container_type_with_names(container_type)): # Note: This could be relaxed in some cases if needed. raise ValueError( 'Cannot represent value {} with unnamed elements ' 'with container type {} which requires names.'.format( anon_tuple, container_type)) elements = [] for index, (elem_name, elem_type) in enumerate( anonymous_tuple.iter_elements(structure_type_spec)): value = type_to_py_container(anon_tuple[index], elem_type) if elem_name is None and not container_is_anon_tuple: elements.append(value) else: elements.append((elem_name, value)) if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): # The namedtuple and attr.s class constructors cannot interpret a list of # (name, value) tuples; instead call constructor using kwargs. Note that # these classes already define an order of names internally, so order does # not matter. return container_type(**dict(elements)) else: # E.g., tuple and list when elements only has values, but also `dict`, # `collections.OrderedDict`, or `anonymous_tuple.AnonymousTuple` when # elements has (name, value) tuples. return container_type(elements)
def type_to_py_container(value, type_spec): """Recursively convert `AnonymousTuple`s to Python containers. This is in some sense the inverse operation to `anonymous_tuple.from_container`. Args: value: An `AnonymousTuple`, in which case this method recurses, replacing all `AnonymousTuple`s with the appropriate Python containers if possible (and keeping AnonymousTuple otherwise); or some other value, in which case that value is returned unmodified immediately (terminating the recursion). type_spec: The `tff.Type` to which value should conform, possibly including `NamedTupleTypeWithPyContainerType`. Returns: The input value, with containers converted to appropriate Python containers as specified by the `type_spec`. Raises: ValueError: If the conversion is not possible due to a mix of named and unnamed values. """ if not isinstance(value, anonymous_tuple.AnonymousTuple): return value anon_tuple = value if isinstance(type_spec, computation_types.FederatedType): structure_type_spec = type_spec.member else: structure_type_spec = type_spec py_typecheck.check_type(structure_type_spec, computation_types.NamedTupleType) def is_container_type_without_names(container_type): return (issubclass(container_type, (list, tuple)) and not py_typecheck.is_named_tuple(container_type)) def is_container_type_with_names(container_type): return (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type) or issubclass(container_type, dict)) if isinstance(structure_type_spec, computation_types.NamedTupleTypeWithPyContainerType): container_type = ( computation_types.NamedTupleTypeWithPyContainerType.get_container_type( structure_type_spec)) container_is_anon_tuple = False else: # TODO(b/133228705): Consider requiring NamedTupleTypeWithPyContainerType. container_is_anon_tuple = True container_type = anonymous_tuple.AnonymousTuple # Avoid projecting the AnonymousTuple into a Python container that is not # supported. if not container_is_anon_tuple: num_named_elements = len(dir(anon_tuple)) num_unnamed_elements = len(anon_tuple) - num_named_elements if num_named_elements > 0 and num_unnamed_elements > 0: raise ValueError('Cannot represent value {} with container type {}, ' 'because value contains a mix of named and unnamed ' 'elements.'.format(anon_tuple, container_type)) if (num_named_elements > 0 and is_container_type_without_names(container_type)): # Note: This could be relaxed in some cases if needed. raise ValueError( 'Cannot represent value {} with named elements ' 'using container type {} which does not support names.'.format( anon_tuple, container_type)) if (num_unnamed_elements > 0 and is_container_type_with_names(container_type)): # Note: This could be relaxed in some cases if needed. raise ValueError('Cannot represent value {} with unnamed elements ' 'with container type {} which requires names.'.format( anon_tuple, container_type)) elements = [] for index, (elem_name, elem_type) in enumerate( anonymous_tuple.iter_elements(structure_type_spec)): value = type_to_py_container(anon_tuple[index], elem_type) if elem_name is None and not container_is_anon_tuple: elements.append(value) else: elements.append((elem_name, value)) if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): # The namedtuple and attr.s class constructors cannot interpret a list of # (name, value) tuples; instead call constructor using kwargs. Note # that these classes already define an order of names internally, # so order does not matter. return container_type(**dict(elements)) else: # E.g., tuple and list when elements only has values, # but also dict, OrderedDict, or AnonymousTuple when # elements has (name, value) tuples. return container_type(elements)
def type_to_py_container(value, type_spec): """Recursively convert `structure.Struct`s to Python containers. This is in some sense the inverse operation to `structure.from_container`. Args: value: A structure of anonymous tuples of values corresponding to `type_spec`. type_spec: The `tff.Type` to which value should conform, possibly including `computation_types.StructWithPythonType`. Returns: The input value, with containers converted to appropriate Python containers as specified by the `type_spec`. Raises: ValueError: If the conversion is not possible due to a mix of named and unnamed values, or if `value` contains names that are mismatched or not present in the corresponding index of `type_spec`. """ if type_spec.is_federated(): if type_spec.all_equal: structure_type_spec = type_spec.member else: if not isinstance(value, list): raise TypeError( 'Unexpected Python type for non-all-equal TFF type ' f'{type_spec}: expected `list`, found `{type(value)}`.') return [ type_to_py_container(element, type_spec.member) for element in value ] else: structure_type_spec = type_spec if structure_type_spec.is_sequence(): element_type = structure_type_spec.element if isinstance(value, list): return [ type_to_py_container(element, element_type) for element in value ] if isinstance(value, tf.data.Dataset): # `tf.data.Dataset` does not understand `Struct`, so the dataset # in `value` must already be yielding Python containers. This is because # when TFF is constructing datasets it always uses the proper Python # container, so we simply return `value` here without modification. return value raise TypeError('Unexpected Python type for TFF type {}: {}'.format( structure_type_spec, type(value))) if not structure_type_spec.is_struct(): return value if not isinstance(value, structure.Struct): # NOTE: When encountering non-`structure.Struct`s, we assume that # this means that we're attempting to re-convert a value that # already has the proper containers, and we short-circuit to # avoid re-converting. This is a possibly dangerous assumption. return value container_type = structure_type_spec.python_container # Ensure that names are only added, not mismatched or removed names_from_value = structure.name_list_with_nones(value) names_from_type_spec = structure.name_list_with_nones(structure_type_spec) for value_name, type_name in zip(names_from_value, names_from_type_spec): if value_name is not None: if value_name != type_name: raise ValueError( f'Cannot convert value with field name `{value_name}` into a ' f'type with field name `{type_name}`.') num_named_elements = len(dir(structure_type_spec)) num_unnamed_elements = len(structure_type_spec) - num_named_elements if num_named_elements > 0 and num_unnamed_elements > 0: raise ValueError( f'Cannot represent value {value} with a Python container because it ' 'contains a mix of named and unnamed elements.\n\nNote: this was ' 'previously allowed when using the `tff.structure.Struct` container. ' 'This support has been removed: please change to use structures with ' 'either all-named or all-unnamed fields.') if container_type is None: if num_named_elements: container_type = collections.OrderedDict else: container_type = tuple # Avoid projecting the `structure.StructType`d TFF value into a Python # container that is not supported. if (num_named_elements > 0 and is_container_type_without_names(container_type)): raise ValueError( 'Cannot represent value {} with named elements ' 'using container type {} which does not support names. In TFF\'s ' 'typesystem, this corresponds to an implicit downcast'.format( value, container_type)) if (is_container_type_with_names(container_type) and len(dir(structure_type_spec)) != len(value)): # If the type specifies the names, we have all the information we need. # Otherwise we must raise here. raise ValueError( 'When packaging as a Python value which requires names, ' 'the TFF type spec must have all names specified. Found ' '{} names in type spec {} of length {}, with requested' 'python type {}.'.format(len(dir(structure_type_spec)), structure_type_spec, len(value), container_type)) elements = [] for index, (elem_name, elem_type) in enumerate( structure.iter_elements(structure_type_spec)): element = type_to_py_container(value[index], elem_type) if elem_name is None: elements.append(element) else: elements.append((elem_name, element)) if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type) or py_typecheck.is_dataclass(container_type) or container_type is tf.SparseTensor): # The namedtuple and attr.s class constructors cannot interpret a list of # (name, value) tuples; instead call constructor using kwargs. Note that # these classes already define an order of names internally, so order does # not matter. return container_type(**dict(elements)) elif container_type is tf.RaggedTensor: elements = dict(elements) return tf.RaggedTensor.from_nested_row_splits( elements['flat_values'], elements['nested_row_splits']) else: # E.g., tuple and list when elements only has values, but also `dict`, # `collections.OrderedDict`, or `structure.Struct` when # elements has (name, value) tuples. return container_type(elements)
def infer_type(arg: Any) -> Optional[computation_types.Type]: """Infers the TFF type of the argument (a `computation_types.Type` instance). Warning: This function is only partially implemented. The kinds of arguments that are currently correctly recognized: * tensors, variables, and data sets * things that are convertible to tensors (including `numpy` arrays, builtin types, as well as `list`s and `tuple`s of any of the above, etc.) * nested lists, `tuple`s, `namedtuple`s, anonymous `tuple`s, `dict`, `OrderedDict`s, `dataclasses`, `attrs` classes, and `tff.TypedObject`s Args: arg: The argument, the TFF type of which to infer. Returns: Either an instance of `computation_types.Type`, or `None` if the argument is `None`. """ if arg is None: return None elif isinstance(arg, typed_object.TypedObject): return arg.type_signature elif tf.is_tensor(arg): # `tf.is_tensor` returns true for some things that are not actually single # `tf.Tensor`s, including `tf.SparseTensor`s and `tf.RaggedTensor`s. if isinstance(arg, tf.RaggedTensor): return computation_types.StructWithPythonType( (('flat_values', infer_type(arg.flat_values)), ('nested_row_splits', infer_type(arg.nested_row_splits))), tf.RaggedTensor) elif isinstance(arg, tf.SparseTensor): return computation_types.StructWithPythonType( (('indices', infer_type(arg.indices)), ('values', infer_type(arg.values)), ('dense_shape', infer_type(arg.dense_shape))), tf.SparseTensor) else: return computation_types.TensorType(arg.dtype.base_dtype, arg.shape) elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES): element_type = computation_types.to_type(arg.element_spec) return computation_types.SequenceType(element_type) elif isinstance(arg, structure.Struct): return computation_types.StructType([ (k, infer_type(v)) if k else infer_type(v) for k, v in structure.iter_elements(arg) ]) elif py_typecheck.is_attrs(arg): items = named_containers.attrs_class_to_odict(arg).items() return computation_types.StructWithPythonType([(k, infer_type(v)) for k, v in items], type(arg)) elif py_typecheck.is_dataclass(arg): items = named_containers.dataclass_to_odict(arg).items() return computation_types.StructWithPythonType([(k, infer_type(v)) for k, v in items], type(arg)) elif py_typecheck.is_named_tuple(arg): # In Python 3.8 and later `_asdict` no longer return OrderedDict, rather a # regular `dict`. items = collections.OrderedDict(arg._asdict()) return computation_types.StructWithPythonType( [(k, infer_type(v)) for k, v in items.items()], type(arg)) elif isinstance(arg, dict): if isinstance(arg, collections.OrderedDict): items = arg.items() else: items = sorted(arg.items()) return computation_types.StructWithPythonType([(k, infer_type(v)) for k, v in items], type(arg)) elif isinstance(arg, (tuple, list)): elements = [] all_elements_named = True for element in arg: all_elements_named &= py_typecheck.is_name_value_pair(element) elements.append(infer_type(element)) # If this is a tuple of (name, value) pairs, the caller most likely intended # this to be a StructType, so we avoid storing the Python container. if elements and all_elements_named: return computation_types.StructType(elements) else: return computation_types.StructWithPythonType(elements, type(arg)) elif isinstance(arg, str): return computation_types.TensorType(tf.string) elif isinstance(arg, (np.generic, np.ndarray)): return computation_types.TensorType(tf.dtypes.as_dtype(arg.dtype), arg.shape) else: arg_type = type(arg) if arg_type is bool: return computation_types.TensorType(tf.bool) elif arg_type is int: # Chose the integral type based on value. if arg > tf.int64.max or arg < tf.int64.min: raise TypeError( 'No integral type support for values outside range ' f'[{tf.int64.min}, {tf.int64.max}]. Got: {arg}') elif arg > tf.int32.max or arg < tf.int32.min: return computation_types.TensorType(tf.int64) else: return computation_types.TensorType(tf.int32) elif arg_type is float: return computation_types.TensorType(tf.float32) else: # Now fall back onto the heavier-weight processing, as all else failed. # Use make_tensor_proto() to make sure to handle it consistently with # how TensorFlow is handling values (e.g., recognizing int as int32, as # opposed to int64 as in NumPy). try: # TODO(b/113112885): Find something more lightweight we could use here. tensor_proto = tf.make_tensor_proto(arg) return computation_types.TensorType( tf.dtypes.as_dtype(tensor_proto.dtype), tf.TensorShape(tensor_proto.tensor_shape)) except TypeError as e: raise TypeError('Could not infer the TFF type of {}.'.format( py_typecheck.type_string(type(arg)))) from e
def is_container_type_with_names(container_type: Type[Any]) -> bool: """Returns whether `container_type`'s elements are named.""" return (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type) or issubclass(container_type, dict))
def is_container_type_without_names(container_type: Type[Any]) -> bool: """Returns whether `container_type`'s elements are unnamed.""" return (issubclass(container_type, (list, tuple)) and not py_typecheck.is_named_tuple(container_type))
def type_to_tf_structure(type_spec: computation_types.Type): """Returns nested `tf.data.experimental.Structure` for a given TFF type. Args: type_spec: A `computation_types.Type`, the type specification must be composed of only named tuples and tensors. In all named tuples that appear in the type spec, all the elements must be named. Returns: An instance of `tf.data.experimental.Structure`, possibly nested, that corresponds to `type_spec`. Raises: ValueError: if the `type_spec` is composed of something other than named tuples and tensors, or if any of the elements in named tuples are unnamed. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return tf.TensorSpec(type_spec.shape, type_spec.dtype) elif type_spec.is_struct(): elements = structure.to_elements(type_spec) if not elements: return () element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements] named = element_outputs[0][0] is not None if not all((e[0] is not None) == named for e in element_outputs): raise ValueError('Tuple elements inconsistently named.') if type_spec.python_container is None: if named: return collections.OrderedDict(element_outputs) else: return tuple(v for _, v in element_outputs) else: container_type = type_spec.python_container if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): return container_type(**dict(element_outputs)) elif container_type is tf.RaggedTensor: flat_values = type_spec.flat_values nested_row_splits = type_spec.nested_row_splits ragged_rank = len(nested_row_splits) return tf.RaggedTensorSpec( shape=tf.TensorShape([None] * (ragged_rank + 1)), dtype=flat_values.dtype, ragged_rank=ragged_rank, row_splits_dtype=nested_row_splits[0].dtype, flat_values_spec=None) elif container_type is tf.SparseTensor: # We can't generally infer the shape from the type of the tensors, but # we *can* infer the rank based on the shapes of `indices` or # `dense_shape`. if (type_spec.indices.shape is not None and type_spec.indices.shape.dims[1] is not None): rank = type_spec.indices.shape.dims[1] shape = tf.TensorShape([None] * rank) elif (type_spec.dense_shape.shape is not None and type_spec.dense_shape.shape.dims[0] is not None): rank = type_spec.dense_shape.shape.dims[0] shape = tf.TensorShape([None] * rank) else: shape = None return tf.SparseTensorSpec(shape=shape, dtype=type_spec.values.dtype) elif named: return container_type(element_outputs) else: return container_type(e if e[0] is not None else e[1] for e in element_outputs) else: raise ValueError('Unsupported type {}.'.format( py_typecheck.type_string(type(type_spec))))
def append_to_list_structure_for_element_type_spec(nested, value, type_spec): """Adds an element `value` to `nested` lists for `type_spec`. This function appends tensor-level constituents of an element `value` to the lists created by `make_empty_list_structure_for_element_type_spec`. The nested structure of `value` must match that created by the above function, and consistent with `type_spec`. Args: nested: Output of `make_empty_list_structure_for_element_type_spec`. value: A value (Python object) that a hierarchical structure of dictionary, list, and other containers holding tensor-like items that matches the hierarchy of `type_spec`. type_spec: An instance of `tff.Type` or something convertible to it, as in `make_empty_list_structure_for_element_type_spec`. Raises: TypeError: If the `type_spec` is not of a form described above, or the value is not of a type compatible with `type_spec`. """ if value is None: return type_spec = computation_types.to_type(type_spec) # TODO(b/113116813): This could be made more efficient, but for now we won't # need to worry about it as this is an odd corner case. if isinstance(value, structure.Struct): elements = structure.to_elements(value) if all(k is not None for k, _ in elements): value = collections.OrderedDict(elements) elif all(k is None for k, _ in elements): value = tuple([v for _, v in elements]) else: raise TypeError( 'Expected an anonymous tuple to either have all elements named or ' 'all unnamed, got {}.'.format(value)) if type_spec.is_tensor(): py_typecheck.check_type(nested, list) # Convert the members to tensors to ensure that they are properly # typed and grouped before being passed to # tf.data.Dataset.from_tensor_slices. nested.append(tf.convert_to_tensor(value, type_spec.dtype)) elif type_spec.is_struct(): elements = structure.to_elements(type_spec) if isinstance(nested, collections.OrderedDict): if py_typecheck.is_named_tuple(value): # In Python 3.8 and later `_asdict` no longer return OrdereDict, rather # a regular `dict`. value = collections.OrderedDict(value._asdict()) if isinstance(value, dict): if set(value.keys()) != set(k for k, _ in elements): raise TypeError('Value {} does not match type {}.'.format( value, type_spec)) for elem_name, elem_type in elements: append_to_list_structure_for_element_type_spec( nested[elem_name], value[elem_name], elem_type) elif isinstance(value, (list, tuple)): if len(value) != len(elements): raise TypeError('Value {} does not match type {}.'.format( value, type_spec)) for idx, (elem_name, elem_type) in enumerate(elements): append_to_list_structure_for_element_type_spec( nested[elem_name], value[idx], elem_type) else: raise TypeError('Unexpected type of value {} for TFF type {}.'.format( py_typecheck.type_string(type(value)), type_spec)) elif isinstance(nested, tuple): py_typecheck.check_type(value, (list, tuple)) if len(value) != len(elements): raise TypeError('Value {} does not match type {}.'.format( value, type_spec)) for idx, (_, elem_type) in enumerate(elements): append_to_list_structure_for_element_type_spec(nested[idx], value[idx], elem_type) else: raise TypeError( 'Invalid nested structure, unexpected container type {}.'.format( py_typecheck.type_string(type(nested)))) else: raise TypeError( 'Expected a tensor or named tuple type, found {}.'.format(type_spec))
def to_value( arg: Any, type_spec, parameter_type_hint=None, ) -> Value: """Converts the argument into an instance of the abstract class `tff.Value`. Instances of `tff.Value` represent TFF values that appear internally in federated computations. This helper function can be used to wrap a variety of Python objects as `tff.Value` instances to allow them to be passed as arguments, used as functions, or otherwise manipulated within bodies of federated computations. At the moment, the supported types include: * Simple constants of `str`, `int`, `float`, and `bool` types, mapped to values of a TFF tensor type. * Numpy arrays (`np.ndarray` objects), also mapped to TFF tensors. * Dictionaries (`collections.OrderedDict` and unordered `dict`), `list`s, `tuple`s, `namedtuple`s, and `Struct`s, all of which are mapped to TFF tuple type. * Computations (constructed with either the `tff.tf_computation` or with the `tff.federated_computation` decorator), typically mapped to TFF functions. * Placement literals (`tff.CLIENTS`, `tff.SERVER`), mapped to values of the TFF placement type. This function is also invoked when attempting to execute a TFF computation. All arguments supplied in the invocation are converted into TFF values prior to execution. The types of Python objects that can be passed as arguments to computations thus matches the types listed here. Args: arg: An instance of one of the Python types that are convertible to TFF values (instances of `tff.Value`). type_spec: An optional type specifier that allows for disambiguating the target type (e.g., when two TFF types can be mapped to the same Python representations). If not specified, TFF tried to determine the type of the TFF value automatically. parameter_type_hint: An optional `tff.Type` or value convertible to it by `tff.to_type()` which specifies an argument type to use in the case that `arg` is a `function_utils.PolymorphicFunction`. Returns: An instance of `tff.Value` as described above. Raises: TypeError: if `arg` is of an unsupported type, or of a type that does not match `type_spec`. Raises explicit error message if TensorFlow constructs are encountered, as TensorFlow code should be sealed away from TFF federated context. """ if type_spec is not None: type_spec = computation_types.to_type(type_spec) if isinstance(arg, Value): result = arg elif isinstance(arg, building_blocks.ComputationBuildingBlock): result = Value(arg) elif isinstance(arg, placements.PlacementLiteral): result = Value(building_blocks.Placement(arg)) elif isinstance( arg, (computation_base.Computation, function_utils.PolymorphicFunction)): if isinstance(arg, function_utils.PolymorphicFunction): if parameter_type_hint is None: raise TypeError( 'Polymorphic computations cannot be converted to `tff.Value`s ' 'without a type hint. Consider explicitly specifying the ' 'argument types of a computation before passing it to a ' 'function that requires a `tff.Value` (such as a TFF intrinsic ' 'like `federated_map`). If you are a TFF developer and think ' 'this should be supported, consider providing `parameter_type_hint` ' 'as an argument to the encompassing `to_value` conversion.' ) parameter_type_hint = computation_types.to_type( parameter_type_hint) arg = arg.fn_for_argument_type(parameter_type_hint) py_typecheck.check_type(arg, computation_base.Computation) result = Value(arg.to_compiled_building_block()) elif type_spec is not None and type_spec.is_sequence(): result = _wrap_sequence_as_value(arg, type_spec.element) elif isinstance(arg, structure.Struct): items = structure.iter_elements(arg) result = _dictlike_items_to_value(items, type_spec, None) elif py_typecheck.is_named_tuple(arg): items = arg._asdict().items() result = _dictlike_items_to_value(items, type_spec, type(arg)) elif py_typecheck.is_attrs(arg): items = attr.asdict(arg, dict_factory=collections.OrderedDict, recurse=False).items() result = _dictlike_items_to_value(items, type_spec, type(arg)) elif isinstance(arg, dict): if isinstance(arg, collections.OrderedDict): items = arg.items() else: items = sorted(arg.items()) result = _dictlike_items_to_value(items, type_spec, type(arg)) elif isinstance(arg, (tuple, list)): items = zip(itertools.repeat(None), arg) result = _dictlike_items_to_value(items, type_spec, type(arg)) elif isinstance(arg, tensorflow_utils.TENSOR_REPRESENTATION_TYPES): result = _wrap_constant_as_value(arg) elif isinstance(arg, (tf.Tensor, tf.Variable)): raise TypeError( 'TensorFlow construct {} has been encountered in a federated ' 'context. TFF does not support mixing TF and federated orchestration ' 'code. Please wrap any TensorFlow constructs with ' '`tff.tf_computation`.'.format(arg)) else: raise TypeError( 'Unable to interpret an argument of type {} as a `tff.Value`.'. format(py_typecheck.type_string(type(arg)))) py_typecheck.check_type(result, Value) if (type_spec is not None and not type_spec.is_assignable_from(result.type_signature)): raise TypeError( 'The supplied argument maps to TFF type {}, which is incompatible with ' 'the requested type {}.'.format(result.type_signature, type_spec)) return result
def capture_result_from_graph(result, graph): """Captures a result stamped into a tf.Graph as a type signature and binding. Args: result: The result to capture, a Python object that is composed of tensors, possibly nested within Python structures such as dictionaries, lists, tuples, or named tuples. graph: The instance of tf.Graph to use. Returns: A tuple (type_spec, binding), where 'type_spec' is an instance of computation_types.Type that describes the type of the result, and 'binding' is an instance of TensorFlow.Binding that indicates how parts of the result type relate to the tensors and ops that appear in the result. Raises: TypeError: If the argument or any of its parts are of an uexpected type. """ def _get_bindings_for_elements(name_value_pairs, graph, type_fn): """Build `(type_spec, binding)` tuple for name value pairs.""" element_name_type_binding_triples = [ ((k, ) + capture_result_from_graph(v, graph)) for k, v in name_value_pairs ] type_spec = type_fn([((e[0], e[1]) if e[0] else e[1]) for e in element_name_type_binding_triples]) binding = pb.TensorFlow.Binding(tuple=pb.TensorFlow.NamedTupleBinding( element=[e[2] for e in element_name_type_binding_triples])) return type_spec, binding # TODO(b/113112885): The emerging extensions for serializing SavedModels may # end up introducing similar concepts of bindings, etc., we should look here # into the possibility of reusing some of that code when it's available. if isinstance(result, dtype_utils.TENSOR_REPRESENTATION_TYPES): with graph.as_default(): result = tf.constant(result) if tf.contrib.framework.is_tensor(result): if hasattr(result, 'read_value'): # We have a tf.Variable-like result, get a proper tensor to fetch. with graph.as_default(): result = result.read_value() return (computation_types.TensorType(result.dtype.base_dtype, result.shape), pb.TensorFlow.Binding(tensor=pb.TensorFlow.TensorBinding( tensor_name=result.name))) elif py_typecheck.is_named_tuple(result): # Special handling needed for collections.namedtuples since they do not have # anything in the way of a shared base class. Note we don't want to rely on # the fact that collections.namedtuples inherit from 'tuple' because we'd be # failing to retain the information about naming of tuple members. # pylint: disable=protected-access name_value_pairs = six.iteritems(result._asdict()) # pylint: enable=protected-access return _get_bindings_for_elements( name_value_pairs, graph, functools.partial( computation_types.NamedTupleTypeWithPyContainerType, container_type=type(result))) elif isinstance(result, anonymous_tuple.AnonymousTuple): return _get_bindings_for_elements(anonymous_tuple.to_elements(result), graph, computation_types.NamedTupleType) elif isinstance(result, dict): if isinstance(result, collections.OrderedDict): name_value_pairs = six.iteritems(result) else: name_value_pairs = sorted(six.iteritems(result)) return _get_bindings_for_elements( name_value_pairs, graph, functools.partial( computation_types.NamedTupleTypeWithPyContainerType, container_type=type(result))) elif isinstance(result, (list, tuple)): element_type_binding_pairs = [ capture_result_from_graph(e, graph) for e in result ] return (computation_types.NamedTupleTypeWithPyContainerType( [e[0] for e in element_type_binding_pairs], type(result)), pb.TensorFlow.Binding(tuple=pb.TensorFlow.NamedTupleBinding( element=[e[1] for e in element_type_binding_pairs]))) elif isinstance(result, (tf.compat.v1.data.Dataset, tf.compat.v2.data.Dataset)): variant_tensor = tf.data.experimental.to_variant(result) # TODO(b/130032140): Switch to TF2.0 way of doing it while cleaning up the # legacy structures all over the code base and replacing them with the new # tf.data.experimenta.Structure variants. element_type = type_utils.tf_dtypes_and_shapes_to_type( tf.compat.v1.data.get_output_types(result), tf.compat.v1.data.get_output_shapes(result)) return (computation_types.SequenceType(element_type), pb.TensorFlow.Binding(sequence=pb.TensorFlow.SequenceBinding( variant_tensor_name=variant_tensor.name))) elif isinstance(result, OneShotDataset): # TODO(b/129956296): Eventually delete this deprecated code path. element_type = type_utils.tf_dtypes_and_shapes_to_type( tf.compat.v1.data.get_output_types(result), tf.compat.v1.data.get_output_shapes(result)) handle_name = result.make_one_shot_iterator().string_handle().name return (computation_types.SequenceType(element_type), pb.TensorFlow.Binding(sequence=pb.TensorFlow.SequenceBinding( iterator_string_handle_name=handle_name))) else: raise TypeError( 'Cannot capture a result of an unsupported type {}.'.format( py_typecheck.type_string(type(result))))
def tf_dtypes_and_shapes_to_type(dtypes, shapes): """Returns computation_types.Type for the given TF (dtypes, shapes) tuple. The returned dtypes and shapes match those used by `tf.data.Dataset`s to indicate the type and shape of their elements. They can be used, e.g., as arguments in constructing an iterator over a string handle. Note that the nested structure of dtypes and shapes must be identical. Args: dtypes: A nested structure of dtypes, such as what is returned by Dataset's output_dtypes property. shapes: A nested structure of shapes, such as what is returned by Dataset's output_shapes property. Returns: The corresponding instance of computation_types.Type. Raises: TypeError: if the arguments are of types that weren't recognized. """ tf.contrib.framework.nest.assert_same_structure(dtypes, shapes) def _parallel_dict_to_element_list(dtype_dict, shape_dict): return [(name, tf_dtypes_and_shapes_to_type(dtype_elem, shape_dict[name])) for name, dtype_elem in six.iteritems(dtype_dict)] if isinstance(dtypes, tf.DType): return computation_types.TensorType(dtypes, shapes) elif py_typecheck.is_named_tuple(dtypes): # Special handling needed for collections.namedtuple due to the lack of # a base class. Note this must precede the test for being a list. dtype_dict = dtypes._asdict() shape_dict = shapes._asdict() return computation_types.NamedTupleTypeWithPyContainerType( _parallel_dict_to_element_list(dtype_dict, shape_dict), type(dtypes)) elif py_typecheck.is_attrs(dtypes): dtype_dict = attr.asdict(dtypes, dict_factory=collections.OrderedDict, recurse=False) shapes_dict = attr.asdict(shapes, dict_factory=collections.OrderedDict, recurse=False) return computation_types.NamedTupleTypeWithPyContainerType( _parallel_dict_to_element_list(dtype_dict, shapes_dict), type(dtypes)) elif isinstance(dtypes, dict): if isinstance(dtypes, collections.OrderedDict): items = six.iteritems(dtypes) else: items = sorted(six.iteritems(dtypes)) elements = [(name, tf_dtypes_and_shapes_to_type(dtypes_elem, shapes[name])) for name, dtypes_elem in items] return computation_types.NamedTupleTypeWithPyContainerType( elements, type(dtypes)) elif isinstance(dtypes, (list, tuple)): return computation_types.NamedTupleTypeWithPyContainerType([ tf_dtypes_and_shapes_to_type(dtypes_elem, shapes[idx]) for idx, dtypes_elem in enumerate(dtypes) ], type(dtypes)) else: raise TypeError('Unrecognized: dtypes {}, shapes {}.'.format( str(dtypes), str(shapes)))
def to_value( arg: Any, type_spec, context_stack: context_stack_base.ContextStack, parameter_type_hint=None, ) -> ValueImpl: """Converts the argument into an instance of `tff.Value`. The types of non-`tff.Value` arguments that are currently convertible to `tff.Value` include the following: * Lists, tuples, anonymous tuples, named tuples, and dictionaries, all of which are converted into instances of `tff.Tuple`. * Placement literals, converted into instances of `tff.Placement`. * Computations. * Python constants of type `str`, `int`, `float`, `bool` * Numpy objects inherting from `np.ndarray` or `np.generic` (the parent of numpy scalar types) Args: arg: Either an instance of `tff.Value`, or an argument convertible to `tff.Value`. The argument must not be `None`. type_spec: An optional `computation_types.Type` or value convertible to it by `computation_types.to_type` which specifies the desired type signature of the resulting value. This allows for disambiguating the target type (e.g., when two TFF types can be mapped to the same Python representations), or `None` if none available, in which case TFF tries to determine the type of the TFF value automatically. context_stack: The context stack to use. parameter_type_hint: An optional `computation_types.Type` or value convertible to it by `computation_types.to_type` which specifies an argument type to use in the case that `arg` is a `function_utils.PolymorphicFunction`. Returns: An instance of `tff.Value` corresponding to the given `arg`, and of TFF type matching the `type_spec` if specified (not `None`). Raises: TypeError: if `arg` is of an unsupported type, or of a type that does not match `type_spec`. Raises explicit error message if TensorFlow constructs are encountered, as TensorFlow code should be sealed away from TFF federated context. """ py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if type_spec is not None: type_spec = computation_types.to_type(type_spec) type_analysis.check_well_formed(type_spec) if isinstance(arg, ValueImpl): result = arg elif isinstance(arg, building_blocks.ComputationBuildingBlock): result = ValueImpl(arg, context_stack) elif isinstance(arg, placement_literals.PlacementLiteral): result = ValueImpl(building_blocks.Placement(arg), context_stack) elif isinstance( arg, (computation_base.Computation, function_utils.PolymorphicFunction)): if isinstance(arg, function_utils.PolymorphicFunction): if parameter_type_hint is None: raise TypeError( 'Polymorphic computations cannot be converted to TFF values ' 'without a type hint. Consider explicitly specifying the ' 'argument types of a computation before passing it to a ' 'function that requires a TFF value (such as a TFF intrinsic ' 'like `federated_map`). If you are a TFF developer and think ' 'this should be supported, consider providing `parameter_type_hint` ' 'as an argument to the encompassing `to_value` conversion.' ) parameter_type_hint = computation_types.to_type( parameter_type_hint) type_analysis.check_well_formed(parameter_type_hint) arg = arg.fn_for_argument_type(parameter_type_hint) py_typecheck.check_type(arg, computation_base.Computation) result = ValueImpl( building_blocks.CompiledComputation( computation_impl.ComputationImpl.get_proto(arg)), context_stack) elif type_spec is not None and isinstance(type_spec, computation_types.SequenceType): result = _wrap_sequence_as_value(arg, type_spec.element, context_stack) elif isinstance(arg, anonymous_tuple.AnonymousTuple): result = ValueImpl( building_blocks.Tuple([ (k, ValueImpl.get_comp(to_value(v, None, context_stack))) for k, v in anonymous_tuple.iter_elements(arg) ]), context_stack) elif py_typecheck.is_named_tuple(arg): result = to_value(arg._asdict(), None, context_stack) # pytype: disable=attribute-error elif py_typecheck.is_attrs(arg): result = to_value( attr.asdict(arg, dict_factory=collections.OrderedDict, recurse=False), None, context_stack) elif isinstance(arg, dict): if isinstance(arg, collections.OrderedDict): items = arg.items() else: items = sorted(arg.items()) value = building_blocks.Tuple([ (k, ValueImpl.get_comp(to_value(v, None, context_stack))) for k, v in items ]) result = ValueImpl(value, context_stack) elif isinstance(arg, (tuple, list)): result = ValueImpl( building_blocks.Tuple([ ValueImpl.get_comp(to_value(x, None, context_stack)) for x in arg ]), context_stack) elif isinstance(arg, tensorflow_utils.TENSOR_REPRESENTATION_TYPES): result = _wrap_constant_as_value(arg, context_stack) elif isinstance(arg, (tf.Tensor, tf.Variable)): raise TypeError( 'TensorFlow construct {} has been encountered in a federated ' 'context. TFF does not support mixing TF and federated orchestration ' 'code. Please wrap any TensorFlow constructs with ' '`tff.tf_computation`.'.format(arg)) else: raise TypeError( 'Unable to interpret an argument of type {} as a TFF value.'. format(py_typecheck.type_string(type(arg)))) py_typecheck.check_type(result, ValueImpl) if (type_spec is not None and not type_spec.is_assignable_from(result.type_signature)): raise TypeError( 'The supplied argument maps to TFF type {}, which is incompatible with ' 'the requested type {}.'.format(result.type_signature, type_spec)) return result
def infer_type(arg): """Infers the TFF type of the argument (a `computation_types.Type` instance). WARNING: This function is only partially implemented. The kinds of arguments that are currently correctly recognized: - tensors, variables, and data sets, - things that are convertible to tensors (including numpy arrays, builtin types, as well as lists and tuples of any of the above, etc.), - nested lists, tuples, namedtuples, anonymous tuples, dict, and OrderedDicts. Args: arg: The argument, the TFF type of which to infer. Returns: Either an instance of `computation_types.Type`, or `None` if the argument is `None`. """ # TODO(b/113112885): Implement the remaining cases here on the need basis. if arg is None: return None elif isinstance(arg, typed_object.TypedObject): return arg.type_signature elif tf.contrib.framework.is_tensor(arg): return computation_types.TensorType(arg.dtype.base_dtype, arg.shape) elif isinstance(arg, (tf.data.Dataset, tf.compat.v1.data.Dataset, tf.compat.v2.data.Dataset)): return computation_types.SequenceType( tf_dtypes_and_shapes_to_type( tf.compat.v1.data.get_output_types(arg), tf.compat.v1.data.get_output_shapes(arg))) elif isinstance(arg, anonymous_tuple.AnonymousTuple): return computation_types.NamedTupleType([ (k, infer_type(v)) if k else infer_type(v) for k, v in anonymous_tuple.to_elements(arg) ]) elif py_typecheck.is_attrs(arg): items = attr.asdict(arg, dict_factory=collections.OrderedDict, recurse=False) return computation_types.NamedTupleTypeWithPyContainerType( [(k, infer_type(v)) for k, v in six.iteritems(items)], type(arg)) elif py_typecheck.is_named_tuple(arg): items = arg._asdict() return computation_types.NamedTupleTypeWithPyContainerType( [(k, infer_type(v)) for k, v in six.iteritems(items)], type(arg)) elif isinstance(arg, dict): if isinstance(arg, collections.OrderedDict): items = six.iteritems(arg) else: items = sorted(six.iteritems(arg)) return computation_types.NamedTupleTypeWithPyContainerType( [(k, infer_type(v)) for k, v in items], type(arg)) elif isinstance(arg, (tuple, list)): elements = [] all_elements_named = True for element in arg: all_elements_named &= py_typecheck.is_name_value_pair(element) elements.append(infer_type(element)) # If this is a tuple of (name, value) pairs, the caller most likely intended # this to be a NamedTupleType, so we avoid storing the Python container. if all_elements_named: return computation_types.NamedTupleType(elements) else: return computation_types.NamedTupleTypeWithPyContainerType( elements, type(arg)) elif isinstance(arg, six.string_types): return computation_types.TensorType(tf.string) elif isinstance(arg, (np.generic, np.ndarray)): return computation_types.TensorType(tf.as_dtype(arg.dtype), arg.shape) else: dtype = { bool: tf.bool, int: tf.int32, float: tf.float32 }.get(type(arg)) if dtype: return computation_types.TensorType(dtype) else: # Now fall back onto the heavier-weight processing, as all else failed. # Use make_tensor_proto() to make sure to handle it consistently with # how TensorFlow is handling values (e.g., recognizing int as int32, as # opposed to int64 as in NumPy). try: # TODO(b/113112885): Find something more lightweight we could use here. tensor_proto = tf.make_tensor_proto(arg) return computation_types.TensorType( tf.DType(tensor_proto.dtype), tf.TensorShape(tensor_proto.tensor_shape)) except TypeError as err: raise TypeError( 'Could not infer the TFF type of {}: {}.'.format( py_typecheck.type_string(type(arg)), str(err)))
def build_py_container(elements): if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): return container_type(**dict(elements)) else: return container_type(elements)
def to_value(arg, type_spec, context_stack): """Converts the argument into an instance of `tff.Value`. The types of non-`tff.Value` arguments that are currently convertible to `tff.Value` include the following: * Lists, tuples, anonymous tuples, named tuples, and dictionaries, all of which are converted into instances of `tff.Tuple`. * Placement literals, converted into instances of `tff.Placement`. * Computations. * Python constants of type `str`, `int`, `float`, `bool` * Numpy objects inherting from `np.ndarray` or `np.generic` (the parent of numpy scalar types) Args: arg: Either an instance of `tff.Value`, or an argument convertible to `tff.Value`. The argument must not be `None`. type_spec: A type specifier that allows for disambiguating the target type (e.g., when two TFF types can be mapped to the same Python representations), or `None` if none available, in which case TFF tries to determine the type of the TFF value automatically. context_stack: The context stack to use. Returns: An instance of `tff.Value` corresponding to the given `arg`, and of TFF type matching the `type_spec` if specified (not `None`). Raises: TypeError: if `arg` is of an unsupported type, or of a type that does not match `type_spec`. Raises explicit error message if TensorFlow constructs are encountered, as TensorFlow code should be sealed away from TFF federated context. """ py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if type_spec is not None: type_spec = computation_types.to_type(type_spec) type_utils.check_well_formed(type_spec) if isinstance(arg, ValueImpl): result = arg elif isinstance(arg, computation_building_blocks.ComputationBuildingBlock): result = ValueImpl(arg, context_stack) elif isinstance(arg, placement_literals.PlacementLiteral): result = ValueImpl(computation_building_blocks.Placement(arg), context_stack) elif isinstance(arg, computation_base.Computation): result = ValueImpl( computation_building_blocks.CompiledComputation( computation_impl.ComputationImpl.get_proto(arg)), context_stack) elif type_spec is not None and isinstance(type_spec, computation_types.SequenceType): result = _wrap_sequence_as_value(arg, type_spec.element, context_stack) elif isinstance(arg, anonymous_tuple.AnonymousTuple): result = ValueImpl( computation_building_blocks.Tuple([ (k, ValueImpl.get_comp(to_value(v, None, context_stack))) for k, v in anonymous_tuple.to_elements(arg) ]), context_stack) elif py_typecheck.is_named_tuple(arg): result = to_value(arg._asdict(), None, context_stack) elif isinstance(arg, dict): if isinstance(arg, collections.OrderedDict): items = six.iteritems(arg) else: items = sorted(six.iteritems(arg)) value = computation_building_blocks.Tuple([ (k, ValueImpl.get_comp(to_value(v, None, context_stack))) for k, v in items ]) result = ValueImpl(value, context_stack) elif isinstance(arg, (tuple, list)): result = ValueImpl( computation_building_blocks.Tuple([ ValueImpl.get_comp(to_value(x, None, context_stack)) for x in arg ]), context_stack) elif isinstance(arg, dtype_utils.TENSOR_REPRESENTATION_TYPES): result = _wrap_constant_as_value(arg, context_stack) elif isinstance(arg, (tf.Tensor, tf.Variable)): raise TypeError( 'TensorFlow construct {} has been encountered in a federated ' 'context. TFF does not support mixing TF and federated orchestration ' 'code. Please wrap any TensorFlow constructs with ' '`tff.tf_computation`.'.format(arg)) else: raise TypeError( 'Unable to interpret an argument of type {} as a TFF value.'. format(py_typecheck.type_string(type(arg)))) py_typecheck.check_type(result, ValueImpl) if (type_spec is not None and not type_utils.is_assignable_from( type_spec, result.type_signature)): raise TypeError( 'The supplied argument maps to TFF type {}, which is incompatible ' 'with the requested type {}.'.format(str(result.type_signature), str(type_spec))) return result