예제 #1
0
  def _testExceptionHandling(self, py_exp, tf_exp, eager=False):

    def inner_exception():
      raise py_exp("blah")  # pylint: disable=not-callable

    def raise_exception():
      inner_exception()

    expected_regexp = r": blah.*"               # Error at the top
    expected_regexp += r"in raise_exception.*"  # Stacktrace outer
    expected_regexp += r"in inner_exception.*"  # Stacktrace inner
    expected_regexp += r": blah"                # Stacktrace of raise
    def expected_error_check(exception):
      return re.search(expected_regexp, str(exception), re.DOTALL)

    if eager:
      if context.executing_eagerly():
        with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
          f = script_ops.eager_py_func(raise_exception, [], [])
        return
      else:
        f = script_ops.eager_py_func(raise_exception, [], [])
    else:
      f = script_ops.py_func(raise_exception, [], [])

    with self.test_session():
      with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
        self.evaluate(f)
예제 #2
0
 def make_graphs():
   for _ in xrange(1000):
     g = ops.Graph()
     with g.as_default():
       c = constant_op.constant([1.], dtypes.float32)
       _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
       _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
       # These ops have a reference to 'c' which has a reference to the graph.
       # Checks if the functions are being deleted though the graph is referenced from them.
       # (see #18292)
       _ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
       _ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
예제 #3
0
 def testScalar(self):
   with self.test_session():
     x = constant_op.constant(1.0, dtypes.float32)
     y = constant_op.constant(2.0, dtypes.float32)
     z = self.evaluate(
         script_ops.eager_py_func(np_func, [x, y], [dtypes.float32]))
     self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32))
예제 #4
0
 def testEagerSingleOutputInt32(self):
   a = array_ops.ones((3, 3), dtype=dtypes.int32)
   x = array_ops.ones((3, 1), dtype=dtypes.int32)
   output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
   with self.test_session():
     ret = self.evaluate(output)
     self.assertAllEqual(ret, [[3], [3], [3]])
예제 #5
0
 def testEagerSingleOutputFloat32(self):
   with test_util.device(use_gpu=True):
     a = array_ops.ones((3, 3), dtype=dtypes.float32)
     x = array_ops.ones((3, 1), dtype=dtypes.float32)
     output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
     ret = self.evaluate(output)
     self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
예제 #6
0
 def testCleanup(self):
   for _ in xrange(1000):
     g = ops.Graph()
     with g.as_default():
       c = constant_op.constant([1.], dtypes.float32)
       _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
       _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
   self.assertTrue(script_ops._py_funcs.size() < 100)
예제 #7
0
 def testEagerArrayOutput(self):
   with test_util.device(use_gpu=True):
     a = array_ops.ones((3, 3), dtype=dtypes.float32)
     x = array_ops.ones((3, 1), dtype=dtypes.float32)
     output = script_ops.eager_py_func(
         lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.float32])
     ret = self.evaluate(output)
     self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]])
예제 #8
0
  def testEagerRespectsDevicePlacmentOfOp(self):

    def f(x):
      return math_ops.square(x)

    def g(x):
      return math_ops.add(x, x)

    with ops.device("/CPU:0"):
      # Explicitly ask for the py_funcs to execute on CPU, even if
      # a GPU is available.
      x = array_ops.placeholder(dtypes.float32)
      y = script_ops.eager_py_func(func=f, inp=[x], Tout=dtypes.float32)
      z = script_ops.eager_py_func(func=g, inp=[y], Tout=dtypes.float32)

    with self.test_session(use_gpu=True) as sess:
      output = sess.run(z, feed_dict={x: 3.0})
      self.assertEqual(output, 18.0)
예제 #9
0
  def testEagerGradientGraph(self):

    def f(x):
      return x**2

    x = constant_op.constant(3.0)
    y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
    dy_dx = gradients_impl.gradients(y, x)[0]
    self.assertEqual(self.evaluate(dy_dx), 6.0)
예제 #10
0
  def testEagerArrayOutput(self):
    a = array_ops.ones((3, 3), dtype=dtypes.int32)
    x = array_ops.ones((3, 1), dtype=dtypes.int32)
    output = script_ops.eager_py_func(
        lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.int32])

    with self.test_session():
      ret = self.evaluate(output)
      self.assertAllEqual(ret, [[[3], [3], [3]]])
예제 #11
0
  def testEagerReturningVariableRaisesError(self):
    def return_variable():
      return resource_variable_ops.ResourceVariable(0.0)

    with self.assertRaisesRegexp(errors.UnknownError,
                                 "Attempting to return a variable"):
      output = script_ops.eager_py_func(
          return_variable, inp=[], Tout=dtypes.float32)
      self.evaluate(output)
예제 #12
0
  def _testExceptionHandling(self, py_exp, tf_exp, eager=False):

    def raise_exception():
      raise py_exp("blah")  # pylint: disable=not-callable

    if eager:
      if context.in_eager_mode():
        with self.assertRaisesRegexp(tf_exp, "blah"):
          f = script_ops.eager_py_func(raise_exception, [], [])
        return
      else:
        f = script_ops.eager_py_func(raise_exception, [], [])
    else:
      f = script_ops.py_func(raise_exception, [], [])

    with self.test_session():
      with self.assertRaisesRegexp(tf_exp, "blah"):
        self.evaluate(f)
예제 #13
0
  def testEagerGradientTape(self):

    def f(x):
      return x**2

    x = constant_op.constant(3.0)
    with backprop.GradientTape() as tape:
      tape.watch(x)
      y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
    dy_dx = tape.gradient(y, x)
    self.assertEqual(self.evaluate(dy_dx), 6.0)
예제 #14
0
  def testEagerReturnNone(self):

    def no_return_value():
      return

    output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
    ret = self.evaluate(output)
    if context.in_eager_mode():
      self.assertEquals(len(ret), 0)
    else:
      self.assertIsNone(ret)
예제 #15
0
    def testEagerGradientGraphMultipleArgs(self):
        def f(x, y):
            return x**2 + y**2

        x = constant_op.constant(3.0)
        y = constant_op.constant(4.0)
        z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32)

        dz_dx, dz_dy = gradients_impl.gradients(z, [x, y])
        self.assertEqual(self.evaluate(dz_dx), 6.0)
        self.assertEqual(self.evaluate(dz_dy), 8.0)
예제 #16
0
  def testEagerGradientTape(self):

    def f(x):
      return x**2

    x = constant_op.constant(3.0)
    with backprop.GradientTape() as tape:
      tape.watch(x)
      y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
    dy_dx = tape.gradient(y, x)
    self.assertEqual(self.evaluate(dy_dx), 6.0)
예제 #17
0
  def testEagerReturnNone(self):
    with test_util.device(use_gpu=True):
      def no_return_value():
        return

      output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
      ret = self.evaluate(output)
      if context.executing_eagerly():
        self.assertEquals(len(ret), 0)
      else:
        self.assertIsNone(ret)
예제 #18
0
  def testEagerReturnNone(self):
    with test_util.device(use_gpu=True):
      def no_return_value():
        return

      output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
      ret = self.evaluate(output)
      if context.executing_eagerly():
        self.assertEquals(len(ret), 0)
      else:
        self.assertIsNone(ret)
예제 #19
0
    def testEagerReturningVariableRaisesError(self):
        def return_variable():
            variable = resource_variable_ops.ResourceVariable(0.0)
            return variable

        with self.assertRaisesRegexp(errors.UnknownError,
                                     "Attempting to return a variable"):
            output = script_ops.eager_py_func(return_variable,
                                              inp=[],
                                              Tout=dtypes.float32)
            self.evaluate(output)
예제 #20
0
  def testEagerGradientTape(self):

    def f(x):
      return x**2

    x = constant_op.constant(3.0)
    with backprop.GradientTape() as tape:
      tape.watch(x)
      y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
    dy_dx = tape.gradient(y, x)
    self.assertAllClose(self.evaluate(dy_dx), 6.0)

    # Test complex values
    x = constant_op.constant(3.0 + 3.0j)
    with backprop.GradientTape() as tape:
      tape.watch(x)
      y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.complex128)
    dy_dx = tape.gradient(y, x)
    # Gradient of complex will be the conj
    self.assertAllClose(self.evaluate(dy_dx), 6.0 - 6.0j)
예제 #21
0
  def testRaggedTensorReturn(self):

    def fn(v, l):
      return ragged_tensor.RaggedTensor.from_row_lengths(v, l)

    values = [1, 2, 3, 4, 5, 6]
    lengths = constant_op.constant([3, 1, 2], dtypes.int64)
    out_signature = [ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)]
    y, = script_ops.eager_py_func(fn, [values, lengths], out_signature)
    self.assertIsInstance(y, ragged_tensor.RaggedTensor)
    self.assertAllEqual(y, [[1, 2, 3], [4], [5, 6]])
예제 #22
0
  def testEagerGradientGraphTwoOutputs(self):

    def f(x, y):
      return x * y, x / y

    x = constant_op.constant(3.0)
    y = constant_op.constant(2.0)
    fa, fb = script_ops.eager_py_func(f, inp=[x, y],
                                      Tout=[dtypes.float32, dtypes.float32])
    dy_dx = gradients_impl.gradients(fa + fb, x)[0]
    self.assertEqual(self.evaluate(dy_dx), 2.5)
예제 #23
0
 def _external_func_grad(*grad):
     iList = []
     if self.xEnable:
         iList.extend(x)
     if self.yEnable:
         if isinstance(y, (list, tuple)):
             iList.extend(y)
         else:
             iList.append(y)
     if self.dyEnable:
         iList.extend(grad)
     return script_ops.eager_py_func(self.backward, iList, self.Tin)
예제 #24
0
 def testRaggedBadReturnTypeExpectedRaggedReturnedTensor(self):
     with self.assertRaisesRegex(
         (ValueError, errors.InvalidArgumentError),
             "py_function: func=.* returned .* which did not match Tout=.*"
     ):
         result = script_ops.eager_py_func(
             func=lambda x: x,
             inp=[constant_op.constant([[1, 2, 3]])],
             Tout=[
                 ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)
             ])
         self.evaluate(result)
예제 #25
0
  def testEagerGradientGraphMultipleArgs(self):

    def f(x, y):
      return x**2 + y**2

    x = constant_op.constant(3.0)
    y = constant_op.constant(4.0)
    z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32)

    dz_dx, dz_dy = gradients_impl.gradients(z, [x, y])
    self.assertEqual(self.evaluate(dz_dx), 6.0)
    self.assertEqual(self.evaluate(dz_dy), 8.0)
예제 #26
0
  def testRenamedDeviceInTestClusterCorrectlyIdentifiedAsLocalhost(self):
    if context.executing_eagerly():
      self.skipTest("b/126565353: We don't test eager's remote execution.")

    workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0)
    worker = workers[0]
    session = session_lib.Session(worker.target)
    with ops.device("/job:worker/task:0/cpu:0"):
      a = array_ops.ones((3, 3), dtype=dtypes.float32)
      x = array_ops.ones((3, 1), dtype=dtypes.float32)
      output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
    ret = session.run(output)
    self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
예제 #27
0
    def testEagerGradientTapeMultipleArgs(self):
        def f(x, y):
            return x**2 + y**2

        x = constant_op.constant(3.0)
        y = constant_op.constant(4.0)
        with backprop.GradientTape() as tape:
            tape.watch(x)
            tape.watch(y)
            z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32)

        dz_dx, dz_dy = tape.gradient(z, [x, y])
        self.assertEqual(self.evaluate(dz_dx), 6.0)
        self.assertEqual(self.evaluate(dz_dy), 8.0)
예제 #28
0
  def testEagerGradientTapeMultipleArgs(self):

    def f(x, y):
      return x**2 + y**2

    x = constant_op.constant(3.0)
    y = constant_op.constant(4.0)
    with backprop.GradientTape() as tape:
      tape.watch(x)
      tape.watch(y)
      z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32)

    dz_dx, dz_dy = tape.gradient(z, [x, y])
    self.assertEqual(self.evaluate(dz_dx), 6.0)
    self.assertEqual(self.evaluate(dz_dy), 8.0)
  def testEagerPyFuncPlacement(self):
    if not ops.executing_eagerly_outside_functions():
      return

    def f(x):
      return math_ops.square(x)

    with ops.device("/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
      const_op = constant_op.constant(3.0, dtype=dtypes.float32)
      # PyFuncOp should be placed on the localhost's address space.
      py_func_op = script_ops.eager_py_func(
          func=f, inp=[const_op], Tout=dtypes.float32)
      self.assertEqual(py_func_op.device,
                       "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME)
      self.assertEqual(self.evaluate(py_func_op), 9.0)
예제 #30
0
    def grab_batch(indices):
      """Grab a batch of data from the inputs."""
      # This uses a py_function to avoid converting the array-like
      # into a Tensor before slicing it, because converting the array-like
      # to a Tensor may force it into memory..
      def py_method(ind):
        def slice_array(data):
          return training_utils.slice_arrays(data, ind.numpy(),
                                             contiguous=contiguous)
        return [slice_array(inp) for inp in flat_inputs]

      flat_out = script_ops.eager_py_func(py_method, [indices], flat_dtypes)
      for v, original_inp in zip(flat_out, flat_inputs):
        v.set_shape(dynamic_shape_like(original_inp))
      return nest.pack_sequence_as(inputs, flat_out)
예제 #31
0
파일: external.py 프로젝트: wujinke/MDNT
        def _external_func(*x):
            y = script_ops.eager_py_func(self.forward,
                                         x,
                                         self.Tout,
                                         name='pyfunc')

            def _external_func_grad(*grad):
                iList = []
                if self.xEnable:
                    iList.extend(x)
                if self.yEnable:
                    if isinstance(y, (list, tuple)):
                        iList.extend(y)
                    else:
                        iList.append(y)
                if self.dyEnable:
                    iList.extend(grad)
                return script_ops.eager_py_func(self.backward, iList, self.Tin)

            return y, _external_func_grad
예제 #32
0
    def testEagerGradientGraphLogHuber(self):
        def log_huber(x, m):
            if math_ops.abs(x) <= m:
                return x**2
            else:
                return m**2 * (1 - 2 * math_ops.log(m) + math_ops.log(x**2))

        x = array_ops.placeholder(dtypes.float32)
        m = array_ops.placeholder(dtypes.float32)

        y = script_ops.eager_py_func(func=log_huber,
                                     inp=[x, m],
                                     Tout=dtypes.float32)
        dy_dx = gradients_impl.gradients(y, x)[0]

        with self.cached_session() as sess:
            # Takes the first branch of log_huber.
            y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0})
            self.assertEqual(y, 1.0)
            self.assertEqual(dy_dx, 2.0)
예제 #33
0
  def testEagerGradientGraphLogHuber(self):

    def log_huber(x, m):
      if math_ops.abs(x) <= m:
        return x**2
      else:
        return m**2 * (1 - 2 * math_ops.log(m) + math_ops.log(x**2))

    x = array_ops.placeholder(dtypes.float32)
    m = array_ops.placeholder(dtypes.float32)

    y = script_ops.eager_py_func(
        func=log_huber, inp=[x, m], Tout=dtypes.float32)
    dy_dx = gradients_impl.gradients(y, x)[0]

    with self.test_session() as sess:
      # Takes the first branch of log_huber.
      y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0})
      self.assertEqual(y, 1.0)
      self.assertEqual(dy_dx, 2.0)
 def wrapped_fn(*args):  # pylint: disable=missing-docstring
     return script_ops.eager_py_func(
         py_function_wrapper, args,
         structure.get_flat_tensor_types(self._output_structure))
예제 #35
0
 def foo(x):
   spec = ragged_tensor.RaggedTensorSpec.from_value(x)
   res = script_ops.eager_py_func(fn, [x], spec)
   return x + res
예제 #36
0
def wrap_py_func(f, return_dtypes, args, kwargs=None, use_dummy_return=False):
    """Helper that wraps a callable to py_func.

  The helper passes tensor arguments through the py_func interface. Non-tensor
  arguments are allowed, and will be passed to f directly. Note that non-tensor
  arguments are captured by f will not update every time the wrapper is
  called (this is consistent with its argument list, which only includes
  the tensor arguments). In general, it's safest not to reuse this wrapper.

  Args:
    f: Callable
    return_dtypes: None, individual of tuple/list of DType or MatchDType, the
        data type for each of f's return value(s). Set to None if f has no
        return values or use_dummy_return is True. Use MatchDType to define a
        dtype identical to that of `i`th argument (argument 0 is the first);
        an argument must of Tensor type if it is to be used with MatchDType.
    args: Positional arguments for f, as list or tuple.
    kwargs: Keyword arguments for f, as dict with string keys. May be None.
    use_dummy_return: If True, the function will return a dummy value of 1
        and discard its actual return value.
  Returns:
    The return values of f converted to tensor.
  Raises:
    ValueError: if any of the arguments are incorrect.
  """

    if return_dtypes and use_dummy_return:
        raise ValueError(
            'if use_dummy_return is True, return_dtypes must be empty')

    tensor_args = []
    tensor_args_idx = {}

    # Of the positional arguments, only grab the tensor ones to be passed through
    # the py_func.
    n_args = len(args)
    arg_is_tensor = tuple(map(tensor_util.is_tensor, args))
    for i in range(n_args):
        if arg_is_tensor[i]:
            tensor_args_idx[i] = len(tensor_args)
            tensor_args.append(args[i])

    # We essentially take the tensor kwargs, if any, and add them to the list of
    # positional arguments. The kwargs are then reconstructed inside the py_func.
    #
    # For example, if
    #
    #     args = [Tensor(1), 'foo']
    #     kwargs = {'a': Tensor(2), 'b': 'bar'}
    #
    # Then
    #
    #     tensor_args = (Tensor(1), Tensor(2))
    #     kwarg_keys = ('a', 'b')
    if kwargs:
        kwarg_keys = tuple(kwargs.keys())
        kwarg_is_tensor = {
            k: tensor_util.is_tensor(kwargs[k])
            for k in kwarg_keys
        }
        for k in kwarg_keys:
            if kwarg_is_tensor[k]:
                tensor_args_idx[k] = len(tensor_args)
                tensor_args.append(kwargs[k])
    else:
        kwarg_keys = ()

    # Set up return dtypes.
    def match_arg_dtype(arg_number):
        arg = args[arg_number]
        if not arg_is_tensor[arg_number]:
            raise ValueError(
                'argument %d was used with MatchDType and must be a tf.Tensor, but '
                'was %s instead' % (arg_number, type(arg)))
        return arg.dtype

    if return_dtypes:
        if isinstance(return_dtypes, MatchDType):
            return_dtypes = match_arg_dtype(return_dtypes.arg_number)
        elif isinstance(return_dtypes, (list, tuple)):
            return_dtypes = tuple(
                match_arg_dtype(a.arg_number) if isinstance(a, MatchDType
                                                            ) else a
                for a in return_dtypes)
        else:
            assert isinstance(return_dtypes, dtypes.DType)

    def f_wrapper(*tensor_args):
        f_args = tuple(
            tensor_args[tensor_args_idx[i]] if arg_is_tensor[i] else a
            for i, a in enumerate(args))
        f_kwargs = {
            k: tensor_args[tensor_args_idx[k]]
            if kwarg_is_tensor[k] else kwargs[k]
            for i, k in enumerate(kwarg_keys)
        }
        retval = f(*f_args, **f_kwargs)
        return 1 if use_dummy_return else retval

    if use_dummy_return:
        return_dtypes = dtypes.int32
    return script_ops.eager_py_func(f_wrapper, tensor_args, return_dtypes)
예제 #37
0
 def f():
   script_ops.eager_py_func(side_effect_one, [1], [dtypes.int32])
   script_ops.eager_py_func(side_effect_two, [1], [dtypes.int32])
   return 1
예제 #38
0
 def testEagerSingleOutputInt32(self):
   a = array_ops.ones((3, 3), dtype=dtypes.int32)
   x = array_ops.ones((3, 1), dtype=dtypes.int32)
   output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
   ret = self.evaluate(output)
   self.assertAllEqual(ret, [[3], [3], [3]])
예제 #39
0
 def test_fn(v):
     return script_ops.eager_py_func(simple_fn, [v], dtypes.float32)
예제 #40
0
  def testEagerPyFuncNotACallable(self):
    x = constant_op.constant("x", dtype=dtypes.string)

    with self.assertRaisesRegex(ValueError, "callable"):
      _ = script_ops.eager_py_func(x, inp=[x], Tout=dtypes.string)
예제 #41
0
 def testUnsupportedToutType(self):
   with self.assertRaisesRegex(
       TypeError, "Cannot convert .* to a TensorFlow DType."):
     script_ops.eager_py_func(lambda x: x, [1], [{}])
예제 #42
0
 def testRaggedTensorArg(self):
   x = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]])
   y, = script_ops.eager_py_func(math_ops.reduce_sum, [x], [dtypes.int32])
   self.assertAllEqual(y, 21)
 def inner():
     x = constant_op.constant([[1, 2, 3]])
     y = script_ops.eager_py_func(lambda: [[1, 2, 3]], (), dtypes.int32)
     return math_ops.matmul(x, y)
예제 #44
0
 def test_fn(v):
     script_ops.eager_py_func(simple_fn, [v.handle], dtypes.float32)
     return 1.
예제 #45
0
def wrap_py_func(f, return_dtypes, args, kwargs=None, use_dummy_return=False):
  """Helper that wraps a callable to py_func.

  The helper passes tensor arguments through the py_func interface. Non-tensor
  arguments are allowed, and will be passed to f directly. Note that non-tensor
  arguments are captured by f will not update every time the wrapper is
  called (this is consistent with its argument list, which only includes
  the tensor arguments). In general, it's safest not to reuse this wrapper.

  Args:
    f: Callable
    return_dtypes: None, individual of tuple/list of DType or MatchDType, the
        data type for each of f's return value(s). Set to None if f has no
        return values or use_dummy_return is True. Use MatchDType to define a
        dtype identical to that of `i`th argument (argument 0 is the first);
        an argument must of Tensor type if it is to be used with MatchDType.
    args: Positional arguments for f, as list or tuple.
    kwargs: Keyword arguments for f, as dict with string keys. May be None.
    use_dummy_return: If True, the function will return a dummy value of 1
        and discard its actual return value.
  Returns:
    The return values of f converted to tensor.
  Raises:
    ValueError: if any of the arguments are incorrect.
  """

  if return_dtypes and use_dummy_return:
    raise ValueError('if use_dummy_return is True, return_dtypes must be empty')

  tensor_args = []
  tensor_args_idx = {}

  # Of the positional arguments, only grab the tensor ones to be passed through
  # the py_func.
  n_args = len(args)
  arg_is_tensor = tuple(map(tensor_util.is_tensor, args))
  for i in range(n_args):
    if arg_is_tensor[i]:
      tensor_args_idx[i] = len(tensor_args)
      tensor_args.append(args[i])

  # We essentially take the tensor kwargs, if any, and add them to the list of
  # positional arguments. The kwargs are then reconstructed inside the py_func.
  #
  # For example, if
  #
  #     args = [Tensor(1), 'foo']
  #     kwargs = {'a': Tensor(2), 'b': 'bar'}
  #
  # Then
  #
  #     tensor_args = (Tensor(1), Tensor(2))
  #     kwarg_keys = ('a', 'b')
  if kwargs:
    kwarg_keys = tuple(kwargs.keys())
    kwarg_is_tensor = {k: tensor_util.is_tensor(kwargs[k]) for k in kwarg_keys}
    for k in kwarg_keys:
      if kwarg_is_tensor[k]:
        tensor_args_idx[k] = len(tensor_args)
        tensor_args.append(kwargs[k])
  else:
    kwarg_keys = ()

  # Set up return dtypes.
  def match_arg_dtype(arg_number):
    arg = args[arg_number]
    if not arg_is_tensor[arg_number]:
      raise ValueError(
          'argument %d was used with MatchDType and must be a tf.Tensor, but '
          'was %s instead' % (arg_number, type(arg)))
    return arg.dtype

  if return_dtypes:
    if isinstance(return_dtypes, MatchDType):
      return_dtypes = match_arg_dtype(return_dtypes.arg_number)
    elif isinstance(return_dtypes, (list, tuple)):
      return_dtypes = tuple(
          match_arg_dtype(a.arg_number) if isinstance(a, MatchDType) else a
          for a in return_dtypes)
    else:
      assert isinstance(return_dtypes, dtypes.DType)

  def f_wrapper(*tensor_args):
    f_args = tuple(tensor_args[tensor_args_idx[i]] if arg_is_tensor[i] else a
                   for i, a in enumerate(args))
    f_kwargs = {
        k: tensor_args[tensor_args_idx[k]] if kwarg_is_tensor[k] else kwargs[k]
        for i, k in enumerate(kwarg_keys)
    }
    retval = f(*f_args, **f_kwargs)
    return 1 if use_dummy_return else retval

  if use_dummy_return:
    return_dtypes = dtypes.int32
  return script_ops.eager_py_func(f_wrapper, tensor_args, return_dtypes)
예제 #46
0
 def wrapper():
   a = array_ops.ones((3, 3), dtype=dtypes.float32)
   x = array_ops.ones((3, 1), dtype=dtypes.float32)
   return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
예제 #47
0
 def testRaggedExpectedListGotSingleValue(self):
   x = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]])
   x_spec = type_spec.type_spec_from_value(x)
   y, = script_ops.eager_py_func(lambda v: v, [x], [x_spec])
   self.assertAllEqual(y, x)
예제 #48
0
 def wrapper():
     a = array_ops.ones((3, 3), dtype=dtypes.float32)
     x = array_ops.ones((3, 1), dtype=dtypes.float32)
     return script_ops.eager_py_func(matmul,
                                     inp=[a, x],
                                     Tout=dtypes.float32)