Пример #1
0
 def _call_and_update_batched(self, *args, has_rng=False, **kwargs):
   if has_rng:
     rng, args = args[0], args[1:]
     kwargs = dict(kwargs, rng=rng)
   call_kwargs = kwargs_util.filter_kwargs(self._call_batched, kwargs)
   update_kwargs = kwargs_util.filter_kwargs(self._update_batched, kwargs)
   layer = self.replace(state=lax.stop_gradient(self.state))
   return (layer._call_batched(*args, **call_kwargs),  # pylint: disable=protected-access
           layer._update_batched(*args, **update_kwargs))  # pylint: disable=protected-access
Пример #2
0
 def _call_and_update(self, *args, has_rng=False, **kwargs):
   """Runs and returns the `Layer`'s `_call` and `_update` functions."""
   if has_rng:
     rng, args = args[0], args[1:]
     kwargs = dict(kwargs, rng=rng)
   call_kwargs = kwargs_util.filter_kwargs(self._call, kwargs)
   update_kwargs = kwargs_util.filter_kwargs(self._update, kwargs)
   layer = self.replace(state=lax.stop_gradient(self.state))
   return (layer._call(*args, **call_kwargs),  # pylint: disable=protected-access
           layer._update(*args, **update_kwargs))  # pylint: disable=protected-access
Пример #3
0
 def test_filter_kwargs_accepts_all(self):
   kwargs = dict(
       rng=1,
       training=True
   )
   def foo1(x, y, **kwargs):
     del kwargs
     return x + y
   self.assertDictEqual(
       kwargs_util.filter_kwargs(foo1, kwargs),
       {'rng': 1, 'training': True})
   def foo2(x, y, training=False, **kwargs):
     del training, kwargs
     return x + y
   self.assertDictEqual(
       kwargs_util.filter_kwargs(foo2, kwargs),
       {'rng': 1, 'training': True})
Пример #4
0
def _layer_cau_batched(layer, *args, **kwargs):
    kwargs = kwargs.copy()
    has_rng = kwargs.pop('has_rng', False)
    layer = layer.replace(state=lax.stop_gradient(layer.state))
    if has_rng:
        rng, args = args[0], args[1:]
        kwargs['rng'] = rng
    kwargs = kwargs_util.filter_kwargs(layer._call_and_update_batched, kwargs)  # pylint: disable=protected-access
    return layer._call_and_update_batched(*args, **kwargs)  # pylint: disable=protected-access
    def test_filter_incomplete_kwargs(self):
        kwargs = dict(rng=1, )

        def foo1(x, y):
            return x + y

        self.assertDictEqual(kwargs_util.filter_kwargs(foo1, kwargs), {})

        def foo2(x, y, rng=None):
            del rng
            return x + y

        self.assertDictEqual(kwargs_util.filter_kwargs(foo2, kwargs),
                             {'rng': 1})

        def foo3(x, y, rng=None, training=False):
            del rng, training
            return x + y

        self.assertDictEqual(kwargs_util.filter_kwargs(foo3, kwargs),
                             {'rng': 1})
Пример #6
0
 def wrapped(*args, **kwargs):
     in_specs = tree_util.tree_map(state.make_array_spec, args)
     kwargs = kwargs_util.filter_kwargs(template._spec, kwargs)  # pylint: disable=protected-access
     out_specs = template._spec(*in_specs, **kwargs)  # pylint: disable=protected-access
     return tree_util.tree_map(state.make_array_spec, out_specs)