コード例 #1
0
 def test_func(self):
     id_ = lambda x: x
     in_spec = api.Shape(())
     out_spec = api.spec(id_)(in_spec)
     func = api.init(id_)(self._seed, in_spec)
     self.assertEqual(out_spec, in_spec)
     self.assertEqual(api.call(func, 1.), 1.)
コード例 #2
0
 def test_tuple(self):
     id_ = lambda x: x
     in_spec = api.Shape(())
     out_spec = api.spec((id_, id_, id_))(in_spec)
     tupl = api.init((id_, id_, id_))(self._seed, in_spec)
     self.assertTupleEqual(out_spec, (in_spec, ) * 3)
     self.assertTupleEqual(api.call(tupl, 1.), (1., 1., 1.))
コード例 #3
0
 def test_list(self):
     func = lambda x: x + 1
     in_spec = api.Shape(())
     out_spec = api.spec([func, func, func])(in_spec)
     lst = api.init([func, func, func])(self._seed, in_spec)
     self.assertEqual(out_spec, in_spec)
     self.assertEqual(api.call(lst, 1.), 4.)
コード例 #4
0
ファイル: function_test.py プロジェクト: seanmb/probability
    def test_vmap_of_init_should_return_ensemble(self):
        def f(x, init_key=None):
            w = module.variable(random.normal(init_key, x.shape), name='w')
            return np.dot(w, x)

        ensemble = jax.vmap(api.init(f))(random.split(random.PRNGKey(0)),
                                         np.ones([2, 5]))
        self.assertTupleEqual(ensemble.w.shape, (2, 5))
        onp.testing.assert_allclose(jax.vmap(api.call,
                                             in_axes=(0, None))(ensemble,
                                                                np.ones(5)),
                                    jax.vmap(lambda key, x: f(x, init_key=key),
                                             in_axes=(0, None))(random.split(
                                                 random.PRNGKey(0)),
                                                                np.ones(5)),
                                    rtol=1e-5,
                                    atol=1e-5)
        onp.testing.assert_allclose(
            jax.vmap(api.call, in_axes=(0, 0))(ensemble,
                                               np.arange(10.).reshape((2, 5))),
            jax.vmap(lambda key, x: f(x, init_key=key),
                     in_axes=(0, 0))(random.split(random.PRNGKey(0)),
                                     np.arange(10.).reshape((2, 5))),
            rtol=1e-5,
            atol=1e-5)
コード例 #5
0
ファイル: function_test.py プロジェクト: seanmb/probability
    def test_should_pass_kwarg_into_primitive(self):
        def f(x):
            return training_add_p.bind(x)

        m = api.init(f)(random.PRNGKey(0), 1.)
        self.assertEqual(m(1.), 2.)
        self.assertEqual(m(1., training=True), 2.)
        self.assertEqual(m(1., training=False), 1.)
コード例 #6
0
ファイル: function_test.py プロジェクト: seanmb/probability
    def test_init_nonstateful_function(self):
        def f(x):
            return x

        m = api.init(f)(random.PRNGKey(0), 1.)
        self.assertIsInstance(m, function.FunctionModule)
        self.assertDictEqual(m.variables(), {})
        self.assertEqual(m.call(1.), 1.)
        self.assertEqual(m.update(1.).variables(), {})
コード例 #7
0
 def test_list_multiple_args(self):
     func1 = lambda x: (x, x + 1)
     func2 = lambda x, y: (y, x + 1)
     func3 = lambda x, y: (x + 1, y + 2, x + y)
     in_spec = api.Shape(())
     out_spec = api.spec([func1, func2, func3])(in_spec)
     lst = api.init([func1, func2, func3])(self._seed, in_spec)
     self.assertEqual(out_spec, (api.Shape(()), ) * 3)
     self.assertTupleEqual(api.call(lst, 1.), (3., 4., 4.))
コード例 #8
0
ファイル: function_test.py プロジェクト: seanmb/probability
    def test_assign_with_no_matching_variable_should_error(self):
        def f(x, init_key=None):
            y = module.variable(np.zeros(x.shape), name='y', key=init_key)
            next_y = module.assign(y + 1., name='z')
            return x + next_y

        m = api.init(f)(random.PRNGKey(0), 1.)
        with self.assertRaisesRegex(ValueError,
                                    'No variable declared for assign: z'):
            m(1.)
コード例 #9
0
ファイル: function_test.py プロジェクト: seanmb/probability
    def test_init_stateful_function(self):
        def f(x, init_key=None):
            y = module.variable(np.ones(x.shape), name='y', key=init_key)
            return x + y

        m = api.init(f)(random.PRNGKey(0), 1.)
        self.assertIsInstance(m, function.FunctionModule)
        self.assertDictEqual(m.variables(), {'y': 1.})
        self.assertEqual(m.call(1.), 2.)
        self.assertEqual(m.update(1.).variables(), {'y': 1})
コード例 #10
0
 def wrapped(init_key, *args, **kwargs):
   if init_key is None:
     init_keys = (None,) * len(tupl)
   else:
     init_keys = random.split(init_key, len(tupl))
   names = [None if name is None else f'{name}_{i}' for i
            in range(len(tupl))]
   modules = tuple(api.init(t, name=name)(init_key, *args, **kwargs)
                   for t, name, init_key in zip(tupl, names, init_keys))
   return tuple.__new__(tupl.__class__, modules)
コード例 #11
0
ファイル: function_test.py プロジェクト: seanmb/probability
    def test_init_stateful_function_with_tied_in_assign(self):
        def f(x, init_key=None):
            y = module.variable(np.zeros(x.shape), name='y', key=init_key)
            next_y = module.assign(y + 1., name='y')
            return primitive.tie_in(next_y, x) + y

        m = api.init(f)(random.PRNGKey(0), 1.)
        self.assertIsInstance(m, function.FunctionModule)
        self.assertDictEqual(m.variables(), {'y': 0.})
        self.assertEqual(m.call(1.), 1.)
        self.assertDictEqual(m.update(1.).variables(), {'y': 1.})
コード例 #12
0
 def wrapped(init_key, *args, **kwargs):
   if init_key is not None:
     init_keys = random.split(init_key, len(inits))
   else:
     init_keys = (None,) * len(inits)
   modules = []
   for i, (elem_init, init_key) in enumerate(zip(inits, init_keys)):
     if isinstance(args, api.ArraySpec):
       args = (args,)
     elem_name = None if name is None else '{}_{}'.format(name, i)
     elem = api.init(elem_init, name=elem_name)(init_key, *args, **kwargs)  # pylint: disable=assignment-from-no-return
     args = api.spec(elem_init)(*args, **kwargs)  # pylint: disable=assignment-from-no-return
     modules.append(elem)
   return inits.__class__(modules)
コード例 #13
0
ファイル: function_test.py プロジェクト: seanmb/probability
    def test_init_of_nested_init_without_name_should_have_flat_params(self):
        def f(x, init_key=None):
            y = module.variable(np.zeros(x.shape), name='y', key=init_key)
            next_y = module.assign(y + 1., name='y')
            return primitive.tie_in(next_y, x) + y

        def g(x, init_key=None):
            return api.init(f)(init_key, x)(x)

        m = api.init(g)(random.PRNGKey(0), 1.)
        self.assertIsInstance(m, function.FunctionModule)
        self.assertDictEqual(m.variables(), {'y': 0.})
        self.assertEqual(m.call(1.), 1.)
        self.assertDictEqual(m.update(1.).variables(), {'y': 1.})
コード例 #14
0
ファイル: function_test.py プロジェクト: seanmb/probability
 def g(x, init_key=None):
     return api.init(f)(init_key, x)(x)
コード例 #15
0
ファイル: function_test.py プロジェクト: seanmb/probability
 def g(x, init_key=None):
     return api.init(f, name='f')(init_key, x)(x)
コード例 #16
0
 def test_empty_tuple_init(self):
     with self.assertRaises(ValueError):
         api.init(())(self._seed, api.Shape((50, )))
コード例 #17
0
ファイル: function_test.py プロジェクト: seanmb/probability
 def f_(init_key, x):
     return api.init(f)(init_key, x)