def init_fn(*args, **kwargs) -> types.OutputStates: kwargs = {f"__{name}": value for name, value in kwargs.items()} key = rng.next() params, states = self.module.init(key, *args, **kwargs) y_pred, _ = self.module.apply(params, states, key, *args, **kwargs) return types.OutputStates(y_pred, params, states)
def _lambda(*args, **kwargs): y_pred, parameters, collections = utils.inject_dependencies( self.module.init(rng=rng))( *args, **kwargs, ) return types.OutputStates(y_pred, parameters, collections)
def _lambda(*args, **kwargs): y_pred, net_params, net_states = utils.inject_dependencies( self.module.apply(params, states, training=training, rng=rng), )( *args, **kwargs, ) return types.OutputStates(y_pred, net_params, net_states)
def _lambda(*args, **kwargs): def apply_fn(*args, **kwargs): kwargs = {f"__{name}": value for name, value in kwargs.items()} return self.module.apply(params, states, rng.next(), *args, **kwargs) y_pred, states_ = utils.inject_dependencies(apply_fn, signature_f=self.f,)( *args, **kwargs, ) return types.OutputStates(y_pred, params, states_)
def lambda_(*args, **kwargs) -> types.OutputStates: output = utils.inject_dependencies(self.f)(*args, **kwargs) if isinstance(output, types.OutputStates): return output else: return types.OutputStates( preds=output, params=types.UNINITIALIZED, states=types.UNINITIALIZED, )
def _lambda(*args, **kwargs): y_pred, collections = utils.inject_dependencies( self.module.init(rng=rng))( *args, **kwargs, ) assert isinstance(collections, dict) net_params = collections.pop("parameters", {}) net_states = collections return types.OutputStates(y_pred, net_params, net_states)
def _lambda(*args, **kwargs) -> types.OutputStates: preds = utils.inject_dependencies(self.f)(*args, **kwargs) if isinstance(preds, types.OutputStates): return preds n = 0 total = jax.tree_map(lambda x: jnp.zeros_like(x), preds) return types.OutputStates( preds=preds, params=None, states=(n, total), )
def _lambda(*args, **kwargs) -> types.OutputStates: preds = utils.inject_dependencies(self.f)(*args, **kwargs) if isinstance(preds, types.OutputStates): return preds n, total = states n += 1 total = jax.tree_multimap(lambda a, b: a + b, preds, total) preds = jax.tree_map(lambda total: total / n, total) return types.OutputStates( preds=preds, params=None, states=(n, total), )
def _lambda(*args, **kwargs): def init_fn(*args, **kwargs) -> types.OutputStates: kwargs = {f"__{name}": value for name, value in kwargs.items()} key = rng.next() y_pred, params, states = self.module.init(key, *args, **kwargs) return types.OutputStates(y_pred, params, states) y_pred, params, states = utils.inject_dependencies( init_fn, signature_f=self.f, )( *args, **kwargs, ) return types.OutputStates(y_pred, params, states)
def _lambda(*args, **kwargs): collections = states.copy() if states is not None else {} if params is not None: collections["parameters"] = params y_pred, collections = utils.inject_dependencies( self.module.apply(collections, training=training, rng=rng), )( *args, **kwargs, ) assert isinstance(collections, dict) net_params = collections.pop("parameters", {}) net_states = collections return types.OutputStates(y_pred, net_params, net_states)
def _lambda(*args, **kwargs): def init_fn(*args, **kwargs): return self.module.init_with_output(rng.next(), *args, **kwargs) y_pred, variables = utils.inject_dependencies( init_fn, signature_f=self.module.__call__, )( *args, **kwargs, ) assert isinstance(variables, FrozenDict) net_states, net_params = (variables.pop("params") if "params" in variables else (variables, FrozenDict())) return types.OutputStates(y_pred, net_params, net_states)
def _lambda(*args, **kwargs): def apply_fn(*args, **kwargs): variables = dict(params=params, **states) return self.module.apply( variables, *args, rngs={"params": rng.next()}, mutable=True, **kwargs, ) y_pred, variables = utils.inject_dependencies( apply_fn, signature_f=self.module.__call__, )( *args, **kwargs, ) net_states, net_params = (variables.pop("params") if "params" in variables else (variables, FrozenDict())) return types.OutputStates(y_pred, net_params, net_states)