Beispiel #1
0
    def test_wrap_py_func_simple(self):
        def test_fn(a, b, c):
            return a + b + c

        with self.cached_session() as sess:
            result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                          (1, constant_op.constant(1), 1))
            self.assertEqual(3, self.evaluate(result))
            result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
            self.assertEqual(3, self.evaluate(result))
            result = py_func.wrap_py_func(
                test_fn, dtypes.int64,
                (constant_op.constant(1), 1, constant_op.constant(1)))
            self.assertEqual(3, self.evaluate(result))
  def test_wrap_py_func_dummy_return(self):

    side_counter = [0]

    def test_fn(_):
      side_counter[0] += 1

    with self.cached_session() as sess:
      result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
      self.assertEqual(1, sess.run(result))
      self.assertEqual([1], side_counter)
      result = py_func.wrap_py_func(
          test_fn, None, (constant_op.constant(5),), use_dummy_return=True)
      self.assertEqual(1, sess.run(result))
      self.assertEqual([2], side_counter)
Beispiel #3
0
  def test_wrap_py_func_simple(self):

    def test_fn(a, b, c):
      return a + b + c

    with self.cached_session() as sess:
      result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                    (1, constant_op.constant(1), 1))
      self.assertEqual(3, sess.run(result))
      result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
      self.assertEqual(3, sess.run(result))
      result = py_func.wrap_py_func(
          test_fn, dtypes.int64,
          (constant_op.constant(1), 1, constant_op.constant(1)))
      self.assertEqual(3, sess.run(result))
Beispiel #4
0
    def test_wrap_py_func_complex_args(self):
        class TestClass(object):
            def __init__(self):
                self.foo = 5

        def test_fn(a, b):
            return a * b.foo

        with self.cached_session() as sess:
            result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                          (7, TestClass()))
            self.assertEqual(35, self.evaluate(result))
            result = py_func.wrap_py_func(
                test_fn, dtypes.int64, (constant_op.constant(7), TestClass()))
            self.assertEqual(35, self.evaluate(result))
Beispiel #5
0
  def test_wrap_py_func_complex_args(self):

    class TestClass(object):

      def __init__(self):
        self.foo = 5

    def test_fn(a, b):
      return a * b.foo

    with self.cached_session() as sess:
      result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
      self.assertEqual(35, sess.run(result))
      result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                    (constant_op.constant(7), TestClass()))
      self.assertEqual(35, sess.run(result))
Beispiel #6
0
 def py_func_wrapper(*args, **kwargs):
     if kwargs:
         raise NotImplementedError(
             'RunMode.PY_FUNC does not yet support kwargs')
     # TODO(mdan): Add support for kwargs.
     return py_func.wrap_py_func(f,
                                 return_dtypes,
                                 args,
                                 kwargs,
                                 use_dummy_return=not return_dtypes)
  def test_wrap_py_func_kwargs(self):

    class TestClass(object):

      def __init__(self, foo):
        self.foo = foo

    def test_fn(a, b, c, d):
      return a * b.foo + c * d.foo

    with self.cached_session() as sess:
      result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), {
          'c': 11,
          'd': TestClass(13)
      })
      self.assertEqual(178, sess.run(result))
      result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                    (constant_op.constant(7), TestClass(5)), {
                                        'c': constant_op.constant(11),
                                        'd': TestClass(13)
                                    })
      self.assertEqual(178, sess.run(result))
Beispiel #8
0
def _tf_py_func_print(objects, kwargs):
  """Overload of print_ as a py_func implementation."""
  override_kwargs = {k: v for k, v in kwargs.items() if v is not UNSPECIFIED}
  if 'flush' not in override_kwargs:
    # Defaulting to flushing the console in graph mode, which helps reduce
    # garbled output in IPython.
    override_kwargs['flush'] = True

  def print_wrapper(*vals):
    vals = tuple(v.numpy() if tensor_util.is_tf_type(v) else v for v in vals)
    # TensorFlow doesn't seem to generate Unicode when passing strings to
    # py_func. This causes the print to add a "b'" wrapper to the output,
    # which is probably never what you want.
    vals = tuple(v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
    six.print_(*vals, **override_kwargs)

  return py_func.wrap_py_func(
      print_wrapper, None, objects, use_dummy_return=True)
Beispiel #9
0
def _tf_py_func_print(objects, kwargs):
  """Overload of print_ as a py_func implementation."""
  override_kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED}
  if 'flush' not in override_kwargs:
    # Defaulting to flushing the console in graph mode, which helps reduce
    # garbled output in IPython.
    override_kwargs['flush'] = True

  def print_wrapper(*vals):
    if six.PY3:
      # TensorFlow doesn't seem to generate Unicode when passing strings to
      # py_func. This causes the print to add a "b'" wrapper to the output,
      # which is probably never what you want.
      vals = tuple(
          v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
    six.print_(*vals, **override_kwargs)

  return py_func.wrap_py_func(
      print_wrapper, None, objects, use_dummy_return=True)
Beispiel #10
0
 def py_func_wrapper(*args, **kwargs):
   if kwargs:
     raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
   # TODO(mdan): Add support for kwargs.
   return py_func.wrap_py_func(
       f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes)