def test_function_of_function(self): # single arg, no default def fn(): return np.ones((3, 4)) + 2 f = Function(fn) f2 = Function(f) self.assertTrue(checkfn(f2))
def test_return_none(self): def f1(): pass def f2(x): x + 1 F1 = Function(f1) F2 = Function(f2)
def function(fn=None, **kwargs): """ Wraps a function with an AutoDiff Function instance, converting it to a symbolic representation. The function is compiled the first time it is called. Use: @function def python_function(...): return do_something() python_function(...) # calls compiled Function Pass keywords to Function: @function(force_floatX=True): def python_function(x=1, y=2): return do_something() """ if callable(fn): return Function(fn, **kwargs) else: def function_wrapper(pyfn): return Function(pyfn, **kwargs) return function_wrapper
def test_nested_fn_kwargs_def(self): def g(x): def f(x, y): return x + y return f(y=x, x=x+1) - x ** 2 h = Function(g) self.assertTrue(checkfn(h, 10))
def test_sig_var_args(self): # var args, no default def fn(x, y, *z): return x * y * sum(z) f = Function(fn) self.assertRaises(TypeError, f) self.assertRaises(TypeError, f, 2) self.assertRaises(TypeError, f, a=2, b=2) self.assertTrue(checkfn(f, 2, 3)) self.assertTrue(checkfn(f, 2, 3, 4)) self.assertTrue(checkfn(f, 2, 3, 4, 5)) # make sure function recompiles for different numbers of varargs f = Function(fn) self.assertTrue(checkfn(f, 2, 3, 4, 5, 6)) self.assertTrue(checkfn(f, 2, 3, 4)) self.assertTrue(checkfn(f, 2, 3, 4, 5))
def test_function_of_nested_def_vargs_kwargs(self): def fn2(*args, **kwargs): def fn(*args, **kwargs): return args[1] + kwargs['kw'] return fn(*args, **kwargs) f = Function(fn2) self.assertTrue(checkfn(f, 1.0, 2.0, kw=3.0, kw2=4.0))
def test_dict_arg(self): def f(x): return x + 1 def g(x): return f(x[1]) F = Function(g) self.assertTrue(checkfn(F, {1.0: 5.0}))
def test_nested_fn_call(self): def f(x, y): return x + y def g(x): return f(x, x + 1) - x ** 2 h = Function(g) self.assertTrue(checkfn(h, 10))
def checkfn(fn, var_ndim, *args): """Given a function and a list of ndim for each input variable, get a result and compare it to the Theano result.""" dim = [[4] * nd for nd in var_ndim] values = tuple([np.random.random(d) for d in dim]) F = Function(fn) py_result = fn(*(values + args)) sym_result = F(*(values + args)) return np.allclose(py_result, sym_result)
def test_sig_one_arg(self): # single arg, no default def fn(x): return x f = Function(fn) self.assertRaises(TypeError, f) self.assertRaises(TypeError, f, a=2) self.assertTrue(checkfn(f, 2)) self.assertTrue(checkfn(f, x=2))
def test_sig_default_var_args(self): # multiple var args, all default def fn(x=1, y=2, *z): return x * y * sum(z) f = Function(fn) self.assertTrue(checkfn(f)) self.assertTrue(checkfn(f, 1)) self.assertTrue(checkfn(f, 1, 2)) self.assertTrue(checkfn(f, 1, 2, 3)) self.assertTrue(checkfn(f, 1, 2, 3, 4))
def test_sig_mult_args(self): # multiple args, no default def fn(x, y): return x * y f = Function(fn) self.assertRaises(TypeError, f) self.assertRaises(TypeError, f, 2) self.assertRaises(TypeError, f, a=2, b=2) self.assertTrue(checkfn(f, 2, 3)) self.assertTrue(checkfn(f, y=4, x=5))
def test_tag(self): def f(x): y = tag(x + 2, 'y') z = y * 3 return z F = Function(f) self.assertFalse('y' in F.tags) F(10) self.assertTrue('y' in F.tags)
def test_caching(self): def fn(x, switch): if switch > 0: return x * 1 else: return x * 0 f_cached = Function(fn, use_cache=True) c_result_1 = f_cached(1, 1) c_result_2 = f_cached(1, -1) self.assertTrue(np.allclose(c_result_1, 1)) self.assertTrue(np.allclose(c_result_2, 1)) f_uncached = Function(fn, use_cache=False) uc_result_1 = f_uncached(1, 1) uc_result_2 = f_uncached(1, -1) self.assertTrue(np.allclose(uc_result_1, 1)) self.assertTrue(np.allclose(uc_result_2, 0))
def test_sig_kwargs(self): # kwargs def fn(**kwargs): x = kwargs['x'] y = kwargs['y'] z = kwargs['z'] return x * y * z f = Function(fn) self.assertRaises(KeyError, f) self.assertRaises(TypeError, f, 1) self.assertTrue(checkfn(f, x=1, y=2, z=3))
def test_randomstreams(self): """ Make sure random numbers are different with each call (previously the use of clone_get_equiv broke this, since there is a Theano bug when cloning randomstreams.) """ def f(): return np.random.random((10, 10)) F = Function(f) result1 = F() result2 = F() self.assertFalse(np.allclose(result1, result2))
def test_tag_arg(self): def f(x): y = tag(x + 2, 'x') z = y * 3 return z F = Function(f) self.assertFalse('x' in F.sym_vars) self.assertFalse('x' in F.tags) F(10) self.assertTrue('x' in F.sym_vars) self.assertTrue('x' in F.tags) self.assertTrue(F.sym_vars['x'] is not F.tags['x'])
def test_sig_default_args(self): # multiple args, one default def fn(x, y=2): return x * y f = Function(fn) self.assertRaises(TypeError, f) self.assertRaises(TypeError, f, y=3) self.assertTrue(checkfn(f, 2)) self.assertTrue(checkfn(f, 2, 3)) self.assertTrue(checkfn(f, y=4, x=5)) self.assertTrue(checkfn(f, x=5)) # multiple args, all default def fn(x=1, y=2): return x * y f = Function(fn) self.assertTrue(checkfn(f)) self.assertTrue(checkfn(f, 1)) self.assertTrue(checkfn(f, 1, 2)) self.assertTrue(checkfn(f, y=2, x=1)) self.assertTrue(checkfn(f, x=5)) self.assertTrue(checkfn(f, y=5))
def test_sig_varargs_kwargs(self): # varargs and kwargs def fn(a, *b, **kwargs): x = kwargs['x'] y = kwargs['y'] z = kwargs['z'] return x * y * z f = Function(fn) self.assertRaises(TypeError, f) self.assertRaises(KeyError, f, 1) self.assertRaises(TypeError, f, x=1, y=2, z=3) self.assertTrue(checkfn(f, 1, x=1, y=2, z=3)) self.assertTrue(checkfn(f, 1, 2, 3, x=1, y=2, z=3)) # varargs and kwargs, use varargs def fn(a, *b, **kwargs): x = kwargs['x'] y = kwargs['y'] z = kwargs['z'] return x * y * z * b[0] f = Function(fn) self.assertTrue(checkfn(f, 1, 2, x=1, y=2, z=3)) self.assertTrue(checkfn(f, 1, 2, 3, x=1, y=2, z=3))
def check(fn, *args, **kwargs): F = Function(fn) py_result = fn(*args, **kwargs) sym_result = F(*args, **kwargs) return np.allclose(py_result, sym_result)
def test_sig_no_arg(self): # single arg, no default def fn(): return np.ones((3, 4)) + 2 f = Function(fn) self.assertTrue(checkfn(f))
def function_wrapper(pyfn): return Function(pyfn, **kwargs)
def test_fn_constants(self): # access constant array def fn(x): return np.dot(x, np.ones((3, 4))) f = Function(fn) self.assertTrue(checkfn(f, np.ones((2, 3))))