def _lambda(*args, **kwargs): module_logs = self.calculate_losses(*args, **kwargs) loss = sum(aux_losses.values(), 0.0) + sum(module_logs.values(), 0.0) loss_logs = dict(loss=loss) logs = utils.merge_with_unique_names(loss_logs, aux_losses, module_logs) names = set() logs = { utils.get_unique_name(names, f"{name}_loss") if "loss" not in name else utils.get_unique_name(names, name): value for name, value in logs.items() } logs, _, states = self.loss_metrics.init(rng=rng)(logs) return loss, logs, states
def __init__(self, losses: tp.Any): names: tp.Set[str] = set() def get_name(loss_fn, path): name = utils.get_name(loss_fn) return f"{path}/{name}" if path else name self.losses = { utils.get_unique_name(names, get_name(loss_fn, path)): loss_fn for path, loss_fn in utils.flatten_names(losses) } self.loss_metrics = LossMetrics()
def calculate_losses(self, *args, **kwargs) -> types.Logs: logs: types.Logs = {} for name, loss_fn in self.losses.items(): losses = utils.inject_dependencies(loss_fn)(*args, **kwargs) names = set() for inner_name, loss in utils.flatten_names(losses): inner_name = f"{name}/{inner_name}" if inner_name else name inner_name = utils.get_unique_name(names, inner_name) logs[inner_name] = loss return logs
def __init__(self, modules: tp.Any): names: tp.Set[str] = set() def get_name(module, path): name = utils.get_name(module) return f"{path}/{name}" if path else name self.metrics = { utils.get_unique_name(names, get_name(module, path)): generalize( module, callable_default=AvgMetric, ) for path, module in utils.flatten_names(modules) }
def calculate_metrics( self, aux_metrics: types.Logs, callback: tp.Callable[[str, GeneralizedModule], types.OutputStates], ) -> tp.Tuple[types.Logs, tp.Any]: states = {} for name, module in self.metrics.items(): y_pred, _, states[name] = callback(name, module) names = set() for inner_name, inner_value in utils.flatten_names(y_pred): inner_name = f"{name}/{inner_name}" if inner_name else name inner_name = utils.get_unique_name(names, inner_name) aux_metrics[inner_name] = inner_value return aux_metrics, states
def __call__(cls: tp.Type, *args, **kwargs) -> "Module": # Set unique on parent when using inside `call` if LOCAL.inside_call: assert LOCAL.module_index is not None assert LOCAL.parent index = LOCAL.module_index parent = LOCAL.parent if len(parent._dynamic_submodules) > index: module = parent._dynamic_submodules[index] assert isinstance(module, Module) # if not isinstance(module, cls): if module.__class__.__name__ != cls.__name__: raise types.ModuleOrderError( f"Error retrieving module, expected type {cls.__name__}, got {module.__class__.__name__}. " "This is probably due to control flow, you must guarantee that the same amount " "of submodules will be created every time and that their order is the same." ) else: # if not LOCAL.initializing: # raise ValueError( # f"Trying to create module of type'{cls.__name__}' outside of `init`." # ) module = construct_module(cls, *args, **kwargs) name = utils.get_unique_name(set(parent._submodules), module.name) parent._submodules[name] = module parent._submodule_name[module] = name parent._dynamic_submodules.append(module) LOCAL.module_index += 1 return module else: return construct_module(cls, *args, **kwargs)
def add_metric(name: str, value: types.Scalar) -> None: """ A hook that lets you define a metric within a [`module`][elegy.module.Module]. ```python y = jax.nn.relu(x) elegy.hooks.add_metric("activation_mean", jnp.mean(y)) ``` Arguments: name: The name of the loss. If a metric with the same `name` already exists a unique identifier will be generated. value: The value for the metric. """ if LOCAL.metrics is None: return # name = f"{base_name()}/{name}" name = utils.get_unique_name(set(LOCAL.metrics), name) LOCAL.metrics[name] = value