def test_list(self):
     func = lambda x: x + 1
     in_spec = api.Shape(())
     out_spec = api.spec([func, func, func])(in_spec)
     lst = api.init([func, func, func])(self._seed, in_spec)
     self.assertEqual(out_spec, in_spec)
     self.assertEqual(api.call(lst, 1.), 4.)
 def test_tuple(self):
     id_ = lambda x: x
     in_spec = api.Shape(())
     out_spec = api.spec((id_, id_, id_))(in_spec)
     tupl = api.init((id_, id_, id_))(self._seed, in_spec)
     self.assertTupleEqual(out_spec, (in_spec, ) * 3)
     self.assertTupleEqual(api.call(tupl, 1.), (1., 1., 1.))
 def test_func(self):
     id_ = lambda x: x
     in_spec = api.Shape(())
     out_spec = api.spec(id_)(in_spec)
     func = api.init(id_)(self._seed, in_spec)
     self.assertEqual(out_spec, in_spec)
     self.assertEqual(api.call(func, 1.), 1.)
예제 #4
0
    def test_spec_works_for_identity_function(self):
        def f(x):
            return x

        out_spec = api.spec(f)(np.ones(5))
        self.assertTupleEqual(out_spec.shape, (5, ))
        self.assertEqual(out_spec.dtype, np.float32)
 def test_list_multiple_args(self):
     func1 = lambda x: (x, x + 1)
     func2 = lambda x, y: (y, x + 1)
     func3 = lambda x, y: (x + 1, y + 2, x + y)
     in_spec = api.Shape(())
     out_spec = api.spec([func1, func2, func3])(in_spec)
     lst = api.init([func1, func2, func3])(self._seed, in_spec)
     self.assertEqual(out_spec, (api.Shape(()), ) * 3)
     self.assertTupleEqual(api.call(lst, 1.), (3., 4., 4.))
 def wrapped(init_key, *args, **kwargs):
   if init_key is not None:
     init_keys = random.split(init_key, len(inits))
   else:
     init_keys = (None,) * len(inits)
   modules = []
   for i, (elem_init, init_key) in enumerate(zip(inits, init_keys)):
     if isinstance(args, api.ArraySpec):
       args = (args,)
     elem_name = None if name is None else '{}_{}'.format(name, i)
     elem = api.init(elem_init, name=elem_name)(init_key, *args, **kwargs)  # pylint: disable=assignment-from-no-return
     args = api.spec(elem_init)(*args, **kwargs)  # pylint: disable=assignment-from-no-return
     modules.append(elem)
   return inits.__class__(modules)
예제 #7
0
 def wrapped(*specs, **kwargs):
     for elem_init in inits:
         if isinstance(specs, api.ArraySpec):
             specs = (specs, )
         specs = api.spec(elem_init)(*specs, **kwargs)  # pylint: disable=assignment-from-no-return
     return specs
예제 #8
0
 def wrapped(*args, **kwargs):
     args = tuple(api.spec(t)(*args, **kwargs) for t in tupl)
     return tuple.__new__(tupl.__class__, args)