def make_dataset_from_variant_tensor(variant_tensor, type_spec): """Constructs a `tf.data.Dataset` from a variant tensor and type spec. Args: variant_tensor: The variant tensor that represents the dataset. type_spec: The type spec of elements of the data set, either an instance of `types.Type` or something convertible to it. Returns: A corresponding instance of `tf.data.Dataset`. Raises: TypeError: If the arguments are of the wrong types. """ if not tf.is_tensor(variant_tensor): raise TypeError( 'Expected `variant_tensor` to be a tensor, found {}.'.format( py_typecheck.type_string(type(variant_tensor)))) if variant_tensor.dtype != tf.variant: raise TypeError( 'Expected `variant_tensor` to be of a variant type, found {}.'.format( variant_tensor.dtype)) return tf.data.experimental.from_variant( variant_tensor, structure=(type_conversions.type_to_tf_structure( computation_types.to_type(type_spec))))
def _deserialize_type_spec(serialize_type_variable, python_container=None): """Deserialize a `tff.Type` protocol buffer into a python class instance.""" type_spec = type_serialization.deserialize_type( computation_pb2.Type.FromString( serialize_type_variable.read_value().numpy())) if type_spec.is_struct() and python_container is not None: type_spec = computation_types.StructWithPythonType( structure.iter_elements(type_spec), python_container) return type_conversions.type_to_tf_structure(type_spec)
def test_without_names(self): expected_structure = ( tf.TensorSpec(shape=(), dtype=tf.bool), tf.TensorSpec(shape=(), dtype=tf.int32), ) type_spec = computation_types.to_type(expected_structure) tf_structure = type_conversions.type_to_tf_structure(type_spec) with tf.Graph().as_default(): ds = tf.data.experimental.from_variant(tf.compat.v1.placeholder( tf.variant, shape=[]), structure=tf_structure) actual_structure = ds.element_spec self.assertEqual(expected_structure, actual_structure)
def test_with_names(self): expected_structure = collections.OrderedDict([ ('a', tf.TensorSpec(shape=(), dtype=tf.bool)), ('b', collections.OrderedDict([ ('c', tf.TensorSpec(shape=(), dtype=tf.float32)), ('d', tf.TensorSpec(shape=(20, ), dtype=tf.int32)), ])), ]) type_spec = computation_types.to_type(expected_structure) tf_structure = type_conversions.type_to_tf_structure(type_spec) with tf.Graph().as_default(): ds = tf.data.experimental.from_variant(tf.compat.v1.placeholder( tf.variant, shape=[]), structure=tf_structure) actual_structure = ds.element_spec self.assertEqual(expected_structure, actual_structure)
def embed_tensorflow_computation(comp, type_spec=None, device=None): """Embeds a TensorFlow computation for use in the eager context. Args: comp: An instance of `pb.Computation`. type_spec: An optional `tff.Type` instance or something convertible to it. device: An optional `tf.config.LogicalDevice`. Returns: Either a one-argument or a zero-argument callable that executes the computation in eager mode. Raises: TypeError: If arguments are of the wrong types, e.g., in `comp` is not a TensorFlow computation. """ # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py` # since it deals exclusively with eager mode. Incubate here, and potentially # move there, once stable. py_typecheck.check_type(comp, pb.Computation) comp_type = type_serialization.deserialize_type(comp.type) type_spec = computation_types.to_type(type_spec) if type_spec is not None: if not type_spec.is_equivalent_to(comp_type): raise TypeError( 'Expected a computation of type {}, got {}.'.format( type_spec, comp_type)) else: type_spec = comp_type # TODO(b/156302055): Currently, TF will raise on any function returning a # `tf.data.Dataset` not pinned to CPU. We should follow up here and remove # this gating when we can. must_pin_function_to_cpu = type_analysis.contains( type_spec.result, lambda t: t.is_sequence()) which_computation = comp.WhichOneof('computation') if which_computation != 'tensorflow': unexpected_building_block = building_blocks.ComputationBuildingBlock.from_proto( comp) raise TypeError('Expected a TensorFlow computation, found {}.'.format( unexpected_building_block)) if type_spec.is_function(): param_type = type_spec.parameter result_type = type_spec.result else: param_type = None result_type = type_spec wrapped_fn = _get_wrapped_function_from_comp(comp, must_pin_function_to_cpu, param_type, device) param_fns = [] if param_type is not None: for spec in structure.flatten(type_spec.parameter): if spec.is_tensor(): param_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) param_fns.append(tf.data.experimental.to_variant) result_fns = [] for spec in structure.flatten(result_type): if spec.is_tensor(): result_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) tf_structure = type_conversions.type_to_tf_structure(spec.element) def fn(x, tf_structure=tf_structure): return tf.data.experimental.from_variant(x, tf_structure) result_fns.append(fn) ops = wrapped_fn.graph.get_operations() eager_cleanup_ops = [] destroy_before_invocation = [] for op in ops: if op.type == 'HashTableV2': eager_cleanup_ops += op.outputs if eager_cleanup_ops: for resource in wrapped_fn.prune(feeds={}, fetches=eager_cleanup_ops)(): destroy_before_invocation.append(resource) lazy_cleanup_ops = [] destroy_after_invocation = [] for op in ops: if op.type == 'VarHandleOp': lazy_cleanup_ops += op.outputs if lazy_cleanup_ops: for resource in wrapped_fn.prune(feeds={}, fetches=lazy_cleanup_ops)(): destroy_after_invocation.append(resource) def fn_to_return(arg, param_fns=tuple(param_fns), result_fns=tuple(result_fns), result_type=result_type, wrapped_fn=wrapped_fn, destroy_before=tuple(destroy_before_invocation), destroy_after=tuple(destroy_after_invocation)): # This double-function pattern works around python late binding, forcing the # variables to bind eagerly. return _call_embedded_tf(arg=arg, param_fns=param_fns, result_fns=result_fns, result_type=result_type, wrapped_fn=wrapped_fn, destroy_before_invocation=destroy_before, destroy_after_invocation=destroy_after) # pylint: disable=function-redefined if must_pin_function_to_cpu: old_fn_to_return = fn_to_return def fn_to_return(x): with tf.device('cpu'): return old_fn_to_return(x) elif device is not None: old_fn_to_return = fn_to_return def fn_to_return(x): with tf.device(device.name): return old_fn_to_return(x) # pylint: enable=function-redefined if param_type is not None: return lambda arg: fn_to_return(arg) # pylint: disable=unnecessary-lambda else: return lambda: fn_to_return(None)
def test_with_no_elements(self): with self.assertRaises(ValueError): type_conversions.type_to_tf_structure( computation_types.StructType([]))
def test_with_inconsistently_named_elements(self): with self.assertRaises(ValueError): type_conversions.type_to_tf_structure( computation_types.StructType([('a', tf.int32), tf.bool]))
def test_with_sequence_type(self): with self.assertRaises(ValueError): type_conversions.type_to_tf_structure( computation_types.SequenceType(tf.int32))
def test_with_none(self): with self.assertRaises(TypeError): type_conversions.type_to_tf_structure(None)
def embed_tensorflow_computation(comp, type_spec=None, device=None): """Embeds a TensorFlow computation for use in the eager context. Args: comp: An instance of `pb.Computation`. type_spec: An optional `tff.Type` instance or something convertible to it. device: An optional device name. Returns: Either a one-argument or a zero-argument callable that executes the computation in eager mode. Raises: TypeError: If arguments are of the wrong types, e.g., in `comp` is not a TensorFlow computation. """ # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py` # since it deals exclusively with eager mode. Incubate here, and potentially # move there, once stable. py_typecheck.check_type(comp, pb.Computation) comp_type = type_serialization.deserialize_type(comp.type) type_spec = computation_types.to_type(type_spec) if type_spec is not None: if not type_analysis.are_equivalent_types(type_spec, comp_type): raise TypeError('Expected a computation of type {}, got {}.'.format( type_spec, comp_type)) else: type_spec = comp_type which_computation = comp.WhichOneof('computation') if which_computation != 'tensorflow': raise TypeError('Expected a TensorFlow computation, found {}.'.format( which_computation)) if isinstance(type_spec, computation_types.FunctionType): param_type = type_spec.parameter result_type = type_spec.result else: param_type = None result_type = type_spec if param_type is not None: input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding( comp.tensorflow.parameter) else: input_tensor_names = [] output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding( comp.tensorflow.result) def function_to_wrap(*args): # pylint: disable=missing-docstring if len(args) != len(input_tensor_names): raise RuntimeError('Expected {} arguments, found {}.'.format( len(input_tensor_names), len(args))) graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def) init_op = comp.tensorflow.initialize_op if init_op: graph_def = tensorflow_utils.add_control_deps_for_init_op( graph_def, init_op) def _import_fn(): return tf.import_graph_def( graph_merge.uniquify_shared_names(graph_def), input_map=dict(list(zip(input_tensor_names, args))), return_elements=output_tensor_names) if device is not None: with tf.device(device): return _import_fn() else: return _import_fn() signature = [] param_fns = [] if param_type is not None: for spec in anonymous_tuple.flatten(type_spec.parameter): if isinstance(spec, computation_types.TensorType): signature.append(tf.TensorSpec(spec.shape, spec.dtype)) param_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) signature.append(tf.TensorSpec([], tf.variant)) param_fns.append(tf.data.experimental.to_variant) wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature) result_fns = [] for spec in anonymous_tuple.flatten(result_type): if isinstance(spec, computation_types.TensorType): result_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) structure = type_conversions.type_to_tf_structure(spec.element) def fn(x, structure=structure): return tf.data.experimental.from_variant(x, structure) result_fns.append(fn) def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring param_elements = [] if arg is not None: arg_parts = anonymous_tuple.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( len(param_fns), len(arg_parts))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) # There is a tf.wrap_function(...) issue b/144127474 that variables created # from tf.import_graph_def(...) inside tf.wrap_function(...) is not # destroyed. So get all the variables from `wrapped_fn` and destroy # manually. # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...) # is fixed. resources = [] for op in wrapped_fn.graph.get_operations(): if op.type == 'VarHandleOp': resources += op.outputs if resources: for resource in wrapped_fn.prune(feeds={}, fetches=resources)(): tf.raw_ops.DestroyResourceOp(resource=resource) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return anonymous_tuple.pack_sequence_as(result_type, result_elements) fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(arg, p, w) if device is not None: old_fn_to_return = fn_to_return # pylint: disable=function-redefined def fn_to_return(x): with tf.device(device): return old_fn_to_return(x) # pylint: enable=function-redefined if param_type is not None: return lambda arg: fn_to_return(arg) # pylint: disable=unnecessary-lambda else: return lambda: fn_to_return(None)
def embed_tensorflow_computation(comp, type_spec=None, device=None): """Embeds a TensorFlow computation for use in the eager context. Args: comp: An instance of `pb.Computation`. type_spec: An optional `tff.Type` instance or something convertible to it. device: An optional `tf.config.LogicalDevice`. Returns: Either a one-argument or a zero-argument callable that executes the computation in eager mode. Raises: TypeError: If arguments are of the wrong types, e.g., in `comp` is not a TensorFlow computation. """ # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py` # since it deals exclusively with eager mode. Incubate here, and potentially # move there, once stable. py_typecheck.check_type(comp, pb.Computation) comp_type = type_serialization.deserialize_type(comp.type) type_spec = computation_types.to_type(type_spec) if type_spec is not None: if not type_spec.is_equivalent_to(comp_type): raise TypeError( 'Expected a computation of type {}, got {}.'.format( type_spec, comp_type)) else: type_spec = comp_type # TODO(b/156302055): Currently, TF will raise on any function returning a # `tf.data.Dataset` not pinned to CPU. We should follow up here and remove # this gating when we can. must_pin_function_to_cpu = type_analysis.contains( type_spec.result, lambda t: t.is_sequence()) which_computation = comp.WhichOneof('computation') if which_computation != 'tensorflow': unexpected_building_block = building_blocks.ComputationBuildingBlock.from_proto( comp) raise TypeError('Expected a TensorFlow computation, found {}.'.format( unexpected_building_block)) if type_spec.is_function(): param_type = type_spec.parameter result_type = type_spec.result else: param_type = None result_type = type_spec wrapped_fn = _get_wrapped_function_from_comp(comp, must_pin_function_to_cpu, param_type, device) param_fns = [] if param_type is not None: for spec in structure.flatten(type_spec.parameter): if spec.is_tensor(): param_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) param_fns.append(tf.data.experimental.to_variant) result_fns = [] for spec in structure.flatten(result_type): if spec.is_tensor(): result_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) tf_structure = type_conversions.type_to_tf_structure(spec.element) def fn(x, tf_structure=tf_structure): return tf.data.experimental.from_variant(x, tf_structure) result_fns.append(fn) def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring # TODO(b/166479382): This cleanup-before-invocation pattern is a workaround # to square the circle of TF data expecting to lazily reference this # resource on iteration, as well as usages that expect to reinitialize a # table with new data. Revisit the semantics implied by this cleanup # pattern. eager_cleanup_resources = [] for op in wrapped_fn.graph.get_operations(): if op.type == 'HashTableV2': eager_cleanup_resources += op.outputs if eager_cleanup_resources: for resource in wrapped_fn.prune( feeds={}, fetches=eager_cleanup_resources)(): tf.raw_ops.DestroyResourceOp(resource=resource) param_elements = [] if arg is not None: arg_parts = structure.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( len(param_fns), len(arg_parts))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) # There is a tf.wrap_function(...) issue b/144127474 that variables created # from tf.import_graph_def(...) inside tf.wrap_function(...) is not # destroyed. So get all the variables from `wrapped_fn` and destroy # manually. # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...) # is fixed. resources = [] for op in wrapped_fn.graph.get_operations(): if op.type == 'VarHandleOp': resources += op.outputs if resources: for resource in wrapped_fn.prune(feeds={}, fetches=resources)(): tf.raw_ops.DestroyResourceOp(resource=resource) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return structure.pack_sequence_as(result_type, result_elements) fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return( arg, p, w) # pylint: disable=function-redefined if must_pin_function_to_cpu: old_fn_to_return = fn_to_return def fn_to_return(x): with tf.device('cpu'): return old_fn_to_return(x) elif device is not None: old_fn_to_return = fn_to_return def fn_to_return(x): with tf.device(device.name): return old_fn_to_return(x) # pylint: enable=function-redefined if param_type is not None: return lambda arg: fn_to_return(arg) # pylint: disable=unnecessary-lambda else: return lambda: fn_to_return(None)
def check_round_trip(self, first_spec): first_type = computation_types.to_type(first_spec) round_trip_spec = type_conversions.type_to_tf_structure(first_type) round_trip_type = computation_types.to_type(round_trip_spec) self.assert_types_identical(first_type, round_trip_type)
def embed_tensorflow_computation(comp, type_spec=None, device=None): """Embeds a TensorFlow computation for use in the eager context. Args: comp: An instance of `pb.Computation`. type_spec: An optional `tff.Type` instance or something convertible to it. device: An optional `tf.config.LogicalDevice`. Returns: Either a one-argument or a zero-argument callable that executes the computation in eager mode. Raises: TypeError: If arguments are of the wrong types, e.g., in `comp` is not a TensorFlow computation. """ # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py` # since it deals exclusively with eager mode. Incubate here, and potentially # move there, once stable. py_typecheck.check_type(comp, pb.Computation) comp_type = type_serialization.deserialize_type(comp.type) type_spec = computation_types.to_type(type_spec) if type_spec is not None: if not type_spec.is_equivalent_to(comp_type): raise TypeError( 'Expected a computation of type {}, got {}.'.format( type_spec, comp_type)) else: type_spec = comp_type # TODO(b/155198591): Currently, TF will raise on any function returning a # `tf.data.Dataset` not pinned to CPU. We should follow up here and remove # this gating when we can. must_pin_function_to_cpu = type_analysis.contains_types( type_spec.result, computation_types.SequenceType) which_computation = comp.WhichOneof('computation') if which_computation != 'tensorflow': raise TypeError('Expected a TensorFlow computation, found {}.'.format( which_computation)) if isinstance(type_spec, computation_types.FunctionType): param_type = type_spec.parameter result_type = type_spec.result else: param_type = None result_type = type_spec if param_type is not None: input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding( comp.tensorflow.parameter) else: input_tensor_names = [] output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding( comp.tensorflow.result) def function_to_wrap(): """No-arg function to import graph def. We pass a no-arg function to `tf.compat.v1.wrap_function` to avoid the leftover placeholders that can result from binding arguments to the imported graphdef via `input_map`. The correct signature will be added to this function later, via the `prune` call below. Returns: Result of importing graphdef backing `comp`. """ graph_def = serialization_utils.unpack_graph_def( comp.tensorflow.graph_def) init_op = comp.tensorflow.initialize_op if init_op: graph_def = tensorflow_utils.add_control_deps_for_init_op( graph_def, init_op) def _import_fn(): return tf.import_graph_def( graph_merge.uniquify_shared_names(graph_def), name='') if must_pin_function_to_cpu: with tf.device('cpu'): return _import_fn() elif device is not None: with tf.device(device.name): return _import_fn() else: return _import_fn() param_fns = [] if param_type is not None: for spec in anonymous_tuple.flatten(type_spec.parameter): if isinstance(spec, computation_types.TensorType): param_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) param_fns.append(tf.data.experimental.to_variant) wrapped_noarg_fn = tf.compat.v1.wrap_function(function_to_wrap, signature=[]) import_graph = wrapped_noarg_fn.graph try: wrapped_fn = wrapped_noarg_fn.prune( feeds=tf.nest.map_structure(import_graph.as_graph_element, input_tensor_names), fetches=tf.nest.map_structure(import_graph.as_graph_element, output_tensor_names), ) except KeyError as e: raise TypeError( 'Caught exception trying to prune graph `{g}` with ' 'feeds {feeds} and fetches {fetches}. This indicates that these ' 'names may not refer to tensors in the graph. .\nException: {e}'. format(g=import_graph, feeds=input_tensor_names, fetches=output_tensor_names, e=e)) result_fns = [] for spec in anonymous_tuple.flatten(result_type): if isinstance(spec, computation_types.TensorType): result_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) structure = type_conversions.type_to_tf_structure(spec.element) def fn(x, structure=structure): return tf.data.experimental.from_variant(x, structure) result_fns.append(fn) def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring param_elements = [] if arg is not None: arg_parts = anonymous_tuple.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( len(param_fns), len(arg_parts))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) # There is a tf.wrap_function(...) issue b/144127474 that variables created # from tf.import_graph_def(...) inside tf.wrap_function(...) is not # destroyed. So get all the variables from `wrapped_fn` and destroy # manually. # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...) # is fixed. resources = [] for op in wrapped_fn.graph.get_operations(): if op.type == 'VarHandleOp': resources += op.outputs if resources: for resource in wrapped_fn.prune(feeds={}, fetches=resources)(): tf.raw_ops.DestroyResourceOp(resource=resource) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return anonymous_tuple.pack_sequence_as(result_type, result_elements) fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return( arg, p, w) # pylint: disable=function-redefined if must_pin_function_to_cpu: old_fn_to_return = fn_to_return def fn_to_return(x): with tf.device('cpu'): return old_fn_to_return(x) elif device is not None: old_fn_to_return = fn_to_return def fn_to_return(x): with tf.device(device.name): return old_fn_to_return(x) # pylint: enable=function-redefined if param_type is not None: return lambda arg: fn_to_return(arg) # pylint: disable=unnecessary-lambda else: return lambda: fn_to_return(None)