Exemplo n.º 1
0
 def testIdentityShape(self):
   with self.cached_session():
     shape = [2, 3]
     array_2x3 = [[1, 2, 3], [6, 5, 4]]
     tensor = constant_op.constant(array_2x3)
     self.assertEquals(shape, tensor.get_shape())
     self.assertEquals(shape, array_ops.identity_n([tensor])[0].get_shape())
     self.assertEquals(shape, array_ops.identity_n([array_2x3])[0].get_shape())
 def testIdentityShape(self):
   with self.test_session():
     shape = [2, 3]
     array_2x3 = [[1, 2, 3], [6, 5, 4]]
     tensor = constant_op.constant(array_2x3)
     self.assertEquals(shape, tensor.get_shape())
     self.assertEquals(shape, array_ops.identity_n([tensor])[0].get_shape())
     self.assertEquals(shape, array_ops.identity_n([array_2x3])[0].get_shape())
Exemplo n.º 3
0
 def prune(self, feeds, fetches):
     flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
     for f in flat_feeds + flat_fetches:
         if not isinstance(f, ops.Tensor):
             raise ValueError("Feeds and fetches must be tensors.")
         if f.graph is not self._func_graph:
             raise ValueError(
                 "Can only prune function whose feeds and fetches "
                 "are from this graph (%s). Tensor %s from graph %s" %
                 (self._func_graph, f, f.graph))
     with self._func_graph.as_default():
         pruned_graph = func_graph.FuncGraph("pruned")
         sink_tensor = array_ops.identity_n(flat_fetches)[0]
     lift_map = lift_to_graph.lift_to_graph(sink_tensor,
                                            pruned_graph,
                                            sources=flat_feeds +
                                            self.graph.internal_captures)
     pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches)
     for external_capture, internal_capture in self.graph.captures.items():
         pruned_graph.captures[external_capture] = lift_map[
             internal_capture]
     pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
     pruned_graph.inputs.extend(pruned_graph.captures.values())
     pruned_graph.structured_outputs = nest.map_structure(
         lambda node: lift_map[node], fetches)
     pruned_fn = WrappedFunction(pruned_graph,
                                 variable_holder=self._variable_holder)
     pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
     pruned_fn._arg_keywords = []  # pylint: disable=protected-access
     return pruned_fn
Exemplo n.º 4
0
  def testCopiesOfUnsupportedTypesFailGracefully(self):
    """Tests that copies of unsupported types don't crash."""
    test_types = set([
        np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32,
        np.int64, np.float16, np.float32, np.float16,
        dtypes.bfloat16.as_numpy_dtype
    ])
    shape = (10, 10)
    for unsupported_dtype in test_types - self.all_types:
      with self.test_session() as sess:
        with ops.device("CPU"):
          x = array_ops.placeholder(unsupported_dtype, shape)
        with self.test_scope():
          y, = array_ops.identity_n([x])
        with ops.device("CPU"):
          z = array_ops.identity(y)

          inputs = np.random.randint(-100, 100, shape)
          inputs = inputs.astype(unsupported_dtype)
          # Execution should either succeed or raise an InvalidArgumentError,
          # but not crash. Even "unsupported types" may succeed here since some
          # backends (e.g., the CPU backend) are happy to handle buffers of
          # unsupported types, even if they cannot compute with them.
          try:
            sess.run(z, {x: inputs})
          except errors.InvalidArgumentError:
            pass
Exemplo n.º 5
0
 def outer(y, shp):
     y, shp = control_flow_ops.while_loop_v2(
         lambda *_: True, inner, (y, shp), maximum_iterations=3)
     y, shp = array_ops.identity_n([y, shp])
     return control_flow_ops.while_loop_v2(lambda *_: True,
                                           inner, (y, shp),
                                           maximum_iterations=5)
Exemplo n.º 6
0
  def testCopiesOfUnsupportedTypesFailGracefully(self):
    """Tests that copies of unsupported types don't crash."""
    test_types = set([
        np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32,
        np.int64, np.float16, np.float32, np.float16,
        dtypes.bfloat16.as_numpy_dtype
    ])
    shape = (10, 10)
    for unsupported_dtype in test_types - self.all_types:
      with self.session() as sess:
        with ops.device("CPU"):
          x = array_ops.placeholder(unsupported_dtype, shape)
        with self.test_scope():
          y, = array_ops.identity_n([x])
        with ops.device("CPU"):
          z = array_ops.identity(y)

          inputs = np.random.randint(-100, 100, shape)
          inputs = inputs.astype(unsupported_dtype)
          # Execution should either succeed or raise an InvalidArgumentError,
          # but not crash. Even "unsupported types" may succeed here since some
          # backends (e.g., the CPU backend) are happy to handle buffers of
          # unsupported types, even if they cannot compute with them.
          try:
            sess.run(z, {x: inputs})
          except errors.InvalidArgumentError:
            pass
      def Run(branch, x, fetch_by_name, use_gpu=use_gpu):
        with ops.Graph().as_default() as g:
          @function.Defun(dtypes.float32)
          def two(x):
            return -1, x * 2

          @function.Defun(dtypes.float32)
          def three(x):
            return 0, x * 3

          @function.Defun(dtypes.float32)
          def four(x):
            return 1, x * 4

          outputs = gen_functional_ops.case(branch, input=[x],
                                            Tout=[dtypes.int32, dtypes.float32],
                                            branches=[two, three, four],
                                            name="my_case")

          # `outputs` is the list of output tensors of the Case op. We
          # arbitrarily choose the 0th tensor to get the Case op and set the
          # lowering attribute on it.
          outputs[0].op._set_attr("_lower_using_switch_merge",
                                  attr_value_pb2.AttrValue(b=True))
          outputs = array_ops.identity_n(outputs)
        with self.session(graph=g, use_gpu=use_gpu) as sess:
          return sess.run("my_case:1" if fetch_by_name else outputs[1])
 def testInt32String_6(self):
   with self.test_session() as sess:
     [value0, value1] = sess.run(
         array_ops.identity_n([[1, 2, 3, 4, 5, 6],
                               [b"a", b"b", b"C", b"d", b"E", b"f", b"g"]]))
   self.assertAllEqual(np.array([1, 2, 3, 4, 5, 6]), value0)
   self.assertAllEqual(
       np.array([b"a", b"b", b"C", b"d", b"E", b"f", b"g"]), value1)
Exemplo n.º 9
0
 def testInt32String_6(self):
   with self.cached_session() as sess:
     [value0, value1] = sess.run(
         array_ops.identity_n([[1, 2, 3, 4, 5, 6],
                               [b"a", b"b", b"C", b"d", b"E", b"f", b"g"]]))
   self.assertAllEqual(np.array([1, 2, 3, 4, 5, 6]), value0)
   self.assertAllEqual(
       np.array([b"a", b"b", b"C", b"d", b"E", b"f", b"g"]), value1)
    def testInt32String_6(self):
        value0, value1 = self.evaluate(
            array_ops.identity_n([[1, 2, 3, 4, 5, 6],
                                  [b"a", b"b", b"C", b"d", b"E", b"f", b"g"]]))

        self.assertAllEqual(np.array([1, 2, 3, 4, 5, 6]), value0)
        self.assertAllEqual(
            np.array([b"a", b"b", b"C", b"d", b"E", b"f", b"g"]), value1)
Exemplo n.º 11
0
    def prune(self, feeds, fetches):
        flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
        for f in flat_feeds:
            if not isinstance(f, ops.Tensor):
                raise ValueError("Feeds must be tensors.")
        tensor_fetches = []
        operation_fetches = []
        for f in flat_fetches:
            if isinstance(f, ops.Tensor):
                tensor_fetches.append(f)
            elif isinstance(f, ops.Operation):
                operation_fetches.append(f)
            else:
                raise ValueError("Fetches must be tensors or operations.")
        for f in flat_feeds + flat_fetches:
            if f.graph is not self._func_graph:
                raise ValueError(
                    "Can only prune function whose feeds and fetches "
                    "are from this graph (%s). Tensor %s from graph %s" %
                    (self._func_graph, f, f.graph))
        with self._func_graph.as_default():
            pruned_graph = func_graph.FuncGraph("pruned")
            with ops.control_dependencies(operation_fetches):
                if tensor_fetches:
                    identity_fetches = array_ops.identity_n(tensor_fetches)
                    sink_tensor = identity_fetches[0]
                else:
                    identity_fetches = []
                    sink_tensor = control_flow_ops.no_op()
        lift_map = lift_to_graph.lift_to_graph(sink_tensor,
                                               pruned_graph,
                                               sources=flat_feeds +
                                               self.graph.internal_captures)
        for original_fetch, identity_fetch in zip(tensor_fetches,
                                                  identity_fetches):
            lift_map[original_fetch] = lift_map[identity_fetch]
        pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches
                                    if isinstance(x, ops.Tensor))
        for external_capture, internal_capture in self.graph.captures.items():
            pruned_graph.captures[external_capture] = lift_map[
                internal_capture]
        pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
        pruned_graph.inputs.extend(pruned_graph.captures.values())

        def _structured_output_mapping(fetched):
            lifted = lift_map[fetched]
            if isinstance(lifted, ops.Operation):
                return None
            return lifted

        pruned_graph.structured_outputs = nest.map_structure(
            _structured_output_mapping, fetches)
        pruned_fn = WrappedFunction(pruned_graph,
                                    variable_holder=self._variable_holder)
        pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
        pruned_fn._arg_keywords = []  # pylint: disable=protected-access
        return pruned_fn
Exemplo n.º 12
0
  def testFastpathExecute_IdentityNCorrectResponse(self):
    ctx = context.context()
    a_2_by_2 = random_ops.random_uniform((2, 2))
    b_2_by_2 = random_ops.random_uniform((2, 2))

    self.assertAllClose(
        array_ops.identity_n([a_2_by_2, b_2_by_2]),
        pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
                                                 "IdentityN", None, None,
                                                 [a_2_by_2, b_2_by_2]))
Exemplo n.º 13
0
  def testFastpathExecute_IdentityNCorrectResponse(self):
    ctx = context.context()
    a_2_by_2 = random_ops.random_uniform((2, 2))
    b_2_by_2 = random_ops.random_uniform((2, 2))

    self.assertAllClose(
        array_ops.identity_n([a_2_by_2, b_2_by_2]),
        pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
                                                 "IdentityN", None, None,
                                                 [a_2_by_2, b_2_by_2]))
Exemplo n.º 14
0
def fwd_gradients(ys, xs, grad_xs=None, assert_unused=False):
    """Computes forward-mode derivatives.

  This is accomplished in pure-python using tensorflow's existing (reverse-mode)
  gradients. There is additional overhead on graph construction, but runtime
  performance should be equal to a manual implementation [citation needed].

  See https://j-towns.github.io/2017/06/12/A-new-trick.html and
  https://github.com/HIPS/autograd/pull/175 for the original discussion of this
  method, and https://github.com/renmengye/tensorflow-forward-ad for a "direct"
  implementation.

  Args:
    ys: A list of tensors.
    xs: A list of tensors.
    grad_xs: An optional list of tensors. If provided, must have the same length
      and shapes compatible with xs.
    assert_unused: Add assertions that intermediate values are not computed.
  Returns:
    A list of tensors of the same shapes as ys. The directional derivatives of
    ys with respect to xs in the direction grad_xs. Leaving grad_xs unspecified
    is equivalent to passing in 1s for each x in xs.
  """
    # This version of forward-mode autodiff is based on code by Tim Cooijmans
    # and handles list arguments and certain special cases such as when the
    # ys doesn't depend on one or more of the xs, and when tf.IndexedSlices are
    # generated by the first tf.gradients call.

    us = [array_ops.zeros_like(y) + float('nan') for y in ys]

    dydxs = gradients(ys, xs, grad_ys=us)

    # deal with strange types that tf.gradients returns but can't deal with
    dydxs = [
        ops.convert_to_tensor(dydx)
        if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs
    ]

    if assert_unused:
        with ops.control_dependencies(dydxs):
            assert_unused = control_flow_ops.Assert(False, [1],
                                                    name='fwd_gradients')
        with ops.control_dependencies([assert_unused]):
            dydxs = array_ops.identity_n(dydxs)

    dydxs = [
        array_ops.zeros_like(x) if dydx is None else dydx
        for x, dydx in zip(xs, dydxs)
    ]
    for x, dydx in zip(xs, dydxs):
        dydx.set_shape(x.shape)

    dysdx = gradients(dydxs, us, grad_ys=grad_xs)

    return dysdx
Exemplo n.º 15
0
    def testFastpathExecute_IdentityNCorrectResponse(self):
        ctx = context.context()
        ctx.ensure_initialized()

        a_2_by_2 = random_ops.random_uniform((2, 2))
        b_2_by_2 = random_ops.random_uniform((2, 2))

        self.assertAllClose(
            array_ops.identity_n([a_2_by_2, b_2_by_2]),
            pywrap_tfe.TFE_Py_FastPathExecute(ctx, "IdentityN", None,
                                              [a_2_by_2, b_2_by_2]))
Exemplo n.º 16
0
 def testInt32_shapes(self):
   with self.cached_session() as sess:
     inp0 = constant_op.constant([10, 20, 30, 40, 50, 60], shape=[2, 3])
     inp1 = constant_op.constant([11, 21, 31, 41, 51, 61], shape=[3, 2])
     inp2 = constant_op.constant(
         [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], shape=[5, 3])
     [value0, value1,
      value2] = sess.run(array_ops.identity_n([inp0, inp1, inp2]))
   self.assertAllEqual(np.array([[10, 20, 30], [40, 50, 60]]), value0)
   self.assertAllEqual(np.array([[11, 21], [31, 41], [51, 61]]), value1)
   self.assertAllEqual(
       np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]]),
       value2)
    def testInt32_shapes(self):
        inp0 = constant_op.constant([10, 20, 30, 40, 50, 60], shape=[2, 3])
        inp1 = constant_op.constant([11, 21, 31, 41, 51, 61], shape=[3, 2])
        inp2 = constant_op.constant(
            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], shape=[5, 3])
        value0, value1, value2 = self.evaluate(
            array_ops.identity_n([inp0, inp1, inp2]))

        self.assertAllEqual(np.array([[10, 20, 30], [40, 50, 60]]), value0)
        self.assertAllEqual(np.array([[11, 21], [31, 41], [51, 61]]), value1)
        self.assertAllEqual(
            np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12],
                      [13, 14, 15]]), value2)
Exemplo n.º 18
0
def fwd_gradients(ys, xs, grad_xs=None, assert_unused=False):
  """Computes forward-mode derivatives.

  This is accomplished in pure-python using tensorflow's existing (reverse-mode)
  gradients. There is additional overhead on graph construction, but runtime
  performance should be equal to a manual implementation [citation needed].

  See https://j-towns.github.io/2017/06/12/A-new-trick.html and
  https://github.com/HIPS/autograd/pull/175 for the original discussion of this
  method, and https://github.com/renmengye/tensorflow-forward-ad for a "direct"
  implementation.

  Args:
    ys: A list of tensors.
    xs: A list of tensors.
    grad_xs: An optional list of tensors. If provided, must have the same length
      and shapes compatible with xs.
    assert_unused: Add assertions that intermediate values are not computed.
  Returns:
    A list of tensors of the same shapes as ys. The directional derivatives of
    ys with respect to xs in the direction grad_xs. Leaving grad_xs unspecified
    is equivalent to passing in 1s for each x in xs.
  """
  # This version of forward-mode autodiff is based on code by Tim Cooijmans
  # and handles list arguments and certain special cases such as when the
  # ys doesn't depend on one or more of the xs, and when tf.IndexedSlices are
  # generated by the first tf.gradients call.

  us = [array_ops.zeros_like(y) + float('nan') for y in ys]

  dydxs = gradients(ys, xs, grad_ys=us)

  # deal with strange types that tf.gradients returns but can't deal with
  dydxs = [ops.convert_to_tensor(dydx) if isinstance(dydx, ops.IndexedSlices)
           else dydx for dydx in dydxs]

  if assert_unused:
    with ops.control_dependencies(dydxs):
      assert_unused = control_flow_ops.Assert(False, [1], name='fwd_gradients')
    with ops.control_dependencies([assert_unused]):
      dydxs = array_ops.identity_n(dydxs)

  dydxs = [array_ops.zeros_like(x) if dydx is None else dydx
           for x, dydx in zip(xs, dydxs)]
  for x, dydx in zip(xs, dydxs):
    dydx.set_shape(x.shape)

  dysdx = gradients(dydxs, us, grad_ys=grad_xs)

  return dysdx
Exemplo n.º 19
0
    def decorated(*args, **kwargs):
        """Decorated function with custom gradient."""
        if context.in_graph_mode():
            if kwargs:
                raise ValueError(
                    "custom_gradient in graph mode doesn't support keyword arguments."
                )
            name = "CustomGradient-%s" % tf_ops.uid()
            args = [tf_ops.convert_to_tensor(x) for x in args]
            result, grad_fn = f(*args)
            flat_result = nest.flatten(result)
            all_tensors = flat_result + args

            @tf_ops.RegisterGradient(name)
            def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
                gradients = nest.flatten(
                    grad_fn(*result_grads[:len(flat_result)]))
                # Need to return one value per input to the IdentityN, so pad the
                # gradients of the inputs of the custom_gradient function with the
                # gradients of the outputs as well.
                return ([None] * len(flat_result)) + gradients

            with tf_ops.get_default_graph().gradient_override_map(
                {"IdentityN": name}):
                all_tensors = array_ops.identity_n(all_tensors)
            return nest.pack_sequence_as(
                structure=result, flat_sequence=all_tensors[:len(flat_result)])

        input_tensors = []
        for x in args:
            if isinstance(x, tf_ops.Tensor):
                input_tensors.append(x)
            if isinstance(x, resource_variable_ops.ResourceVariable):
                input_tensors.append(x.read_value())

        with tape.stop_recording():
            result, grad_fn = f(*args, **kwargs)

        # TODO(apassos): naive uses of custom_gradient will not get the correct
        # second derivative this way if they capture any output tensors. Change the
        # signature of custom_gradient.
        def actual_grad_fn(*outputs):
            return nest.flatten(grad_fn(*outputs))

        flat_result = nest.flatten(result)
        tape.record_operation(f.__name__, flat_result, input_tensors,
                              actual_grad_fn)
        flat_result = list(flat_result)
        return result
Exemplo n.º 20
0
  def testFastpathExecute_IdentityNTapeWrite(self):
    ctx = context.context()
    a_2_by_2 = random_ops.random_uniform((2, 2))
    b_2_by_2 = random_ops.random_uniform((2, 2))

    with backprop.GradientTape(persistent=True) as tape:
      tape.watch(a_2_by_2)
      tape.watch(b_2_by_2)
      z1 = pywrap_tensorflow.TFE_Py_FastPathExecute(
          ctx._handle, ctx.device_name, "IdentityN", None, None,
          [a_2_by_2, b_2_by_2])
      z2 = array_ops.identity_n([a_2_by_2, b_2_by_2])
    dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0]
    dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0]
    self.assertAllEqual(dz1_dy.numpy(), dz2_dy.numpy())
Exemplo n.º 21
0
  def testFastpathExecute_IdentityNTapeWrite(self):
    ctx = context.context()
    a_2_by_2 = random_ops.random_uniform((2, 2))
    b_2_by_2 = random_ops.random_uniform((2, 2))

    with backprop.GradientTape(persistent=True) as tape:
      tape.watch(a_2_by_2)
      tape.watch(b_2_by_2)
      z1 = pywrap_tensorflow.TFE_Py_FastPathExecute(
          ctx._handle, ctx.device_name, "IdentityN", None, None,
          [a_2_by_2, b_2_by_2])
      z2 = array_ops.identity_n([a_2_by_2, b_2_by_2])
    dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0]
    dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0]
    self.assertAllEqual(dz1_dy.numpy(), dz2_dy.numpy())
Exemplo n.º 22
0
  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    if context.in_graph_mode():
      if kwargs:
        raise ValueError(
            "custom_gradient in graph mode doesn't support keyword arguments.")
      name = "CustomGradient-%s" % tf_ops.uid()
      args = [tf_ops.convert_to_tensor(x) for x in args]
      result, grad_fn = f(*args)
      flat_result = nest.flatten(result)
      all_tensors = flat_result + args

      @tf_ops.RegisterGradient(name)
      def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)]))
        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        return ([None] * len(flat_result)) + gradients

      with tf_ops.get_default_graph().gradient_override_map(
          {"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)
      return nest.pack_sequence_as(
          structure=result, flat_sequence=all_tensors[:len(flat_result)])

    input_tensors = [x for x in args
                     if isinstance(x, tf_ops.Tensor)]

    with tape.stop_recording():
      result, grad_fn = f(*args, **kwargs)

    # TODO(apassos): naive uses of custom_gradient will not get the correct
    # second derivative this way if they capture any output tensors. Change the
    # signature of custom_gradient.
    def actual_grad_fn(*outputs):
      return grad_fn(*outputs)

    flat_result = nest.flatten(result)
    tape.record_operation(
        f.__name__,
        flat_result,
        input_tensors,
        [],
        actual_grad_fn)
    flat_result = list(flat_result)
    return result
Exemplo n.º 23
0
  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    if context.in_graph_mode():
      if kwargs:
        raise ValueError(
            "custom_gradient in graph mode doesn't support keyword arguments.")
      name = "CustomGradient-%s" % tf_ops.uid()
      args = [tf_ops.convert_to_tensor(x) for x in args]
      result, grad_fn = f(*args)
      flat_result = nest.flatten(result)
      all_tensors = flat_result + args

      @tf_ops.RegisterGradient(name)
      def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)]))
        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        return ([None] * len(flat_result)) + gradients

      with tf_ops.get_default_graph().gradient_override_map(
          {"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)
      return nest.pack_sequence_as(
          structure=result, flat_sequence=all_tensors[:len(flat_result)])

    input_tensors = [tf_ops.convert_to_tensor(x) for x in args]

    result, grad_fn = f(*args, **kwargs)
    flat_result = nest.flatten(result)
    # TODO(apassos) consider removing the identity below.
    flat_result = [gen_array_ops.identity(x) for x in flat_result]

    def actual_grad_fn(*outputs):
      return nest.flatten(grad_fn(*outputs))

    tape.record_operation(
        f.__name__,
        flat_result,
        input_tensors,
        actual_grad_fn)
    flat_result = list(flat_result)
    return nest.pack_sequence_as(result, flat_result)
Exemplo n.º 24
0
    def decorated(*args, **kwargs):
        """Decorated function with custom gradient."""
        if context.in_graph_mode():
            if kwargs:
                raise ValueError(
                    "custom_gradient in graph mode doesn't support keyword arguments."
                )
            name = "CustomGradient-%s" % tf_ops.uid()
            args = [tf_ops.convert_to_tensor(x) for x in args]
            result, grad_fn = f(*args)
            flat_result = nest.flatten(result)
            all_tensors = flat_result + args

            @tf_ops.RegisterGradient(name)
            def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
                gradients = nest.flatten(
                    grad_fn(*result_grads[:len(flat_result)]))
                # Need to return one value per input to the IdentityN, so pad the
                # gradients of the inputs of the custom_gradient function with the
                # gradients of the outputs as well.
                return ([None] * len(flat_result)) + gradients

            with tf_ops.get_default_graph().gradient_override_map(
                {"IdentityN": name}):
                all_tensors = array_ops.identity_n(all_tensors)
            return nest.pack_sequence_as(
                structure=result, flat_sequence=all_tensors[:len(flat_result)])

        input_tensors = [tf_ops.convert_to_tensor(x) for x in args]

        with tape.stop_recording():
            result, grad_fn = f(*args, **kwargs)
            flat_result = nest.flatten(result)
            # TODO(apassos) consider removing the identity below.
            flat_result = [gen_array_ops.identity(x) for x in flat_result]

        def actual_grad_fn(*outputs):
            return nest.flatten(grad_fn(*outputs))

        tape.record_operation(f.__name__, flat_result, input_tensors,
                              actual_grad_fn)
        flat_result = list(flat_result)
        return nest.pack_sequence_as(result, flat_result)
Exemplo n.º 25
0
def fwd_gradients(ys, xs, grad_xs=None, assert_unused=False):
    us = [array_ops.zeros_like(y) + float('nan') for y in ys]
    dydxs = tf.gradients(ys, xs, grad_ys=us)
    dydxs = [ops.convert_to_tensor(dydx) if isinstance(dydx, ops.IndexedSlices)
             else dydx for dydx in dydxs]
    if assert_unused:
        with ops.control_dependencies(dydxs):
            assert_unused = control_flow_ops.Assert(False, [1], name='fwd_gradients')
        with ops.control_dependencies([assert_unused]):
            dydxs = array_ops.identity_n(dydxs)

    dydxs = [array_ops.zeros_like(x) if dydx is None else dydx
             for x, dydx in zip(xs, dydxs)]
    for x, dydx in zip(xs, dydxs):
        dydx.set_shape(x.shape)

    dysdx = tf.gradients(dydxs, us, grad_ys=grad_xs)

    return dysdx
Exemplo n.º 26
0
 def prune(self, feeds, fetches):
   flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
   for f in flat_feeds + flat_fetches:
     if not isinstance(f, ops.Tensor):
       raise ValueError("Feeds and fetches must be tensors.")
     if f.graph is not self._func_graph:
       raise ValueError(
           "Can only prune function whose feeds and fetches "
           "are from this graph (%s). Tensor %s from graph %s" % (
               self._func_graph, f, f.graph))
   with self._func_graph.as_default():
     pruned_graph = func_graph.FuncGraph("pruned")
     sink_tensor = array_ops.identity_n(flat_fetches)[0]
   lift_map = lift_to_graph.lift_to_graph(
       sink_tensor, pruned_graph, sources=flat_feeds)
   pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches)
   pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
   pruned_fn = WrappedFunction(
       pruned_graph, variable_holder=self._variable_holder)
   pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
   pruned_fn._arg_keywords = []  # pylint: disable=protected-access
   return pruned_fn
Exemplo n.º 27
0
def Test():

    x = tf.constant(1.0, shape=(5, 3))
    y = tf.constant(1.0, shape=(3, 5))

    s = tf.matmul(x, y)
    t = tf.matmul(y, x)
    [t, s] = array_ops.identity_n([t, s])

    tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
    tensor_info_y = tf.compat.v1.saved_model.utils.build_tensor_info(y)
    tensor_info_s = tf.compat.v1.saved_model.utils.build_tensor_info(s)
    tensor_info_t = tf.compat.v1.saved_model.utils.build_tensor_info(t)

    return {
        'key':
        (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
            inputs={
                'x': tensor_info_x,
                'y': tensor_info_y
            },
            outputs={
                's': tensor_info_s,
                't': tensor_info_t
            },
            method_name='some_function')),
        'key2':
        (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
            inputs={
                'a': tensor_info_y,
                'b': tensor_info_x,
            },
            outputs={
                'c': tensor_info_t,
                'd': tensor_info_s,
            },
            method_name='reverse_arguments'))
    }
      def Run(branch, x):
        @function.Defun(dtypes.float32)
        def two(x):
          return -1, x * 2

        @function.Defun(dtypes.float32)
        def three(x):
          return 0, x * 3

        @function.Defun(dtypes.float32)
        def four(x):
          return 1, x * 4

        outputs = gen_functional_ops.case(branch, input=[x],
                                          Tout=[dtypes.int32, dtypes.float32],
                                          branches=[two, three, four])

        # `outputs` is the list of output tensors of the Case op. We
        # arbitrarily choose the 0th tensor to get the Case op and set the
        # lowering attribute on it.
        outputs[0].op._set_attr("_lower_using_switch_merge",
                                attr_value_pb2.AttrValue(b=True))
        outputs = array_ops.identity_n(outputs)
        return outputs[1]
Exemplo n.º 29
0
def _graph_mode_decorator(f, *args, **kwargs):
    """Implement custom gradient decorator for graph mode."""
    # TODO(rsepassi): Add support for kwargs
    if kwargs:
        raise ValueError(
            "The custom_gradient decorator currently supports keywords "
            "arguments only when eager execution is enabled.")
    name = "CustomGradient-%s" % ops.uid()
    args = [ops.convert_to_tensor(x) for x in args]

    # Checking global and local variables attempts to ensure that no non-resource
    # Variables are added to the graph.
    current_var_scope = variable_scope.get_variable_scope()
    before_vars = set(current_var_scope.global_variables() +
                      current_var_scope.local_variables())
    with backprop.GradientTape() as tape:
        result, grad_fn = f(*args)
    after_vars = set(current_var_scope.global_variables() +
                     current_var_scope.local_variables())
    new_vars = after_vars - before_vars
    for v in new_vars:
        if not isinstance(v, resource_variable_ops.ResourceVariable):
            raise TypeError(
                "All variables used by a function wrapped with @custom_gradient must "
                "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
                "with `use_resource=False`.")
    # The variables that grad_fn needs to return gradients for are the set of
    # variables used that are *not* part of the inputs.
    variables = list(set(tape.watched_variables()) - set(args))
    grad_argspec = tf_inspect.getfullargspec(grad_fn)
    variables_in_signature = ("variables" in grad_argspec.args
                              or grad_argspec.varkw)
    if variables and not variables_in_signature:
        raise TypeError("If using @custom_gradient with a function that "
                        "uses variables, then grad_fn must accept a keyword "
                        "argument 'variables'.")
    if variables_in_signature and not variables:
        # User seems to intend to use variables but none were captured.
        if not variable_scope.get_variable_scope().use_resource:
            raise TypeError(
                "If using @custom_gradient with a function that "
                "uses variables, the enclosing variable scope must "
                "have use_resource=True.")
        else:
            logging.warn(
                "@custom_gradient grad_fn has 'variables' in signature, but "
                "no ResourceVariables were used on the forward pass.")
    flat_result = nest.flatten(result)
    all_tensors = flat_result + args + variables

    @ops.RegisterGradient(name)
    def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        """Custom grad fn wrapper."""
        result_grads = result_grads[:len(flat_result)]
        if variables:
            input_grads, variable_grads = grad_fn(*result_grads,
                                                  variables=variables)
            if len(variable_grads) != len(variables):
                raise ValueError("Must return gradient for each variable from "
                                 "@custom_gradient grad_fn.")
        else:
            input_grads = grad_fn(*result_grads)
            variable_grads = []

        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        input_grads = nest.flatten(input_grads)
        return ([None] * len(flat_result)) + input_grads + variable_grads

    original_tensors = all_tensors
    with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)
    for ot, t in zip(original_tensors, all_tensors):
        copy_handle_data(ot, t)
    return nest.pack_sequence_as(structure=result,
                                 flat_sequence=all_tensors[:len(flat_result)])
Exemplo n.º 30
0
def _graph_mode_decorator(f, args, kwargs):
  """Implement custom gradient decorator for graph mode."""
  # TODO(rsepassi): Add support for kwargs
  if kwargs:
    raise ValueError(
        "The custom_gradient decorator currently supports keywords "
        "arguments only when eager execution is enabled.")
  name = "CustomGradient-%s" % ops.uid()
  args = nest.map_structure(ops.convert_to_tensor, args)

  # Checking global and local variables attempts to ensure that no non-resource
  # Variables are added to the graph.
  current_var_scope = variable_scope.get_variable_scope()
  before_vars = set([
      v.ref() for v in current_var_scope.global_variables() +
      current_var_scope.local_variables()
  ])
  with tape_lib.VariableWatcher() as variable_watcher:
    result, grad_fn = f(*args)
  args = nest.flatten(args)
  after_vars = set([
      v.ref() for v in current_var_scope.global_variables() +
      current_var_scope.local_variables()
  ])
  new_vars = after_vars - before_vars
  new_vars_list = [v.deref() for v in new_vars]
  for v in new_vars_list:
    if not resource_variable_ops.is_resource_variable(v):
      raise TypeError(
          "All variables used by a function wrapped with @custom_gradient must "
          "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
          "with `use_resource=False`.")

  # It is possible for the caller to pass in an input that is from a different
  # graph. Even though this is not valid we filter these out if they are not
  # from the output graph to make it easier for some code to migrate to custom
  # gradients.
  inputs = nest.flatten(args)
  outputs = nest.flatten(result)
  graphs = {getattr(o, "graph", None) for o in outputs}
  # Not all results may be tensors. However, we want to ensure that all outputs
  # are from the same graph and use that to filter the inputs.
  graphs.discard(None)  # Discard non-graph outputs
  if graphs:
    if len(graphs) > 1:
      raise ValueError("All graph outputs should be from the same graph")
    output_graph = graphs.pop()
    filtered_inputs = []
    for i in inputs:
      if i.graph != output_graph:
        logging.warn("%s does not belong to output graph %s", i, output_graph)
      else:
        filtered_inputs.append(i)

    inputs = filtered_inputs

  # The variables that grad_fn needs to return gradients for are the set of
  # variables used that are *not* part of the inputs.
  variables_in_tape = frozenset([
      v.ref() for v in variable_watcher.watched_variables()
  ]) - frozenset(v.ref() for v in inputs)
  variables_in_subgraph = frozenset([
      v.ref()
      for v in get_dependent_variables(input_ops=inputs, output_ops=outputs)
  ])
  variables = list(
      [v.deref() for v in variables_in_subgraph.union(variables_in_tape)])

  grad_argspec = tf_inspect.getfullargspec(grad_fn)
  variables_in_signature = ("variables" in grad_argspec.args or
                            grad_argspec.varkw)
  if variables and not variables_in_signature:
    raise TypeError("If using @custom_gradient with a function that "
                    "uses variables, then grad_fn must accept a keyword "
                    "argument 'variables'.")
  if variables_in_signature and not variables:
    # User seems to intend to use variables but none were captured.
    logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
                 "no ResourceVariables were used on the forward pass.")
  flat_result = nest.flatten(result)
  flat_result_len = len(flat_result)

  all_tensors = flat_result + inputs + variables

  def tape_grad_fn(*result_grads):
    """Custom grad fn wrapper."""
    result_grads = result_grads[:flat_result_len]
    if variables:
      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
      if len(variable_grads) != len(variables):
        raise ValueError("Must return gradient for each variable from "
                         "@custom_gradient grad_fn.")
    else:
      input_grads = grad_fn(*result_grads)
      variable_grads = []

    # Need to return one value per input to the IdentityN, so pad the
    # gradients of the inputs of the custom_gradient function with the
    # gradients of the outputs as well.
    input_grads = nest.flatten(input_grads)
    return ([None] * flat_result_len) + input_grads + variable_grads

  @ops.RegisterGradient(name)
  def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
    """Custom grad fn wrapper."""
    return tape_grad_fn(*result_grads)

  original_tensors = all_tensors
  with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
    all_tensors = array_ops.identity_n(all_tensors)

  original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]

  # Propagate handle data for happier shape inference for resource variables.
  for i, t in enumerate(original_tensors):
    if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
      all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
  tape_lib.record_operation(
      f.__name__, all_tensors, original_tensors, tape_grad_fn)
  for ot, t in zip(original_tensors, all_tensors):
    copy_handle_data(ot, t)
  return nest.pack_sequence_as(
      structure=result, flat_sequence=all_tensors[:flat_result_len])
Exemplo n.º 31
0
    def prune(self, feeds, fetches, name=None):
        name = name or "pruned"
        flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
        for f in flat_feeds:
            if not isinstance(f, ops.Tensor):
                raise ValueError("Feeds must be tensors.")

        # Ignoring all feeds that are captures allows prune to be called
        # using wrapped_func.inputs even when it uses variables
        internal_captures = self.graph.internal_captures
        flat_feeds = [f for f in flat_feeds if f not in internal_captures]

        tensor_fetches = []
        operation_fetches = []
        for f in flat_fetches:
            if isinstance(f, ops.Tensor):
                tensor_fetches.append(f)
            elif isinstance(f, ops.Operation):
                operation_fetches.append(f)
            else:
                raise ValueError("Fetches must be tensors or operations.")
        for f in flat_feeds + flat_fetches:
            if f.graph is not self._func_graph:
                raise ValueError(
                    "Can only prune function whose feeds and fetches "
                    "are from this graph (%s). Tensor %s from graph %s" %
                    (self._func_graph, f, f.graph))
        with self._func_graph.as_default():
            pruned_graph = func_graph.FuncGraph(name)
            with ops.control_dependencies(operation_fetches):
                if tensor_fetches:
                    identity_fetches = array_ops.identity_n(tensor_fetches)
                    sink_tensor = identity_fetches[0]
                else:
                    identity_fetches = []
                    sink_tensor = array_ops.zeros([])
        lift_map = lift_to_graph.lift_to_graph([sink_tensor],
                                               pruned_graph,
                                               sources=flat_feeds +
                                               internal_captures)
        for original_fetch, identity_fetch in zip(tensor_fetches,
                                                  identity_fetches):
            lift_map[original_fetch] = lift_map[identity_fetch]
        pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches
                                    if isinstance(x, ops.Tensor))
        pruned_graph.control_outputs.extend(
            [lift_map[operation] for operation in operation_fetches])
        if not tensor_fetches:
            pruned_graph.outputs.append(lift_map[sink_tensor])
        for external_capture, internal_capture in self.graph.captures.items():
            pruned_graph.captures[external_capture] = lift_map[
                internal_capture]
        pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
        pruned_graph.inputs.extend(pruned_graph.captures.values())

        pruned_graph.variables = self.graph.variables

        def _structured_output_mapping(fetched):
            lifted = lift_map[fetched]
            if isinstance(lifted, ops.Operation):
                return None
            return lifted

        pruned_graph.structured_outputs = nest.map_structure(
            _structured_output_mapping, fetches)
        pruned_fn = WrappedFunction(pruned_graph,
                                    variable_holder=self._variable_holder)
        pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
        pruned_fn._arg_keywords = []  # pylint: disable=protected-access
        return pruned_fn
Exemplo n.º 32
0
  def prune(self, feeds, fetches):
    flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
    for f in flat_feeds:
      if not isinstance(f, ops.Tensor):
        raise ValueError("Feeds must be tensors.")

    # Ignoring all feeds that are captures allows prune to be called
    # using wrapped_func.inputs even when it uses variables
    internal_captures = self.graph.internal_captures
    flat_feeds = [f for f in flat_feeds
                  if f not in internal_captures]

    tensor_fetches = []
    operation_fetches = []
    for f in flat_fetches:
      if isinstance(f, ops.Tensor):
        tensor_fetches.append(f)
      elif isinstance(f, ops.Operation):
        operation_fetches.append(f)
      else:
        raise ValueError("Fetches must be tensors or operations.")
    for f in flat_feeds + flat_fetches:
      if f.graph is not self._func_graph:
        raise ValueError(
            "Can only prune function whose feeds and fetches "
            "are from this graph (%s). Tensor %s from graph %s" % (
                self._func_graph, f, f.graph))
    with self._func_graph.as_default():
      pruned_graph = func_graph.FuncGraph("pruned")
      with ops.control_dependencies(operation_fetches):
        if tensor_fetches:
          identity_fetches = array_ops.identity_n(tensor_fetches)
          sink_tensor = identity_fetches[0]
        else:
          identity_fetches = []
          sink_tensor = array_ops.zeros([])
    lift_map = lift_to_graph.lift_to_graph(
        [sink_tensor], pruned_graph, sources=flat_feeds + internal_captures)
    for original_fetch, identity_fetch in zip(
        tensor_fetches, identity_fetches):
      lift_map[original_fetch] = lift_map[identity_fetch]
    pruned_graph.outputs.extend(
        lift_map[x] for x in flat_fetches if isinstance(x, ops.Tensor))
    if not tensor_fetches:
      pruned_graph.outputs.append(lift_map[sink_tensor])
    for external_capture, internal_capture in self.graph.captures.items():
      pruned_graph.captures[external_capture] = lift_map[internal_capture]
    pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
    pruned_graph.inputs.extend(pruned_graph.captures.values())

    pruned_graph.variables = self.graph.variables

    def _structured_output_mapping(fetched):
      lifted = lift_map[fetched]
      if isinstance(lifted, ops.Operation):
        return None
      return lifted

    pruned_graph.structured_outputs = nest.map_structure(
        _structured_output_mapping, fetches)
    pruned_fn = WrappedFunction(
        pruned_graph, variable_holder=self._variable_holder)
    pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
    pruned_fn._arg_keywords = []  # pylint: disable=protected-access
    return pruned_fn
 def testString(self):
     source = [b"A", b"b", b"C", b"d", b"E", b"f"]
     [value] = self.evaluate(array_ops.identity_n([source]))
     self.assertAllEqual(source, value)
 def testString(self):
   source = [b"A", b"b", b"C", b"d", b"E", b"f"]
   with self.test_session() as sess:
     [value] = sess.run(array_ops.identity_n([source]))
   self.assertAllEqual(source, value)
 def testMixedTypeListInputEagerFallbackDifferentArity(self):
     array_ops.identity_n([1, 1])
     array_ops.identity_n([1, 1, 1])
 def testMixedTypeListInputFastPathDifferentArity(self):
     # This tests that the FunctionDef cache key contains the number of args.
     array_ops.identity_n([self._m_2_by_2, self._m_2_by_2])
     array_ops.identity_n([self._m_2_by_2, self._m_2_by_2, self._m_2_by_2])
Exemplo n.º 37
0
def _graph_mode_decorator(f, args, kwargs):
    """Implement custom gradient decorator for graph mode."""
    # TODO(rsepassi): Add support for kwargs
    if kwargs:
        raise ValueError(
            "The custom_gradient decorator currently supports keywords "
            "arguments only when eager execution is enabled.")
    name = "CustomGradient-%s" % ops.uid()

    default_graph = ops.get_default_graph()

    def convert_arg(x):
        x = ops.convert_to_tensor(x)
        # If graph building, be sure to capture all inputs
        if default_graph.building_function and x.graph != default_graph:
            x = default_graph.capture(x)
        return x

    args = nest.map_structure(convert_arg, args)

    # Checking global and local variables attempts to ensure that no non-resource
    # Variables are added to the graph.
    current_var_scope = variable_scope.get_variable_scope()
    before_vars = set([
        v.ref() for v in current_var_scope.global_variables() +
        current_var_scope.local_variables()
    ])
    with tape_lib.VariableWatcher() as variable_watcher:
        result, grad_fn = f(*args)

    args = nest.flatten(args)
    flat_result = nest.flatten(result)
    flat_result_len = len(flat_result)

    after_vars = set([
        v.ref() for v in current_var_scope.global_variables() +
        current_var_scope.local_variables()
    ])
    new_vars = after_vars - before_vars
    new_vars_list = [v.deref() for v in new_vars]
    for v in new_vars_list:
        if not resource_variable_ops.is_resource_variable(v):
            raise TypeError(
                "All variables used by a function wrapped with @custom_gradient must "
                "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
                "with `use_resource=False`.")

    # The variables that grad_fn needs to return gradients for are the set of
    # variables used that are *not* part of the inputs.
    variables_in_tape = frozenset(
        [v.ref() for v in variable_watcher.watched_variables()])
    variables_in_subgraph = frozenset([
        v.ref() for v in _get_dependent_variables(input_ops=args,
                                                  output_ops=flat_result)
    ])
    variables = list(
        [v.deref() for v in variables_in_subgraph.union(variables_in_tape)])

    grad_argspec = tf_inspect.getfullargspec(grad_fn)
    variables_in_signature = ("variables" in grad_argspec.args
                              or "variables" in grad_argspec.kwonlyargs
                              or grad_argspec.varkw)
    if variables and not variables_in_signature:
        raise TypeError(
            "@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
            "since function uses variables: {}".format(variables))
    if variables_in_signature and not variables:
        # User seems to intend to use variables but none were captured.
        logging.warn(
            "@custom_gradient grad_fn has 'variables' in signature, but "
            "no ResourceVariables were used on the forward pass.")

    all_tensors = flat_result + args + variables

    def tape_grad_fn(*result_grads):
        """Custom grad fn wrapper."""
        result_grads = result_grads[:flat_result_len]
        if variables:
            input_grads, variable_grads = grad_fn(*result_grads,
                                                  variables=variables)
            if len(variable_grads) != len(variables):
                raise ValueError("Must return gradient for each variable from "
                                 "@custom_gradient grad_fn.")
        else:
            input_grads = grad_fn(*result_grads)
            variable_grads = []

        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        input_grads = nest.flatten(input_grads)
        return ([None] * flat_result_len) + input_grads + variable_grads

    @ops.RegisterGradient(name)
    def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        """Custom grad fn wrapper."""
        return tape_grad_fn(*result_grads)

    original_tensors = all_tensors
    with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)

    original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]

    # Propagate handle data for happier shape inference for resource variables.
    for i, t in enumerate(original_tensors):
        if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
            all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
    tape_lib.record_operation(f.__name__, all_tensors, original_tensors,
                              tape_grad_fn)
    for ot, t in zip(original_tensors, all_tensors):
        copy_handle_data(ot, t)
    return nest.pack_sequence_as(structure=result,
                                 flat_sequence=all_tensors[:flat_result_len])
Exemplo n.º 38
0
 def loop_fn(i):
   return array_ops.identity_n([x, array_ops.gather(x, i)])
Exemplo n.º 39
0
 def testString(self):
   source = [b"A", b"b", b"C", b"d", b"E", b"f"]
   with self.cached_session() as sess:
     [value] = sess.run(array_ops.identity_n([source]))
   self.assertAllEqual(source, value)
 def testMixedTypeListInputFastPath(self):
     array_ops.identity_n([self._m_2_by_2, self._m_2_by_2])
Exemplo n.º 41
0
def _graph_mode_decorator(f, *args, **kwargs):
  """Implement custom gradient decorator for graph mode."""
  # TODO(rsepassi): Add support for kwargs
  if kwargs:
    raise ValueError(
        "The custom_gradient decorator currently supports keywords "
        "arguments only when eager execution is enabled.")
  name = "CustomGradient-%s" % ops.uid()
  args = [ops.convert_to_tensor(x) for x in args]

  # Checking global and local variables attempts to ensure that no non-resource
  # Variables are added to the graph.
  current_var_scope = variable_scope.get_variable_scope()
  before_vars = set(current_var_scope.global_variables() +
                    current_var_scope.local_variables())
  with backprop.GradientTape() as tape:
    result, grad_fn = f(*args)
  after_vars = set(current_var_scope.global_variables() +
                   current_var_scope.local_variables())
  new_vars = after_vars - before_vars
  for v in new_vars:
    if not isinstance(v, resource_variable_ops.ResourceVariable):
      raise TypeError(
          "All variables used by a function wrapped with @custom_gradient must "
          "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
          "with `use_resource=False`.")
  # The variables that grad_fn needs to return gradients for are the set of
  # variables used that are *not* part of the inputs.
  variables = list(set(tape.watched_variables()) - set(args))
  grad_argspec = tf_inspect.getfullargspec(grad_fn)
  variables_in_signature = ("variables" in grad_argspec.args or
                            grad_argspec.varkw)
  if variables and not variables_in_signature:
    raise TypeError("If using @custom_gradient with a function that "
                    "uses variables, then grad_fn must accept a keyword "
                    "argument 'variables'.")
  if variables_in_signature and not variables:
    # User seems to intend to use variables but none were captured.
    if not variable_scope.get_variable_scope().use_resource:
      raise TypeError("If using @custom_gradient with a function that "
                      "uses variables, the enclosing variable scope must "
                      "have use_resource=True.")
    else:
      logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
                   "no ResourceVariables were used on the forward pass.")
  flat_result = nest.flatten(result)
  all_tensors = flat_result + args + variables

  def tape_grad_fn(*result_grads):
    """Custom grad fn wrapper."""
    result_grads = result_grads[:len(flat_result)]
    if variables:
      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
      if len(variable_grads) != len(variables):
        raise ValueError("Must return gradient for each variable from "
                         "@custom_gradient grad_fn.")
    else:
      input_grads = grad_fn(*result_grads)
      variable_grads = []

    # Need to return one value per input to the IdentityN, so pad the
    # gradients of the inputs of the custom_gradient function with the
    # gradients of the outputs as well.
    input_grads = nest.flatten(input_grads)
    return ([None] * len(flat_result)) + input_grads + variable_grads

  @ops.RegisterGradient(name)
  def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
    """Custom grad fn wrapper."""
    return tape_grad_fn(*result_grads)

  original_tensors = all_tensors
  with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
    all_tensors = array_ops.identity_n(all_tensors)
  # Propagate handle data for happier shape inference for resource variables.
  for i, t in enumerate(original_tensors):
    if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
      all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
  tape_lib.record_operation(
      f.__name__, all_tensors, original_tensors, tape_grad_fn)
  for ot, t in zip(original_tensors, all_tensors):
    copy_handle_data(ot, t)
  return nest.pack_sequence_as(
      structure=result, flat_sequence=all_tensors[:len(flat_result)])
 def testMixedTypeListInputEagerFallback(self):
     array_ops.identity_n([1, 1])
Exemplo n.º 43
0
 def loop_fn(i):
     return array_ops.identity_n([x, array_ops.gather(x, i)])
 def testListInputOutput(self):
   self.skipTest("b/185403393")
   array_ops.identity_n([self._m_2_by_2, self._m_2_by_2])