Esempio n. 1
0
 def init_fn(*args, **kwargs) -> types.OutputStates:
     kwargs = {f"__{name}": value for name, value in kwargs.items()}
     key = rng.next()
     params, states = self.module.init(key, *args, **kwargs)
     y_pred, _ = self.module.apply(params, states, key, *args,
                                   **kwargs)
     return types.OutputStates(y_pred, params, states)
Esempio n. 2
0
        def _lambda(*args, **kwargs):

            y_pred, parameters, collections = utils.inject_dependencies(
                self.module.init(rng=rng))(
                    *args,
                    **kwargs,
                )

            return types.OutputStates(y_pred, parameters, collections)
Esempio n. 3
0
        def _lambda(*args, **kwargs):

            y_pred, net_params, net_states = utils.inject_dependencies(
                self.module.apply(params, states, training=training,
                                  rng=rng), )(
                                      *args,
                                      **kwargs,
                                  )

            return types.OutputStates(y_pred, net_params, net_states)
Esempio n. 4
0
        def _lambda(*args, **kwargs):
            def apply_fn(*args, **kwargs):
                kwargs = {f"__{name}": value for name, value in kwargs.items()}
                return self.module.apply(params, states, rng.next(), *args, **kwargs)

            y_pred, states_ = utils.inject_dependencies(apply_fn, signature_f=self.f,)(
                *args,
                **kwargs,
            )

            return types.OutputStates(y_pred, params, states_)
Esempio n. 5
0
        def lambda_(*args, **kwargs) -> types.OutputStates:

            output = utils.inject_dependencies(self.f)(*args, **kwargs)

            if isinstance(output, types.OutputStates):
                return output
            else:
                return types.OutputStates(
                    preds=output,
                    params=types.UNINITIALIZED,
                    states=types.UNINITIALIZED,
                )
Esempio n. 6
0
        def _lambda(*args, **kwargs):

            y_pred, collections = utils.inject_dependencies(
                self.module.init(rng=rng))(
                    *args,
                    **kwargs,
                )
            assert isinstance(collections, dict)

            net_params = collections.pop("parameters", {})
            net_states = collections

            return types.OutputStates(y_pred, net_params, net_states)
Esempio n. 7
0
        def _lambda(*args, **kwargs) -> types.OutputStates:

            preds = utils.inject_dependencies(self.f)(*args, **kwargs)

            if isinstance(preds, types.OutputStates):
                return preds

            n = 0
            total = jax.tree_map(lambda x: jnp.zeros_like(x), preds)
            return types.OutputStates(
                preds=preds,
                params=None,
                states=(n, total),
            )
Esempio n. 8
0
        def _lambda(*args, **kwargs) -> types.OutputStates:

            preds = utils.inject_dependencies(self.f)(*args, **kwargs)

            if isinstance(preds, types.OutputStates):
                return preds

            n, total = states
            n += 1
            total = jax.tree_multimap(lambda a, b: a + b, preds, total)
            preds = jax.tree_map(lambda total: total / n, total)
            return types.OutputStates(
                preds=preds,
                params=None,
                states=(n, total),
            )
Esempio n. 9
0
        def _lambda(*args, **kwargs):
            def init_fn(*args, **kwargs) -> types.OutputStates:
                kwargs = {f"__{name}": value for name, value in kwargs.items()}
                key = rng.next()
                y_pred, params, states = self.module.init(key, *args, **kwargs)
                return types.OutputStates(y_pred, params, states)

            y_pred, params, states = utils.inject_dependencies(
                init_fn,
                signature_f=self.f,
            )(
                *args,
                **kwargs,
            )

            return types.OutputStates(y_pred, params, states)
Esempio n. 10
0
        def _lambda(*args, **kwargs):
            collections = states.copy() if states is not None else {}
            if params is not None:
                collections["parameters"] = params

            y_pred, collections = utils.inject_dependencies(
                self.module.apply(collections, training=training, rng=rng), )(
                    *args,
                    **kwargs,
                )
            assert isinstance(collections, dict)

            net_params = collections.pop("parameters", {})
            net_states = collections

            return types.OutputStates(y_pred, net_params, net_states)
Esempio n. 11
0
        def _lambda(*args, **kwargs):
            def init_fn(*args, **kwargs):
                return self.module.init_with_output(rng.next(), *args,
                                                    **kwargs)

            y_pred, variables = utils.inject_dependencies(
                init_fn,
                signature_f=self.module.__call__,
            )(
                *args,
                **kwargs,
            )
            assert isinstance(variables, FrozenDict)

            net_states, net_params = (variables.pop("params")
                                      if "params" in variables else
                                      (variables, FrozenDict()))

            return types.OutputStates(y_pred, net_params, net_states)
Esempio n. 12
0
        def _lambda(*args, **kwargs):
            def apply_fn(*args, **kwargs):
                variables = dict(params=params, **states)
                return self.module.apply(
                    variables,
                    *args,
                    rngs={"params": rng.next()},
                    mutable=True,
                    **kwargs,
                )

            y_pred, variables = utils.inject_dependencies(
                apply_fn,
                signature_f=self.module.__call__,
            )(
                *args,
                **kwargs,
            )

            net_states, net_params = (variables.pop("params")
                                      if "params" in variables else
                                      (variables, FrozenDict()))

            return types.OutputStates(y_pred, net_params, net_states)