Esempio n. 1
0
        def jax_fn(flat_states, inputs):
            states = jax.tree_unflatten(states_def, flat_states)

            y_pred, _ = utils.inject_dependencies(self.pred_step)(
                x=inputs, states=states, initializing=False, training=False)

            return y_pred
Esempio n. 2
0
 def call_train_step(
     self,
     x: tp.Any,
     y_true: tp.Any,
     mode: types.Mode,
     sample_weight: tp.Optional[np.ndarray],
     class_weight: tp.Optional[np.ndarray],
     states: types.States,
     initializing: bool,
 ) -> TrainStep:
     return utils.inject_dependencies(self.train_step)(
         x=x,
         y_true=y_true,
         mode=mode,
         net_params=states.net_params,
         net_states=states.net_states,
         metrics_states=states.metrics_states,
         optimizer_states=states.optimizer_states,
         sample_weight=sample_weight,
         class_weight=class_weight,
         rng=states.rng,
         states=states,
         initializing=initializing,
         training=(mode == types.Mode.train),
     )
Esempio n. 3
0
def apply_recursive(context: tp.Tuple[str, ...], metrics, **kwargs):

    if isinstance(metrics, tp.Callable):

        name = (
            metrics.name
            if isinstance(metrics, module.Module)
            else utils.lower_snake_case(metrics.__name__)
        )
        context += (name,)
        value = utils.inject_dependencies(metrics)(**kwargs)

        if isinstance(value, tp.Dict):
            for name, value in value.items():
                yield context + (name,), value
        else:
            yield context, value

    elif isinstance(metrics, (tp.Tuple, tp.List)):
        for loss in metrics:
            yield from apply_recursive(context, loss, **kwargs)
    elif isinstance(metrics, tp.Dict):
        for name, loss in metrics.items():
            yield from apply_recursive(context + (name,), loss, **kwargs)
    else:
        raise TypeError(f"Invalid type {type(metrics)}")
Esempio n. 4
0
    def apply_recursive(self, context: tp.Tuple[str, ...], losses, **kwargs):

        if isinstance(losses, tp.Callable):
            name = (
                losses.name
                if isinstance(losses, Loss)
                else utils.lower_snake_case(losses.__name__)
            )
            context += (name,)
            val = utils.inject_dependencies(losses)(**kwargs)

            if isinstance(val, tp.Dict):
                for name, val in val.items():
                    yield context + (name,), val
            else:
                yield context, val

        elif isinstance(losses, (tp.Tuple, tp.List)):
            for loss in losses:
                yield from self.apply_recursive(context, loss, **kwargs)
        elif isinstance(losses, tp.Dict):
            for name, loss in losses.items():
                yield from self.apply_recursive(context + (name,), loss, **kwargs)
        else:
            raise TypeError(f"Invalid type {type(losses)}")
Esempio n. 5
0
    def call_pred_step(
        self,
        x: tp.Any,
        mode: types.Mode,
        states: types.States,
        initializing: bool,
    ) -> PredStep:
        get_losses_and_metrics = mode in (types.Mode.test, types.Mode.train)
        get_summaries = mode == types.Mode.summary

        with hooks.context(
                losses=get_losses_and_metrics,
                metrics=get_losses_and_metrics,
                summaries=get_summaries,
        ):
            return utils.inject_dependencies(self.pred_step)(
                x=x,
                mode=mode,
                net_params=states.net_params,
                net_states=states.net_states,
                rng=states.rng,
                states=states,
                initializing=initializing,
                training=(mode == types.Mode.train),
            )
Esempio n. 6
0
 def __init__(
     self,
     reduction: tp.Optional[Reduction] = None,
     name: tp.Optional[str] = None,
     weight: tp.Optional[float] = None,
     on: tp.Optional[types.IndexLike] = None,
 ):
     """
     Initializes `Loss` class.
     
     Arguments:
         reduction: (Optional) Type of `elegy.losses.Reduction` to apply to
             loss. Default value is `SUM_OVER_BATCH_SIZE`. For almost all cases
             this defaults to `SUM_OVER_BATCH_SIZE`. When used with
             `tf.distribute.Strategy`, outside of built-in training loops such as
             `elegy` `compile` and `fit`, or `SUM_OVER_BATCH_SIZE`
             will raise an error.
             for more details.
         name: Optional name for the loss.
         weight: Optional weight contribution for the total loss. Defaults to `1`.
     """
     self.name = (name if name is not None else re.sub(
         r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower())
     self.weight = weight if weight is not None else 1.0
     self._reduction = (reduction if reduction is not None else
                        Reduction.SUM_OVER_BATCH_SIZE)
     self._labels_filter = (on, ) if isinstance(on, (str, int)) else on
     self.call = utils.inject_dependencies(self.call)
Esempio n. 7
0
    def test_positional_error_remaining(self):
        def f(a, b, c):
            return a + b + c

        g = utils.inject_dependencies(f)

        with pytest.raises(TypeError):
            g("a", "b", "c", "d")
Esempio n. 8
0
    def test_keyword_error_missing(self):
        def f(a, b, c):
            return a + b + c

        g = utils.inject_dependencies(f)

        with pytest.raises(TypeError):
            g(b="b", c="c")
Esempio n. 9
0
    def test_positional_error_missing(self):
        def f(a, b, c):
            return a + b + c

        g = inject_dependencies(f)

        with pytest.raises(TypeError):
            g("a", "b")
Esempio n. 10
0
    def test_keyword(self):
        def f(a, b, c):
            return a + b + c

        g = inject_dependencies(f)

        y = g(b="b", c="c", a="a")

        assert y == "abc"
Esempio n. 11
0
    def test_positional(self):
        def f(a, b, c):
            return a + b + c

        g = utils.inject_dependencies(f)

        y = g("a", "b", "c")

        assert y == "abc"
Esempio n. 12
0
 def call_summary_step(
     self,
     x: tp.Any,
     states: types.States,
 ) -> tp.List[types.SummaryTableEntry]:
     return utils.inject_dependencies(self.summary_step)(
         x=x,
         states=states,
     )
Esempio n. 13
0
        def _lambda(*args, **kwargs):

            y_pred, parameters, collections = utils.inject_dependencies(
                self.module.init(rng=rng))(
                    *args,
                    **kwargs,
                )

            return types.OutputStates(y_pred, parameters, collections)
Esempio n. 14
0
    def test_keyword_extras_ok(self):
        def f(a, b, c):
            return a + b + c

        g = utils.inject_dependencies(f)

        y = g(b="b", c="c", a="a", d="d")

        assert y == "abc"
Esempio n. 15
0
    def test_positional_extras_ok(self):
        def f(a, b, c):
            return a + b + c

        g = inject_dependencies(f)

        y = g("a", "b", "c", d="d")

        assert y == "abc"
Esempio n. 16
0
    def test_mixed(self):
        def f(a, b, c):
            return a + b + c

        g = utils.inject_dependencies(f)

        y = g("a", c="c", b="b")

        assert y == "abc"
Esempio n. 17
0
    def test_mixed_ignore_duplicated_kwarg_in_arg(self):
        def f(a, b, c):
            return a + b + c

        g = utils.inject_dependencies(f)

        y = g("a", c="c", b="b", a="f")

        assert y == "abc"
Esempio n. 18
0
    def test_override_defaults(self):
        def f(a, b, c="x"):
            return a + b + c

        g = utils.inject_dependencies(f)

        y = g("a", c="c", b="b")

        assert y == "abc"
Esempio n. 19
0
        def _lambda(*args, **kwargs):

            y_pred, net_params, net_states = utils.inject_dependencies(
                self.module.apply(params, states, training=training,
                                  rng=rng), )(
                                      *args,
                                      **kwargs,
                                  )

            return types.OutputStates(y_pred, net_params, net_states)
Esempio n. 20
0
    def __init__(
        self,
        name: tp.Optional[str] = None,
        dtype: tp.Optional[jnp.dtype] = None,
        on: tp.Optional[types.IndexLike] = None,
    ):
        super().__init__(name=name)

        self._dtype = self._dtype = dtype if dtype is not None else jnp.float32
        self._labels_filter = (on, ) if isinstance(on, (str, int)) else on
        self.call = utils.inject_dependencies(self.call)
Esempio n. 21
0
        def _lambda(*args, **kwargs):
            def apply_fn(*args, **kwargs):
                kwargs = {f"__{name}": value for name, value in kwargs.items()}
                return self.module.apply(params, states, rng.next(), *args, **kwargs)

            y_pred, states_ = utils.inject_dependencies(apply_fn, signature_f=self.f,)(
                *args,
                **kwargs,
            )

            return types.OutputStates(y_pred, params, states_)
Esempio n. 22
0
        def lambda_(*args, **kwargs) -> types.OutputStates:

            output = utils.inject_dependencies(self.f)(*args, **kwargs)

            if isinstance(output, types.OutputStates):
                return output
            else:
                return types.OutputStates(
                    preds=output,
                    params=types.UNINITIALIZED,
                    states=types.UNINITIALIZED,
                )
Esempio n. 23
0
 def call_pred_step(
     self,
     x: tp.Any,
     states: types.States,
     initializing: bool,
     training: bool,
 ) -> PredStep:
     return utils.inject_dependencies(self.pred_step)(
         x=x,
         states=states,
         initializing=initializing,
         training=training,
     )
Esempio n. 24
0
        def _lambda(*args, **kwargs):

            y_pred, collections = utils.inject_dependencies(
                self.module.init(rng=rng))(
                    *args,
                    **kwargs,
                )
            assert isinstance(collections, dict)

            net_params = collections.pop("parameters", {})
            net_states = collections

            return types.OutputStates(y_pred, net_params, net_states)
Esempio n. 25
0
    def calculate_losses(self, *args, **kwargs) -> types.Logs:
        logs: types.Logs = {}

        for name, loss_fn in self.losses.items():
            losses = utils.inject_dependencies(loss_fn)(*args, **kwargs)

            names = set()
            for inner_name, loss in utils.flatten_names(losses):
                inner_name = f"{name}/{inner_name}" if inner_name else name
                inner_name = utils.get_unique_name(names, inner_name)

                logs[inner_name] = loss

        return logs
Esempio n. 26
0
        def _lambda(*args, **kwargs) -> types.OutputStates:

            preds = utils.inject_dependencies(self.f)(*args, **kwargs)

            if isinstance(preds, types.OutputStates):
                return preds

            n = 0
            total = jax.tree_map(lambda x: jnp.zeros_like(x), preds)
            return types.OutputStates(
                preds=preds,
                params=None,
                states=(n, total),
            )
Esempio n. 27
0
 def call_init_step(
     self,
     x: tp.Any,
     y_true: tp.Any,
     sample_weight: tp.Optional[np.ndarray],
     class_weight: tp.Optional[np.ndarray],
     states: types.States,
 ) -> types.States:
     return utils.inject_dependencies(self.init_step)(
         x=x,
         y_true=y_true,
         sample_weight=sample_weight,
         class_weight=class_weight,
         states=states,
     )
Esempio n. 28
0
        def _lambda(*args, **kwargs):
            collections = states.copy() if states is not None else {}
            if params is not None:
                collections["parameters"] = params

            y_pred, collections = utils.inject_dependencies(
                self.module.apply(collections, training=training, rng=rng), )(
                    *args,
                    **kwargs,
                )
            assert isinstance(collections, dict)

            net_params = collections.pop("parameters", {})
            net_states = collections

            return types.OutputStates(y_pred, net_params, net_states)
Esempio n. 29
0
        def _lambda(*args, **kwargs):
            def init_fn(*args, **kwargs) -> types.OutputStates:
                kwargs = {f"__{name}": value for name, value in kwargs.items()}
                key = rng.next()
                y_pred, params, states = self.module.init(key, *args, **kwargs)
                return types.OutputStates(y_pred, params, states)

            y_pred, params, states = utils.inject_dependencies(
                init_fn,
                signature_f=self.f,
            )(
                *args,
                **kwargs,
            )

            return types.OutputStates(y_pred, params, states)
Esempio n. 30
0
        def _lambda(*args, **kwargs) -> types.OutputStates:

            preds = utils.inject_dependencies(self.f)(*args, **kwargs)

            if isinstance(preds, types.OutputStates):
                return preds

            n, total = states
            n += 1
            total = jax.tree_multimap(lambda a, b: a + b, preds, total)
            preds = jax.tree_map(lambda total: total / n, total)
            return types.OutputStates(
                preds=preds,
                params=None,
                states=(n, total),
            )