def clip_by_global_norm(delta, clip_norm): # TODO(b/123092620): Replace anonymous_tuple with tf.nest. delta = anonymous_tuple.from_container(delta) clipped, global_norm = tf.clip_by_global_norm( anonymous_tuple.flatten(delta), clip_norm) return anonymous_tuple.pack_sequence_as(delta, clipped), global_norm
def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring param_elements = [] if arg is not None: arg_parts = anonymous_tuple.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( len(param_fns), len(arg_parts))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) # There is a tf.wrap_function(...) issue b/144127474 that variables created # from tf.import_graph_def(...) inside tf.wrap_function(...) is not # destroyed. So get all the variables from `wrapped_fn` and destroy # manually. # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...) # is fixed. resources = [] for op in wrapped_fn.graph.get_operations(): if op.type == 'VarHandleOp': resources += op.outputs if resources: for resource in wrapped_fn.prune(feeds={}, fetches=resources)(): tf.raw_ops.DestroyResourceOp(resource=resource) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return anonymous_tuple.pack_sequence_as(result_type, result_elements)
def fetch_value_in_session(sess, value): """Fetches `value` in `session`. Args: sess: The session in which to perform the fetch (as a single run). value: A Python object of a form analogous to that constructed by the function `assemble_result_from_graph`, made of tensors and anononymous tuples, or a `tf.data.Dataset`. Returns: A Python object with structure similar to `value`, but with tensors replaced with their values, and data sets replaced with lists of their elements, all fetched with a single call `session.run()`. Raises: ValueError: If `value` is not a `tf.data.Dataset` or not a structure of tensors and anonoymous tuples. """ py_typecheck.check_type(sess, tf.Session) # TODO(b/113123634): Investigate handling `list`s and `tuple`s of # `tf.data.Dataset`s and what the API would look like to support this. if isinstance(value, DATASET_REPRESENTATION_TYPES): with sess.graph.as_default(): iterator = tf.compat.v1.data.make_one_shot_iterator(value) next_element = iterator.get_next() elements = [] while True: try: elements.append(sess.run(next_element)) except tf.errors.OutOfRangeError: break return elements else: flattened_value = anonymous_tuple.flatten(value) dataset_results = {} flat_tensors = [] for idx, v in enumerate(flattened_value): if isinstance(v, DATASET_REPRESENTATION_TYPES): dataset_results[idx] = fetch_value_in_session(sess, v) elif tf.is_tensor(v): flat_tensors.append(v) else: raise ValueError('Unsupported value type {}.'.format(str(v))) flat_computed_tensors = sess.run(flat_tensors) flattened_results = _interleave_dataset_results_and_tensors( dataset_results, flat_computed_tensors) def _to_unicode(v): if six.PY3 and isinstance(v, bytes): return v.decode('utf-8') return v if tf.is_tensor(value) and value.dtype == tf.string: flattened_results = [ _to_unicode(result) for result in flattened_results ] return anonymous_tuple.pack_sequence_as(value, flattened_results)
def test_pack_sequence_as_fails_non_anonymous_tuple(self): x = anonymous_tuple.AnonymousTuple([ ('a', 10), ('b', { 'd': 20 }), ('c', 30), ]) y = [10, 20, 30] with self.assertRaisesRegex(TypeError, 'Cannot pack sequence'): _ = anonymous_tuple.pack_sequence_as(x, y)
def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring param_elements = [] if arg is not None: arg_parts = anonymous_tuple.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( str(len(param_fns)), str(len(arg_parts)))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return anonymous_tuple.pack_sequence_as(result_type, result_elements)
def test_flatten_and_pack_sequence_as(self): x = anonymous_tuple.AnonymousTuple([ ('a', 10), ('b', anonymous_tuple.AnonymousTuple([ ('x', anonymous_tuple.AnonymousTuple([('p', 40)])), ('y', 30), ('z', anonymous_tuple.AnonymousTuple([('q', 50), ('r', 60)])), ])), ('c', 20), ]) y = anonymous_tuple.flatten(x) self.assertEqual(y, [10, 40, 30, 50, 60, 20]) z = anonymous_tuple.pack_sequence_as(x, y) self.assertEqual(str(z), '<a=10,b=<x=<p=40>,y=30,z=<q=50,r=60>>,c=20>')