def test_wrap_py_func_simple(self): def test_fn(a, b, c): return a + b + c with self.cached_session() as sess: result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, constant_op.constant(1), 1)) self.assertEqual(3, self.evaluate(result)) result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1)) self.assertEqual(3, self.evaluate(result)) result = py_func.wrap_py_func( test_fn, dtypes.int64, (constant_op.constant(1), 1, constant_op.constant(1))) self.assertEqual(3, self.evaluate(result))
def test_wrap_py_func_dummy_return(self): side_counter = [0] def test_fn(_): side_counter[0] += 1 with self.cached_session() as sess: result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True) self.assertEqual(1, sess.run(result)) self.assertEqual([1], side_counter) result = py_func.wrap_py_func( test_fn, None, (constant_op.constant(5),), use_dummy_return=True) self.assertEqual(1, sess.run(result)) self.assertEqual([2], side_counter)
def test_wrap_py_func_simple(self): def test_fn(a, b, c): return a + b + c with self.cached_session() as sess: result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, constant_op.constant(1), 1)) self.assertEqual(3, sess.run(result)) result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1)) self.assertEqual(3, sess.run(result)) result = py_func.wrap_py_func( test_fn, dtypes.int64, (constant_op.constant(1), 1, constant_op.constant(1))) self.assertEqual(3, sess.run(result))
def test_wrap_py_func_complex_args(self): class TestClass(object): def __init__(self): self.foo = 5 def test_fn(a, b): return a * b.foo with self.cached_session() as sess: result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass())) self.assertEqual(35, self.evaluate(result)) result = py_func.wrap_py_func( test_fn, dtypes.int64, (constant_op.constant(7), TestClass())) self.assertEqual(35, self.evaluate(result))
def test_wrap_py_func_complex_args(self): class TestClass(object): def __init__(self): self.foo = 5 def test_fn(a, b): return a * b.foo with self.cached_session() as sess: result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass())) self.assertEqual(35, sess.run(result)) result = py_func.wrap_py_func(test_fn, dtypes.int64, (constant_op.constant(7), TestClass())) self.assertEqual(35, sess.run(result))
def py_func_wrapper(*args, **kwargs): if kwargs: raise NotImplementedError( 'RunMode.PY_FUNC does not yet support kwargs') # TODO(mdan): Add support for kwargs. return py_func.wrap_py_func(f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes)
def test_wrap_py_func_kwargs(self): class TestClass(object): def __init__(self, foo): self.foo = foo def test_fn(a, b, c, d): return a * b.foo + c * d.foo with self.cached_session() as sess: result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), { 'c': 11, 'd': TestClass(13) }) self.assertEqual(178, sess.run(result)) result = py_func.wrap_py_func(test_fn, dtypes.int64, (constant_op.constant(7), TestClass(5)), { 'c': constant_op.constant(11), 'd': TestClass(13) }) self.assertEqual(178, sess.run(result))
def _tf_py_func_print(objects, kwargs): """Overload of print_ as a py_func implementation.""" override_kwargs = {k: v for k, v in kwargs.items() if v is not UNSPECIFIED} if 'flush' not in override_kwargs: # Defaulting to flushing the console in graph mode, which helps reduce # garbled output in IPython. override_kwargs['flush'] = True def print_wrapper(*vals): vals = tuple(v.numpy() if tensor_util.is_tf_type(v) else v for v in vals) # TensorFlow doesn't seem to generate Unicode when passing strings to # py_func. This causes the print to add a "b'" wrapper to the output, # which is probably never what you want. vals = tuple(v.decode('utf-8') if isinstance(v, bytes) else v for v in vals) six.print_(*vals, **override_kwargs) return py_func.wrap_py_func( print_wrapper, None, objects, use_dummy_return=True)
def _tf_py_func_print(objects, kwargs): """Overload of print_ as a py_func implementation.""" override_kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED} if 'flush' not in override_kwargs: # Defaulting to flushing the console in graph mode, which helps reduce # garbled output in IPython. override_kwargs['flush'] = True def print_wrapper(*vals): if six.PY3: # TensorFlow doesn't seem to generate Unicode when passing strings to # py_func. This causes the print to add a "b'" wrapper to the output, # which is probably never what you want. vals = tuple( v.decode('utf-8') if isinstance(v, bytes) else v for v in vals) six.print_(*vals, **override_kwargs) return py_func.wrap_py_func( print_wrapper, None, objects, use_dummy_return=True)
def py_func_wrapper(*args, **kwargs): if kwargs: raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs') # TODO(mdan): Add support for kwargs. return py_func.wrap_py_func( f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes)