Exemple #1
0
    def call(inputs, *args, **kwargs):

        if LOCAL.contexts:
            context: Context = LOCAL.contexts[-1]

            if not context.module_c:
                raise ValueError(
                    "Cannot execute `sequential` outside of a module's `call` or `init`."
                )

            module: Module = context.module_c[-1]

            out = inputs
            for i, layer in enumerate(layers):
                if i == 0:
                    out = layer(out, *args, **kwargs)
                else:
                    out = layer(out)

                if not isinstance(layer, Module):
                    name = (
                        layer.__name__
                        if hasattr(layer, "__name__")
                        else layer.__class__.__name__
                    )
                    hooks.add_summary(name, out)
            return out

        else:
            raise ValueError(
                "Cannot execute `sequential` outside of an `elegy.context`"
            )
Exemple #2
0
def haiku_summary(
    name: str,
    f: tp.Any,
    value: types.Scalar,
):
    if hooks.summaries_active():
        path = tuple(current_bundle_name().split("/")) + (name, )
        hooks.add_summary(path, f, value)
Exemple #3
0
def flax_summary(
    flax_module: linen.Module,
    name: str,
    f: tp.Any,
    value: types.Scalar,
):
    if hooks.summaries_active():
        path = flax_module.scope.path + (name, )
        hooks.add_summary(path, f, value)
Exemple #4
0
    def wrapper(self: linen.Module, *args, **kwargs):

        outputs = f(self, *args, **kwargs)

        if hooks.summaries_active():
            path = self.scope.path
            hooks.add_summary(path, self, outputs)

        return outputs
Exemple #5
0
    def wrapper(self: haiku.Module, *args, **kwargs):

        outputs = f(self, *args, **kwargs)

        if hooks.summaries_active():
            path = current_bundle_name().split("/")
            hooks.add_summary(tuple(path), self, outputs)

        return outputs
Exemple #6
0
    def call(inputs, *args, **kwargs):

        out = inputs
        for i, layer in enumerate(layers):
            if i == 0:
                out = layer(out, *args, **kwargs)
            else:
                out = layer(out)

            if not isinstance(layer, module.Module):
                if hooks.summaries_active():
                    name = utils.get_name(layer)

                    path = module.get_module_path()
                    path = path if path is not None else ()

                    hooks.add_summary(path + (name,), layer, out)
        return out
Exemple #7
0
    def __call__(self, *args, **kwargs) -> tp.Any:
        """
        Forwards all input arguments to the Module's `call` method and calls
        `elegy.hooks.add_summary` on the outputs.
        """

        # this marks initialization

        with call_context(self):

            outputs = self.call(*args, **kwargs)

            if hooks.summaries_active():
                path = get_module_path(self)
                assert path is not None
                hooks.add_summary(path, self, outputs)

            return outputs
Exemple #8
0
 def add_summary(self, name: str, f: tp.Any, value: tp.Any):
     if hooks.summaries_active():
         path = get_module_path(self) + (name, )
         assert path is not None
         hooks.add_summary(path, f, value)