Beispiel #1
0
 def test_dynamic_range_all_python(self):
     self.assertListEqual(list(builtins.dynamic_builtin(range, 3)),
                          [0, 1, 2])
     self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)),
                          [1, 2])
     self.assertListEqual(list(builtins.dynamic_builtin(range, 2, 0, -1)),
                          [2, 1])
 def test_dynamic_range_tf(self):
   with self.test_session() as sess:
     self.assertAllEqual(
         sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))),
         [0, 1, 2])
     self.assertAllEqual(
         sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))),
         [1, 2])
     self.assertAllEqual(
         sess.run(
             builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))),
         [2, 1])
Beispiel #3
0
 def test_dynamic_range_tf(self):
   with self.test_session() as sess:
     self.assertAllEqual(
         sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))),
         [0, 1, 2])
     self.assertAllEqual(
         sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))),
         [1, 2])
     self.assertAllEqual(
         sess.run(
             builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))),
         [2, 1])
  def test_dynamic_range_detection(self):
    def range(x):  # pylint:disable=redefined-builtin
      return x

    # Functions that just have the names of builtins are ignored.
    self.assertEqual(builtins.dynamic_builtin(range, 1), 1)
    if six.PY2:
      self.assertListEqual(
          list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2])
    self.assertListEqual(
        list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2])
    self.assertListEqual(
        list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2])
  def test_dynamic_range_detection(self):
    def range(x):  # pylint:disable=redefined-builtin
      return x

    # Functions that just have the names of builtins are rejected.
    with self.assertRaises(NotImplementedError):
      self.assertEqual(builtins.dynamic_builtin(range, 1), 1)
    if six.PY2:
      self.assertListEqual(
          list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2])
    self.assertListEqual(
        list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2])
    self.assertListEqual(
        list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2])
Beispiel #6
0
    def test_dynamic_range_detection(self):
        def range(x):  # pylint:disable=redefined-builtin
            return x

        # Functions that just have the names of builtins are rejected.
        with self.assertRaises(NotImplementedError):
            self.assertEqual(builtins.dynamic_builtin(range, 1), 1)
        if six.PY2:
            self.assertListEqual(list(builtins.dynamic_builtin(xrange, 3)),
                                 [0, 1, 2])
        self.assertListEqual(
            list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2])
        self.assertListEqual(
            list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2])
  def test_dynamic_len_tf_scalar(self):
    a = constant_op.constant(1)

    with self.assertRaisesRegexp(ValueError,
                                 'len requires non-zero rank for tensor.*'):
      with self.test_session() as sess:
        sess.run(builtins.dynamic_builtin(len, a))
Beispiel #8
0
    def test_dynamic_len_tf_scalar(self):
        a = constant_op.constant(1)

        with self.assertRaisesRegexp(
                ValueError, 'len requires non-zero rank for tensor.*'):
            with self.test_session() as sess:
                sess.run(builtins.dynamic_builtin(len, a))
Beispiel #9
0
    def test_dynamic_abs_tf_array(self):
        a = constant_op.constant([-1, 2, -3])

        with self.test_session() as sess:
            self.assertListEqual([1, 2, 3],
                                 list(
                                     sess.run(builtins.dynamic_builtin(abs,
                                                                       a))))
Beispiel #10
0
  def test_casts(self):
    i = constant_op.constant(2, dtype=dtypes.int32)
    f = constant_op.constant(1.0, dtype=dtypes.float32)

    self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32)
    self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32)
    self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32)
    self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32)

    self.assertEqual(builtins.dynamic_builtin(int, True), 1)
    self.assertEqual(builtins.dynamic_builtin(int, False), 0)
    self.assertEqual(builtins.dynamic_builtin(float, True), 1.0)
    self.assertEqual(builtins.dynamic_builtin(float, False), 0.0)
Beispiel #11
0
  def test_casts(self):
    i = constant_op.constant(2, dtype=dtypes.int32)
    f = constant_op.constant(1.0, dtype=dtypes.float32)

    self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32)
    self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32)
    self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32)
    self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32)

    self.assertEqual(builtins.dynamic_builtin(int, True), 1)
    self.assertEqual(builtins.dynamic_builtin(int, False), 0)
    self.assertEqual(builtins.dynamic_builtin(float, True), 1.0)
    self.assertEqual(builtins.dynamic_builtin(float, False), 0.0)
Beispiel #12
0
 def test_dynamic_abs_py_scalar(self):
     a = -1
     self.assertEqual(1, builtins.dynamic_builtin(abs, a))
Beispiel #13
0
 def test_dynamic_range_all_python(self):
   self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2])
   self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2])
   self.assertListEqual(
       list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1])
Beispiel #14
0
  def test_dynamic_len_py_list(self):
    a = [3] * 5

    self.assertEqual(5, builtins.dynamic_builtin(len, a))
Beispiel #15
0
  def test_dynamic_len_tf_matrix(self):
    a = constant_op.constant([[1, 2], [3, 4]])

    with self.test_session() as sess:
      self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a)))
Beispiel #16
0
 def test_dynamic_abs_py_scalar(self):
   a = -1
   self.assertEqual(1, builtins.dynamic_builtin(abs, a))
Beispiel #17
0
  def test_dynamic_abs_tf_array(self):
    a = constant_op.constant([-1, 2, -3])

    with self.test_session() as sess:
      self.assertListEqual([1, 2, 3],
                           list(sess.run(builtins.dynamic_builtin(abs, a))))
Beispiel #18
0
  def test_dynamic_abs_tf_scalar(self):
    a = constant_op.constant(-1)

    with self.test_session() as sess:
      self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a)))
Beispiel #19
0
  def test_dynamic_len_tf_array(self):
    a = constant_op.constant([1, 2, 3])

    with self.test_session() as sess:
      self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a)))
Beispiel #20
0
  def test_dynamic_len_tf_scalar(self):
    a = constant_op.constant(1)

    with self.assertRaises(ValueError):
      with self.test_session() as sess:
        sess.run(builtins.dynamic_builtin(len, a))
Beispiel #21
0
def converted_call(f, recursive, verbose, arg_types, *args, **kwargs):
  """Compiles a function call inline."""
  # TODO(mdan): This needs cleanup.
  # In particular, we may want to avoid renaming functions altogether.

  if conversion.is_whitelisted_for_graph(f):
    return f(*args, **kwargs)

  unknown_arg_value = object()  # Sentinel for arguments of unknown value

  if inspect_utils.isbuiltin(f):
    return builtins.dynamic_builtin(f, *args, **kwargs)

  if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
    # Regular functions
    target_entity = f
    arg_map_target = f
    effective_args = args
    f_class = inspect_utils.getmethodclass(f)

    if f_class is not None:
      partial_types = (f_class,)
    else:
      partial_types = ()

  elif tf_inspect.isclass(f):
    # Constructors
    target_entity = f
    arg_map_target = f.__init__
    effective_args = args
    partial_types = ()

  elif hasattr(f, '__call__') and hasattr(f, '__class__'):
    # Callable objects
    target_entity = f.__call__
    arg_map_target = f.__call__
    effective_args = (f,) + args
    partial_types = (f.__class__,)

  else:
    NotImplementedError('unknown callable type "%s"' % type(f))

  arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
  for name, arg in arg_values.items():
    if arg is unknown_arg_value:
      continue
    arg_class = arg.__class__
    # If arg_value_hints specifies any name, use that instead.
    if name not in arg_types:
      arg_types[name] = (arg_class.__name__, arg_class)

  # When called from within a decorator, this is the only indication that
  # the function is a method - it appears that the decorator is applied
  # before the method is bound.
  if not partial_types:
    if 'self' in arg_values:
      if tf_inspect.isclass(arg_values['self'].__class__):
        partial_types = (arg_values['self'].__class__,)
    elif 'cls' in arg_values:
      if tf_inspect.isclass(arg_values['cls']):
        partial_types = (arg_values['cls'],)

  converted_f = to_graph(
      target_entity,
      recursive=recursive,
      verbose=verbose,
      arg_values=arg_values,
      arg_types=arg_types,
      partial_types=partial_types)
  return converted_f(*effective_args, **kwargs)
Beispiel #22
0
    def test_dynamic_len_py_list(self):
        a = [3] * 5

        self.assertEqual(5, builtins.dynamic_builtin(len, a))
Beispiel #23
0
    def test_dynamic_len_tf_matrix(self):
        a = constant_op.constant([[1, 2], [3, 4]])

        with self.test_session() as sess:
            self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a)))
Beispiel #24
0
    def test_dynamic_len_tf_array(self):
        a = constant_op.constant([1, 2, 3])

        with self.test_session() as sess:
            self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a)))
Beispiel #25
0
def converted_call(f, recursive, verbose, arg_types, *args, **kwargs):
  """Compiles a function call inline."""
  # TODO(mdan): This needs cleanup.
  # In particular, we may want to avoid renaming functions altogether.

  if conversion.is_whitelisted_for_graph(f):
    return f(*args, **kwargs)

  unknown_arg_value = object()  # Sentinel for arguments of unknown value

  if inspect_utils.isbuiltin(f):
    return builtins.dynamic_builtin(f, *args, **kwargs)

  if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
    # Regular functions
    target_entity = f
    arg_map_target = f
    effective_args = args
    f_class = inspect_utils.getmethodclass(f)

    if f_class is not None:
      partial_types = (f_class,)
    else:
      partial_types = ()

  elif tf_inspect.isclass(f):
    # Constructors
    target_entity = f
    arg_map_target = f.__init__
    effective_args = args
    partial_types = ()

  elif hasattr(f, '__call__') and hasattr(f, '__class__'):
    # Callable objects
    target_entity = f.__call__
    arg_map_target = f.__call__
    effective_args = (f,) + args
    partial_types = (f.__class__,)

  else:
    NotImplementedError('unknown callable type "%s"' % type(f))

  arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
  for name, arg in arg_values.items():
    if arg is unknown_arg_value:
      continue
    arg_class = arg.__class__
    # If arg_value_hints specifies any name, use that instead.
    if name not in arg_types:
      arg_types[name] = (arg_class.__name__, arg_class)

  # When called from within a decorator, this is the only indication that
  # the function is a method - it appears that the decorator is applied
  # before the method is bound.
  if not partial_types:
    if 'self' in arg_values:
      if tf_inspect.isclass(arg_values['self'].__class__):
        partial_types = (arg_values['self'].__class__,)
    elif 'cls' in arg_values:
      if tf_inspect.isclass(arg_values['cls']):
        partial_types = (arg_values['cls'],)

  converted_f = to_graph(
      target_entity,
      recursive=recursive,
      verbose=verbose,
      arg_values=arg_values,
      arg_types=arg_types,
      partial_types=partial_types)
  return converted_f(*effective_args, **kwargs)
Beispiel #26
0
    def test_dynamic_abs_tf_scalar(self):
        a = constant_op.constant(-1)

        with self.test_session() as sess:
            self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a)))