示例#1
0
 def clip_by_global_norm(delta, clip_norm):
     # TODO(b/123092620): Replace anonymous_tuple with tf.nest.
     delta = anonymous_tuple.from_container(delta)
     clipped, global_norm = tf.clip_by_global_norm(
         anonymous_tuple.flatten(delta), clip_norm)
     return anonymous_tuple.pack_sequence_as(delta,
                                             clipped), global_norm
示例#2
0
    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
        param_elements = []
        if arg is not None:
            arg_parts = anonymous_tuple.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    len(param_fns), len(arg_parts)))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)

        # There is a tf.wrap_function(...) issue b/144127474 that variables created
        # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
        # destroyed.  So get all the variables from `wrapped_fn` and destroy
        # manually.
        # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
        # is fixed.
        resources = []
        for op in wrapped_fn.graph.get_operations():
            if op.type == 'VarHandleOp':
                resources += op.outputs
        if resources:
            for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
                tf.raw_ops.DestroyResourceOp(resource=resource)

        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return anonymous_tuple.pack_sequence_as(result_type, result_elements)
示例#3
0
def fetch_value_in_session(sess, value):
    """Fetches `value` in `session`.

  Args:
    sess: The session in which to perform the fetch (as a single run).
    value: A Python object of a form analogous to that constructed by the
      function `assemble_result_from_graph`, made of tensors and anononymous
      tuples, or a `tf.data.Dataset`.

  Returns:
    A Python object with structure similar to `value`, but with tensors
    replaced with their values, and data sets replaced with lists of their
    elements, all fetched with a single call `session.run()`.

  Raises:
    ValueError: If `value` is not a `tf.data.Dataset` or not a structure of
      tensors and anonoymous tuples.
  """
    py_typecheck.check_type(sess, tf.Session)
    # TODO(b/113123634): Investigate handling `list`s and `tuple`s of
    # `tf.data.Dataset`s and what the API would look like to support this.
    if isinstance(value, DATASET_REPRESENTATION_TYPES):
        with sess.graph.as_default():
            iterator = tf.compat.v1.data.make_one_shot_iterator(value)
            next_element = iterator.get_next()
        elements = []
        while True:
            try:
                elements.append(sess.run(next_element))
            except tf.errors.OutOfRangeError:
                break
        return elements
    else:
        flattened_value = anonymous_tuple.flatten(value)
        dataset_results = {}
        flat_tensors = []
        for idx, v in enumerate(flattened_value):
            if isinstance(v, DATASET_REPRESENTATION_TYPES):
                dataset_results[idx] = fetch_value_in_session(sess, v)
            elif tf.is_tensor(v):
                flat_tensors.append(v)
            else:
                raise ValueError('Unsupported value type {}.'.format(str(v)))
        flat_computed_tensors = sess.run(flat_tensors)
        flattened_results = _interleave_dataset_results_and_tensors(
            dataset_results, flat_computed_tensors)

        def _to_unicode(v):
            if six.PY3 and isinstance(v, bytes):
                return v.decode('utf-8')
            return v

        if tf.is_tensor(value) and value.dtype == tf.string:
            flattened_results = [
                _to_unicode(result) for result in flattened_results
            ]
        return anonymous_tuple.pack_sequence_as(value, flattened_results)
示例#4
0
 def test_pack_sequence_as_fails_non_anonymous_tuple(self):
   x = anonymous_tuple.AnonymousTuple([
       ('a', 10),
       ('b', {
           'd': 20
       }),
       ('c', 30),
   ])
   y = [10, 20, 30]
   with self.assertRaisesRegex(TypeError, 'Cannot pack sequence'):
     _ = anonymous_tuple.pack_sequence_as(x, y)
示例#5
0
 def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
     param_elements = []
     if arg is not None:
         arg_parts = anonymous_tuple.flatten(arg)
         if len(arg_parts) != len(param_fns):
             raise RuntimeError('Expected {} arguments, found {}.'.format(
                 str(len(param_fns)), str(len(arg_parts))))
         for arg_part, param_fn in zip(arg_parts, param_fns):
             param_elements.append(param_fn(arg_part))
     result_parts = wrapped_fn(*param_elements)
     result_elements = []
     for result_part, result_fn in zip(result_parts, result_fns):
         result_elements.append(result_fn(result_part))
     return anonymous_tuple.pack_sequence_as(result_type, result_elements)
示例#6
0
 def test_flatten_and_pack_sequence_as(self):
     x = anonymous_tuple.AnonymousTuple([
         ('a', 10),
         ('b',
          anonymous_tuple.AnonymousTuple([
              ('x', anonymous_tuple.AnonymousTuple([('p', 40)])),
              ('y', 30),
              ('z', anonymous_tuple.AnonymousTuple([('q', 50), ('r', 60)])),
          ])),
         ('c', 20),
     ])
     y = anonymous_tuple.flatten(x)
     self.assertEqual(y, [10, 40, 30, 50, 60, 20])
     z = anonymous_tuple.pack_sequence_as(x, y)
     self.assertEqual(str(z), '<a=10,b=<x=<p=40>,y=30,z=<q=50,r=60>>,c=20>')