Example #1
0
        def _lambda(*args, **kwargs):
            module_logs = self.calculate_losses(*args, **kwargs)

            loss = sum(aux_losses.values(), 0.0) + sum(module_logs.values(),
                                                       0.0)
            loss_logs = dict(loss=loss)

            logs = utils.merge_with_unique_names(loss_logs, aux_losses,
                                                 module_logs)

            names = set()
            logs = {
                utils.get_unique_name(names, f"{name}_loss") if "loss"
                not in name else utils.get_unique_name(names, name): value
                for name, value in logs.items()
            }

            logs, _, states = self.loss_metrics.init(rng=rng)(logs)

            return loss, logs, states
Example #2
0
    def __init__(self, losses: tp.Any):
        names: tp.Set[str] = set()

        def get_name(loss_fn, path):
            name = utils.get_name(loss_fn)
            return f"{path}/{name}" if path else name

        self.losses = {
            utils.get_unique_name(names, get_name(loss_fn, path)): loss_fn
            for path, loss_fn in utils.flatten_names(losses)
        }
        self.loss_metrics = LossMetrics()
Example #3
0
    def calculate_losses(self, *args, **kwargs) -> types.Logs:
        logs: types.Logs = {}

        for name, loss_fn in self.losses.items():
            losses = utils.inject_dependencies(loss_fn)(*args, **kwargs)

            names = set()
            for inner_name, loss in utils.flatten_names(losses):
                inner_name = f"{name}/{inner_name}" if inner_name else name
                inner_name = utils.get_unique_name(names, inner_name)

                logs[inner_name] = loss

        return logs
Example #4
0
    def __init__(self, modules: tp.Any):
        names: tp.Set[str] = set()

        def get_name(module, path):
            name = utils.get_name(module)
            return f"{path}/{name}" if path else name

        self.metrics = {
            utils.get_unique_name(names, get_name(module, path)): generalize(
                module,
                callable_default=AvgMetric,
            )
            for path, module in utils.flatten_names(modules)
        }
Example #5
0
    def calculate_metrics(
        self,
        aux_metrics: types.Logs,
        callback: tp.Callable[[str, GeneralizedModule], types.OutputStates],
    ) -> tp.Tuple[types.Logs, tp.Any]:

        states = {}

        for name, module in self.metrics.items():
            y_pred, _, states[name] = callback(name, module)

            names = set()
            for inner_name, inner_value in utils.flatten_names(y_pred):
                inner_name = f"{name}/{inner_name}" if inner_name else name
                inner_name = utils.get_unique_name(names, inner_name)

                aux_metrics[inner_name] = inner_value

        return aux_metrics, states
Example #6
0
    def __call__(cls: tp.Type, *args, **kwargs) -> "Module":

        # Set unique on parent when using inside `call`
        if LOCAL.inside_call:
            assert LOCAL.module_index is not None
            assert LOCAL.parent

            index = LOCAL.module_index
            parent = LOCAL.parent

            if len(parent._dynamic_submodules) > index:
                module = parent._dynamic_submodules[index]

                assert isinstance(module, Module)

                # if not isinstance(module, cls):
                if module.__class__.__name__ != cls.__name__:
                    raise types.ModuleOrderError(
                        f"Error retrieving module, expected type {cls.__name__}, got {module.__class__.__name__}. "
                        "This is probably due to control flow, you must guarantee that the same amount "
                        "of submodules will be created every time and that their order is the same."
                    )
            else:
                # if not LOCAL.initializing:
                #     raise ValueError(
                #         f"Trying to create module of type'{cls.__name__}' outside of `init`."
                #     )

                module = construct_module(cls, *args, **kwargs)

                name = utils.get_unique_name(set(parent._submodules),
                                             module.name)

                parent._submodules[name] = module
                parent._submodule_name[module] = name
                parent._dynamic_submodules.append(module)

            LOCAL.module_index += 1

            return module
        else:
            return construct_module(cls, *args, **kwargs)
Example #7
0
def add_metric(name: str, value: types.Scalar) -> None:
    """
    A hook that lets you define a metric within a [`module`][elegy.module.Module].

    ```python
    y = jax.nn.relu(x)
    elegy.hooks.add_metric("activation_mean", jnp.mean(y))
    ```

    Arguments:
        name: The name of the loss. If a metric with the same
            `name` already exists a unique identifier will be generated.
        value: The value for the metric.
    """
    if LOCAL.metrics is None:
        return

    # name = f"{base_name()}/{name}"
    name = utils.get_unique_name(set(LOCAL.metrics), name)
    LOCAL.metrics[name] = value