def _call_and_update_batched(self, *args, has_rng=False, **kwargs): if has_rng: rng, args = args[0], args[1:] kwargs = dict(kwargs, rng=rng) call_kwargs = kwargs_util.filter_kwargs(self._call_batched, kwargs) update_kwargs = kwargs_util.filter_kwargs(self._update_batched, kwargs) layer = self.replace(state=lax.stop_gradient(self.state)) return (layer._call_batched(*args, **call_kwargs), # pylint: disable=protected-access layer._update_batched(*args, **update_kwargs)) # pylint: disable=protected-access
def _call_and_update(self, *args, has_rng=False, **kwargs): """Runs and returns the `Layer`'s `_call` and `_update` functions.""" if has_rng: rng, args = args[0], args[1:] kwargs = dict(kwargs, rng=rng) call_kwargs = kwargs_util.filter_kwargs(self._call, kwargs) update_kwargs = kwargs_util.filter_kwargs(self._update, kwargs) layer = self.replace(state=lax.stop_gradient(self.state)) return (layer._call(*args, **call_kwargs), # pylint: disable=protected-access layer._update(*args, **update_kwargs)) # pylint: disable=protected-access
def test_filter_kwargs_accepts_all(self): kwargs = dict( rng=1, training=True ) def foo1(x, y, **kwargs): del kwargs return x + y self.assertDictEqual( kwargs_util.filter_kwargs(foo1, kwargs), {'rng': 1, 'training': True}) def foo2(x, y, training=False, **kwargs): del training, kwargs return x + y self.assertDictEqual( kwargs_util.filter_kwargs(foo2, kwargs), {'rng': 1, 'training': True})
def _layer_cau_batched(layer, *args, **kwargs): kwargs = kwargs.copy() has_rng = kwargs.pop('has_rng', False) layer = layer.replace(state=lax.stop_gradient(layer.state)) if has_rng: rng, args = args[0], args[1:] kwargs['rng'] = rng kwargs = kwargs_util.filter_kwargs(layer._call_and_update_batched, kwargs) # pylint: disable=protected-access return layer._call_and_update_batched(*args, **kwargs) # pylint: disable=protected-access
def test_filter_incomplete_kwargs(self): kwargs = dict(rng=1, ) def foo1(x, y): return x + y self.assertDictEqual(kwargs_util.filter_kwargs(foo1, kwargs), {}) def foo2(x, y, rng=None): del rng return x + y self.assertDictEqual(kwargs_util.filter_kwargs(foo2, kwargs), {'rng': 1}) def foo3(x, y, rng=None, training=False): del rng, training return x + y self.assertDictEqual(kwargs_util.filter_kwargs(foo3, kwargs), {'rng': 1})
def wrapped(*args, **kwargs): in_specs = tree_util.tree_map(state.make_array_spec, args) kwargs = kwargs_util.filter_kwargs(template._spec, kwargs) # pylint: disable=protected-access out_specs = template._spec(*in_specs, **kwargs) # pylint: disable=protected-access return tree_util.tree_map(state.make_array_spec, out_specs)