示例#1
0
    def test_wrap_py_func_simple(self):
        def test_fn(a, b, c):
            return a + b + c

        with self.test_session() as sess:
            result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                          (1, constant_op.constant(1), 1))
            self.assertEqual(3, sess.run(result))
            result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
            self.assertEqual(3, sess.run(result))
            result = py_func.wrap_py_func(
                test_fn, dtypes.int64,
                (constant_op.constant(1), 1, constant_op.constant(1)))
            self.assertEqual(3, sess.run(result))
示例#2
0
    def test_wrap_py_func_complex_args(self):
        class TestClass(object):
            def __init__(self):
                self.foo = 5

        def test_fn(a, b):
            return a * b.foo

        with self.test_session() as sess:
            result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                          (7, TestClass()))
            self.assertEqual(35, sess.run(result))
            result = py_func.wrap_py_func(
                test_fn, dtypes.int64, (constant_op.constant(7), TestClass()))
            self.assertEqual(35, sess.run(result))
示例#3
0
  def test_wrap_py_func_dummy_return(self):

    side_counter = [0]

    def test_fn(_):
      side_counter[0] += 1

    with self.test_session() as sess:
      result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
      self.assertEqual(1, sess.run(result))
      self.assertEqual([1], side_counter)
      result = py_func.wrap_py_func(
          test_fn, None, (constant_op.constant(5),), use_dummy_return=True)
      self.assertEqual(1, sess.run(result))
      self.assertEqual([2], side_counter)
示例#4
0
  def test_wrap_py_func_simple(self):

    def test_fn(a, b, c):
      return a + b + c

    with self.test_session() as sess:
      result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                    (1, constant_op.constant(1), 1))
      self.assertEqual(3, sess.run(result))
      result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
      self.assertEqual(3, sess.run(result))
      result = py_func.wrap_py_func(
          test_fn, dtypes.int64,
          (constant_op.constant(1), 1, constant_op.constant(1)))
      self.assertEqual(3, sess.run(result))
示例#5
0
def dynamic_print(*values):
  """Implementation of print using dynamic dispatch.

  The function attempts to use tf.Print if all the values are compatible.
  Otherwise, it will fall back to py_func.

  Args:
    *values: values to print
  Returns:
    A dummy value indicating the print completed. If tf.
  """

  if all(map(is_tf_print_compatible, values)):
    return logging_ops.Print(1, values)

  def print_wrapper(*vals):
    if six.PY3:
      # TensorFlow doesn't seem to generate Unicode when passing strings to
      # py_func. This causes the print to add a "b'" wrapper to the output,
      # which is probably never what you want.
      vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals)
    print(*vals)
    # The flush helps avoid garbled output in IPython.
    sys.stdout.flush()

  return py_func.wrap_py_func(
      print_wrapper, None, values, use_dummy_return=True)
示例#6
0
 def py_func_wrapper(*args, **kwargs):
   if kwargs:
     raise NotImplementedError(
         'RunMode.PY_FUNC does not yet support kwargs')
   # TODO(mdan): Add support for kwargs.
   return py_func.wrap_py_func(
       f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes)
示例#7
0
  def test_wrap_py_func_complex_args(self):

    class TestClass(object):

      def __init__(self):
        self.foo = 5

    def test_fn(a, b):
      return a * b.foo

    with self.test_session() as sess:
      result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
      self.assertEqual(35, sess.run(result))
      result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                    (constant_op.constant(7), TestClass()))
      self.assertEqual(35, sess.run(result))
示例#8
0
def dynamic_print(*values):
    """Implementation of print using dynamic dispatch.

  The function attempts to use tf.Print if all the values are compatible.
  Otherwise, it will fall back to py_func.

  Args:
    *values: values to print
  Returns:
    A dummy value indicating the print completed. If tf.
  """

    if all(map(is_tf_print_compatible, values)):
        return logging_ops.Print(1, values)

    def print_wrapper(*vals):
        if six.PY3:
            # TensorFlow doesn't seem to generate Unicode when passing strings to
            # py_func. This causes the print to add a "b'" wrapper to the output,
            # which is probably never what you want.
            vals = tuple(v.decode() if isinstance(v, bytes) else v
                         for v in vals)
        print(*vals)
        # The flush helps avoid garbled output in IPython.
        sys.stdout.flush()

    return py_func.wrap_py_func(print_wrapper,
                                None,
                                values,
                                use_dummy_return=True)
示例#9
0
 def py_func_wrapper(*args, **kwargs):
   if kwargs:
     raise NotImplementedError(
         'RunMode.PY_FUNC does not yet support kwargs')
   # TODO(mdan): Add support for kwargs.
   return py_func.wrap_py_func(
       f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes)
示例#10
0
    def test_wrap_py_func_dummy_return(self):

        side_counter = [0]

        def test_fn(_):
            side_counter[0] += 1

        with self.test_session() as sess:
            result = py_func.wrap_py_func(test_fn,
                                          None, (5, ),
                                          use_dummy_return=True)
            self.assertEqual(1, sess.run(result))
            self.assertEqual([1], side_counter)
            result = py_func.wrap_py_func(test_fn,
                                          None, (constant_op.constant(5), ),
                                          use_dummy_return=True)
            self.assertEqual(1, sess.run(result))
            self.assertEqual([2], side_counter)
示例#11
0
文件: api.py 项目: imdone/tensorflow
 def py_func_wrapper(*args, **kwargs):
     if kwargs:
         raise NotImplementedError(
             'RunMode.PY_FUNC does not yet support kwargs')
     # TODO (mdan): Add support for kwargs. id:515
     # https://github.com/imdone/tensorflow/issues/516
     return py_func.wrap_py_func(f,
                                 return_dtypes,
                                 args,
                                 kwargs,
                                 use_dummy_return=not return_dtypes)
示例#12
0
  def test_wrap_py_func_kwargs(self):

    class TestClass(object):

      def __init__(self, foo):
        self.foo = foo

    def test_fn(a, b, c, d):
      return a * b.foo + c * d.foo

    with self.test_session() as sess:
      result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), {
          'c': 11,
          'd': TestClass(13)
      })
      self.assertEqual(178, sess.run(result))
      result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                    (constant_op.constant(7), TestClass(5)), {
                                        'c': constant_op.constant(11),
                                        'd': TestClass(13)
                                    })
      self.assertEqual(178, sess.run(result))
示例#13
0
    def test_wrap_py_func_kwargs(self):
        class TestClass(object):
            def __init__(self, foo):
                self.foo = foo

        def test_fn(a, b, c, d):
            return a * b.foo + c * d.foo

        with self.test_session() as sess:
            result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                          (7, TestClass(5)), {
                                              'c': 11,
                                              'd': TestClass(13)
                                          })
            self.assertEqual(178, sess.run(result))
            result = py_func.wrap_py_func(
                test_fn, dtypes.int64, (constant_op.constant(7), TestClass(5)),
                {
                    'c': constant_op.constant(11),
                    'd': TestClass(13)
                })
            self.assertEqual(178, sess.run(result))
示例#14
0
def dynamic_print(*values):
    """Implementartion of print using dynamic dispatch.

  The function attempts to use tf.Print if all the values are compatible.
  Otherwise, it will fall back to py_func.

  Args:
    *values: values to print
  Returns:
    A dummy value indicating the print completed. If tf.
  """

    if all(map(is_tf_print_compatible, values)):
        return logging_ops.Print(1, values)
    return py_func.wrap_py_func(print, None, values, use_dummy_return=True)
示例#15
0
def _tf_py_func_print(objects, kwargs):
  """Overload of print_ as a py_func implementation."""
  override_kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED}
  if 'flush' not in override_kwargs:
    # Defaulting to flushing the console in graph mode, which helps reduce
    # garbled output in IPython.
    override_kwargs['flush'] = True

  def print_wrapper(*vals):
    if six.PY3:
      # TensorFlow doesn't seem to generate Unicode when passing strings to
      # py_func. This causes the print to add a "b'" wrapper to the output,
      # which is probably never what you want.
      vals = tuple(
          v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
    six.print_(*vals, **override_kwargs)

  return py_func.wrap_py_func(
      print_wrapper, None, objects, use_dummy_return=True)
示例#16
0
def dynamic_print(*values):
  """Implementation of print using dynamic dispatch.

  The function attempts to use tf.Print if all the values are compatible.
  Otherwise, it will fall back to py_func.

  Args:
    *values: values to print
  Returns:
    A dummy value indicating the print completed. If tf.
  """

  if all(map(is_tf_print_compatible, values)):
    return logging_ops.Print(1, values)

  def flushed_print(*vals):
    print(*vals)
    sys.stdout.flush()

  return py_func.wrap_py_func(
      flushed_print, None, values, use_dummy_return=True)
示例#17
0
def _tf_py_func_print(objects, kwargs):
    """Overload of print_ as a py_func implementation."""
    override_kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED}
    if 'flush' not in override_kwargs:
        # Defaulting to flushing the console in graph mode, which helps reduce
        # garbled output in IPython.
        override_kwargs['flush'] = True

    def print_wrapper(*vals):
        if six.PY3:
            # TensorFlow doesn't seem to generate Unicode when passing strings to
            # py_func. This causes the print to add a "b'" wrapper to the output,
            # which is probably never what you want.
            vals = tuple(
                v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
        six.print_(*vals, **override_kwargs)

    return py_func.wrap_py_func(print_wrapper,
                                None,
                                objects,
                                use_dummy_return=True)