def test_dynamic_range_all_python(self): self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2]) self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2]) self.assertListEqual(list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1])
def test_dynamic_range_tf(self): with self.test_session() as sess: self.assertAllEqual( sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))), [0, 1, 2]) self.assertAllEqual( sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))), [1, 2]) self.assertAllEqual( sess.run( builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))), [2, 1])
def test_dynamic_range_tf(self): with self.test_session() as sess: self.assertAllEqual( sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))), [0, 1, 2]) self.assertAllEqual( sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))), [1, 2]) self.assertAllEqual( sess.run( builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))), [2, 1])
def test_dynamic_range_detection(self): def range(x): # pylint:disable=redefined-builtin return x # Functions that just have the names of builtins are ignored. self.assertEqual(builtins.dynamic_builtin(range, 1), 1) if six.PY2: self.assertListEqual( list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2]) self.assertListEqual( list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2]) self.assertListEqual( list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2])
def test_dynamic_range_detection(self): def range(x): # pylint:disable=redefined-builtin return x # Functions that just have the names of builtins are rejected. with self.assertRaises(NotImplementedError): self.assertEqual(builtins.dynamic_builtin(range, 1), 1) if six.PY2: self.assertListEqual( list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2]) self.assertListEqual( list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2]) self.assertListEqual( list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2])
def test_dynamic_range_detection(self): def range(x): # pylint:disable=redefined-builtin return x # Functions that just have the names of builtins are rejected. with self.assertRaises(NotImplementedError): self.assertEqual(builtins.dynamic_builtin(range, 1), 1) if six.PY2: self.assertListEqual(list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2]) self.assertListEqual( list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2]) self.assertListEqual( list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2])
def test_dynamic_len_tf_scalar(self): a = constant_op.constant(1) with self.assertRaisesRegexp(ValueError, 'len requires non-zero rank for tensor.*'): with self.test_session() as sess: sess.run(builtins.dynamic_builtin(len, a))
def test_dynamic_len_tf_scalar(self): a = constant_op.constant(1) with self.assertRaisesRegexp( ValueError, 'len requires non-zero rank for tensor.*'): with self.test_session() as sess: sess.run(builtins.dynamic_builtin(len, a))
def test_dynamic_abs_tf_array(self): a = constant_op.constant([-1, 2, -3]) with self.test_session() as sess: self.assertListEqual([1, 2, 3], list( sess.run(builtins.dynamic_builtin(abs, a))))
def test_casts(self): i = constant_op.constant(2, dtype=dtypes.int32) f = constant_op.constant(1.0, dtype=dtypes.float32) self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32) self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32) self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32) self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32) self.assertEqual(builtins.dynamic_builtin(int, True), 1) self.assertEqual(builtins.dynamic_builtin(int, False), 0) self.assertEqual(builtins.dynamic_builtin(float, True), 1.0) self.assertEqual(builtins.dynamic_builtin(float, False), 0.0)
def test_casts(self): i = constant_op.constant(2, dtype=dtypes.int32) f = constant_op.constant(1.0, dtype=dtypes.float32) self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32) self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32) self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32) self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32) self.assertEqual(builtins.dynamic_builtin(int, True), 1) self.assertEqual(builtins.dynamic_builtin(int, False), 0) self.assertEqual(builtins.dynamic_builtin(float, True), 1.0) self.assertEqual(builtins.dynamic_builtin(float, False), 0.0)
def test_dynamic_abs_py_scalar(self): a = -1 self.assertEqual(1, builtins.dynamic_builtin(abs, a))
def test_dynamic_range_all_python(self): self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2]) self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2]) self.assertListEqual( list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1])
def test_dynamic_len_py_list(self): a = [3] * 5 self.assertEqual(5, builtins.dynamic_builtin(len, a))
def test_dynamic_len_tf_matrix(self): a = constant_op.constant([[1, 2], [3, 4]]) with self.test_session() as sess: self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a)))
def test_dynamic_abs_py_scalar(self): a = -1 self.assertEqual(1, builtins.dynamic_builtin(abs, a))
def test_dynamic_abs_tf_array(self): a = constant_op.constant([-1, 2, -3]) with self.test_session() as sess: self.assertListEqual([1, 2, 3], list(sess.run(builtins.dynamic_builtin(abs, a))))
def test_dynamic_abs_tf_scalar(self): a = constant_op.constant(-1) with self.test_session() as sess: self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a)))
def test_dynamic_len_tf_array(self): a = constant_op.constant([1, 2, 3]) with self.test_session() as sess: self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a)))
def test_dynamic_len_tf_scalar(self): a = constant_op.constant(1) with self.assertRaises(ValueError): with self.test_session() as sess: sess.run(builtins.dynamic_builtin(len, a))
def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): """Compiles a function call inline.""" # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. if conversion.is_whitelisted_for_graph(f): return f(*args, **kwargs) unknown_arg_value = object() # Sentinel for arguments of unknown value if inspect_utils.isbuiltin(f): return builtins.dynamic_builtin(f, *args, **kwargs) if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions target_entity = f arg_map_target = f effective_args = args f_class = inspect_utils.getmethodclass(f) if f_class is not None: partial_types = (f_class,) else: partial_types = () elif tf_inspect.isclass(f): # Constructors target_entity = f arg_map_target = f.__init__ effective_args = args partial_types = () elif hasattr(f, '__call__') and hasattr(f, '__class__'): # Callable objects target_entity = f.__call__ arg_map_target = f.__call__ effective_args = (f,) + args partial_types = (f.__class__,) else: NotImplementedError('unknown callable type "%s"' % type(f)) arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) for name, arg in arg_values.items(): if arg is unknown_arg_value: continue arg_class = arg.__class__ # If arg_value_hints specifies any name, use that instead. if name not in arg_types: arg_types[name] = (arg_class.__name__, arg_class) # When called from within a decorator, this is the only indication that # the function is a method - it appears that the decorator is applied # before the method is bound. if not partial_types: if 'self' in arg_values: if tf_inspect.isclass(arg_values['self'].__class__): partial_types = (arg_values['self'].__class__,) elif 'cls' in arg_values: if tf_inspect.isclass(arg_values['cls']): partial_types = (arg_values['cls'],) converted_f = to_graph( target_entity, recursive=recursive, verbose=verbose, arg_values=arg_values, arg_types=arg_types, partial_types=partial_types) return converted_f(*effective_args, **kwargs)
def test_dynamic_len_py_list(self): a = [3] * 5 self.assertEqual(5, builtins.dynamic_builtin(len, a))
def test_dynamic_len_tf_matrix(self): a = constant_op.constant([[1, 2], [3, 4]]) with self.test_session() as sess: self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a)))
def test_dynamic_len_tf_array(self): a = constant_op.constant([1, 2, 3]) with self.test_session() as sess: self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a)))
def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): """Compiles a function call inline.""" # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. if conversion.is_whitelisted_for_graph(f): return f(*args, **kwargs) unknown_arg_value = object() # Sentinel for arguments of unknown value if inspect_utils.isbuiltin(f): return builtins.dynamic_builtin(f, *args, **kwargs) if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions target_entity = f arg_map_target = f effective_args = args f_class = inspect_utils.getmethodclass(f) if f_class is not None: partial_types = (f_class,) else: partial_types = () elif tf_inspect.isclass(f): # Constructors target_entity = f arg_map_target = f.__init__ effective_args = args partial_types = () elif hasattr(f, '__call__') and hasattr(f, '__class__'): # Callable objects target_entity = f.__call__ arg_map_target = f.__call__ effective_args = (f,) + args partial_types = (f.__class__,) else: NotImplementedError('unknown callable type "%s"' % type(f)) arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) for name, arg in arg_values.items(): if arg is unknown_arg_value: continue arg_class = arg.__class__ # If arg_value_hints specifies any name, use that instead. if name not in arg_types: arg_types[name] = (arg_class.__name__, arg_class) # When called from within a decorator, this is the only indication that # the function is a method - it appears that the decorator is applied # before the method is bound. if not partial_types: if 'self' in arg_values: if tf_inspect.isclass(arg_values['self'].__class__): partial_types = (arg_values['self'].__class__,) elif 'cls' in arg_values: if tf_inspect.isclass(arg_values['cls']): partial_types = (arg_values['cls'],) converted_f = to_graph( target_entity, recursive=recursive, verbose=verbose, arg_values=arg_values, arg_types=arg_types, partial_types=partial_types) return converted_f(*effective_args, **kwargs)
def test_dynamic_abs_tf_scalar(self): a = constant_op.constant(-1) with self.test_session() as sess: self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a)))