Beispiel #1
0
    def test_wrap_py_func_simple(self):
        def test_fn(a, b, c):
            return a + b + c

        with self.test_session() as sess:
            result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                          (1, constant_op.constant(1), 1))
            self.assertEqual(3, sess.run(result))
            result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
            self.assertEqual(3, sess.run(result))
            result = py_func.wrap_py_func(
                test_fn, dtypes.int64,
                (constant_op.constant(1), 1, constant_op.constant(1)))
            self.assertEqual(3, sess.run(result))
Beispiel #2
0
    def test_wrap_py_func_complex_args(self):
        class TestClass(object):
            def __init__(self):
                self.foo = 5

        def test_fn(a, b):
            return a * b.foo

        with self.test_session() as sess:
            result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                          (7, TestClass()))
            self.assertEqual(35, sess.run(result))
            result = py_func.wrap_py_func(
                test_fn, dtypes.int64, (constant_op.constant(7), TestClass()))
            self.assertEqual(35, sess.run(result))
Beispiel #3
0
  def test_wrap_py_func_dummy_return(self):

    side_counter = [0]

    def test_fn(_):
      side_counter[0] += 1

    with self.test_session() as sess:
      result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
      self.assertEqual(1, sess.run(result))
      self.assertEqual([1], side_counter)
      result = py_func.wrap_py_func(
          test_fn, None, (constant_op.constant(5),), use_dummy_return=True)
      self.assertEqual(1, sess.run(result))
      self.assertEqual([2], side_counter)
Beispiel #4
0
  def test_wrap_py_func_simple(self):

    def test_fn(a, b, c):
      return a + b + c

    with self.test_session() as sess:
      result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                    (1, constant_op.constant(1), 1))
      self.assertEqual(3, sess.run(result))
      result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
      self.assertEqual(3, sess.run(result))
      result = py_func.wrap_py_func(
          test_fn, dtypes.int64,
          (constant_op.constant(1), 1, constant_op.constant(1)))
      self.assertEqual(3, sess.run(result))
Beispiel #5
0
 def py_func_wrapper(*args, **kwargs):
   if kwargs:
     raise NotImplementedError(
         'RunMode.PY_FUNC does not yet support kwargs')
   # TODO(mdan): Add support for kwargs.
   return py_func.wrap_py_func(
       f, return_dtypes, args, use_dummy_return=not return_dtypes)
Beispiel #6
0
  def test_wrap_py_func_complex_args(self):

    class TestClass(object):

      def __init__(self):
        self.foo = 5

    def test_fn(a, b):
      return a * b.foo

    with self.test_session() as sess:
      result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
      self.assertEqual(35, sess.run(result))
      result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                    (constant_op.constant(7), TestClass()))
      self.assertEqual(35, sess.run(result))
Beispiel #7
0
 def py_func_wrapper(*args, **kwargs):
   if kwargs:
     raise NotImplementedError(
         'RunMode.PY_FUNC does not yet support kwargs')
   # TODO(mdan): Add support for kwargs.
   return py_func.wrap_py_func(
       f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes)
Beispiel #8
0
    def test_wrap_py_func_dummy_return(self):

        side_counter = [0]

        def test_fn(_):
            side_counter[0] += 1

        with self.test_session() as sess:
            self.assertEqual(
                1, sess.run(py_func.wrap_py_func(test_fn, None, (5, ), True)))
            self.assertEqual([1], side_counter)
            self.assertEqual(
                1,
                sess.run(
                    py_func.wrap_py_func(test_fn, None,
                                         (constant_op.constant(5), ), True)))
            self.assertEqual([2], side_counter)
Beispiel #9
0
  def test_wrap_py_func_simple(self):

    def test_fn(a, b, c):
      return a + b + c

    with self.test_session() as sess:
      tensor_1 = constant_op.constant(1)
      self.assertEqual(3,
                       sess.run(
                           py_func.wrap_py_func(test_fn, dtypes.int64,
                                                (1, tensor_1, 1))))
      self.assertEqual(3,
                       sess.run(
                           py_func.wrap_py_func(test_fn, dtypes.int64,
                                                (1, 1, 1))))
      self.assertEqual(3,
                       sess.run(
                           py_func.wrap_py_func(test_fn, dtypes.int64,
                                                (tensor_1, 1, tensor_1))))
Beispiel #10
0
    def test_wrap_py_func_simple(self):
        def test_fn(a, b, c):
            return a + b + c

        with self.test_session() as sess:
            tensor_1 = constant_op.constant(1)
            self.assertEqual(
                3,
                sess.run(
                    py_func.wrap_py_func(test_fn, dtypes.int64,
                                         (1, tensor_1, 1))))
            self.assertEqual(
                3,
                sess.run(py_func.wrap_py_func(test_fn, dtypes.int64,
                                              (1, 1, 1))))
            self.assertEqual(
                3,
                sess.run(
                    py_func.wrap_py_func(test_fn, dtypes.int64,
                                         (tensor_1, 1, tensor_1))))
Beispiel #11
0
    def test_wrap_py_func_kwargs(self):
        class TestClass(object):
            def __init__(self, foo):
                self.foo = foo

        def test_fn(a, b, c, d):
            return a * b.foo + c * d.foo

        with self.test_session() as sess:
            result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                          (7, TestClass(5)), {
                                              'c': 11,
                                              'd': TestClass(13)
                                          })
            self.assertEqual(178, sess.run(result))
            result = py_func.wrap_py_func(
                test_fn, dtypes.int64, (constant_op.constant(7), TestClass(5)),
                {
                    'c': constant_op.constant(11),
                    'd': TestClass(13)
                })
            self.assertEqual(178, sess.run(result))
Beispiel #12
0
  def test_wrap_py_func_kwargs(self):

    class TestClass(object):

      def __init__(self, foo):
        self.foo = foo

    def test_fn(a, b, c, d):
      return a * b.foo + c * d.foo

    with self.test_session() as sess:
      result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), {
          'c': 11,
          'd': TestClass(13)
      })
      self.assertEqual(178, sess.run(result))
      result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                    (constant_op.constant(7), TestClass(5)), {
                                        'c': constant_op.constant(11),
                                        'd': TestClass(13)
                                    })
      self.assertEqual(178, sess.run(result))
Beispiel #13
0
def dynamic_print(*values):
    """Implementartion of print using dynamic dispatch.

  The function attempts to use tf.Print if all the values are compatible.
  Otherwise, it will fall back to py_func.

  Args:
    *values: values to print
  Returns:
    A dummy value indicating the print completed. If tf.
  """

    if all(map(is_tf_print_compatible, values)):
        return logging_ops.Print(1, values)
    return py_func.wrap_py_func(print, None, values, use_dummy_return=True)
Beispiel #14
0
def call_print(*values):
    """Compiled counterpart of the print builtin.

  The function attempts to use tf.Print if all the values are compatible.
  Otherwise, it will fall back to py_func.

  Args:
    *values: values to print
  Returns:
    A dummy value indicating the print completed. If tf.
  """

    if all(map(is_tf_print_compatible, values)):
        return logging_ops.Print(1, values)
    return py_func.wrap_py_func(print, None, values, use_dummy_return=True)
Beispiel #15
0
def call_print(*values):
  """Compiled counterpart of the print builtin.

  The function attempts to use tf.Print if all the values are compatible.
  Otherwise, it will fall back to py_func.

  Args:
    *values: values to print
  Returns:
    A dummy value indicating the print completed. If tf.
  """

  if all(map(is_tf_print_compatible, values)):
    return logging_ops.Print(1, values)
  return py_func.wrap_py_func(print, None, values, use_dummy_return=True)
Beispiel #16
0
def dynamic_print(*values):
  """Implementartion of print using dynamic dispatch.

  The function attempts to use tf.Print if all the values are compatible.
  Otherwise, it will fall back to py_func.

  Args:
    *values: values to print
  Returns:
    A dummy value indicating the print completed. If tf.
  """

  if all(map(is_tf_print_compatible, values)):
    return logging_ops.Print(1, values)
  return py_func.wrap_py_func(print, None, values, use_dummy_return=True)