def test_basic(self): class M(linen.Module): @linen.compact def __call__(self, x): initialized = self.has_variable("batch_stats", "n") vn = self.variable("batch_stats", "n", lambda: 0) w = self.param("w", lambda key: 2.0) if initialized: vn.value += 1 return x * w gm = generalize(M()) rng = elegy.RNGSeq(42) y_true, params, states = gm.init(rng)(x=3.0, y=1) assert y_true == 6 assert params["w"] == 2 assert states["batch_stats"]["n"] == 0 params = params.copy(dict(w=10.0)) y_true, params, states = gm.apply(params, states, training=True, rng=rng)(x=3.0, y=1) assert y_true == 30 assert params["w"] == 10 assert states["batch_stats"]["n"] == 1
def test_basic(self): class M(elegy.Module): def call(self, x): n = self.add_parameter("n", lambda: 0, trainable=False) w = self.add_parameter("w", lambda: 2.0) self.update_parameter("n", n + 1) key = self.next_key() return x * w gm = generalize(M()) rng = elegy.RNGSeq(42) y_true, params, states = gm.init(rng)(x=3.0, y=1) assert y_true == 6 assert params["w"] == 2 assert states["states"]["n"] == 0 params["w"] = 10.0 y_true, params, states = gm.apply(params, states, training=True, rng=rng)( x=3.0, y=1 ) assert y_true == 30 assert params["w"] == 10 assert states["states"]["n"] == 1
def __init__(self, modules: tp.Any): names: tp.Set[str] = set() def get_name(module, path): name = utils.get_name(module) return f"{path}/{name}" if path else name self.metrics = { utils.get_unique_name(names, get_name(module, path)): generalize( module, callable_default=AvgMetric, ) for path, module in utils.flatten_names(modules) }
def __init__( self, module: tp.Any = None, loss: tp.Any = None, metrics: tp.Any = None, optimizer: tp.Any = None, seed: int = 42, **kwargs, ): """[summary] Arguments: module: A `Module` instance. loss: A `elegy.Loss` or `Callable` instance representing the loss function of the network. You can define more loss terms by simply passing a possibly nested structure of lists and dictionaries of `elegy.Loss` or `Callable`s. Usually a plain list of losses is enough but using dictionaries will create namescopes for the names of the losses which might be useful e.g. to group things in tensorboard. Contrary to Keras convention, in Elegy there is no relation between the structure of `loss` with the structure of the labels and outputs of the network. Elegy's loss system is more flexible than the one provided by Keras, for more information on how to mimick Keras behavior checkout the [Losses and Metrics Guide](https://poets-ai.github.io/elegy/guides/losses-and-metrics)`. metrics: A `elegy.Metric` or `Callable` instance representing the loss function of the network. You can define more metrics terms by simply passing a possibly nested structure of lists and dictionaries of `elegy.Metric` or `Callable`s. Usually a plain list of metrics is enough but using dictionaries will create namescopes for the names of the metrics which might be useful e.g. to group things in tensorboard. Contrary to Keras convention, in Elegy there is no relation between the structure of `metrics` with the structure of the labels and outputs of the network. Elegy's metrics system is more flexible than the one provided by Keras, for more information on how to mimick Keras behavior checkout the [Losses and Metrics Guide](https://poets-ai.github.io/elegy/guides/losses-and-metrics)`. optimizer: A `optax` optimizer instance. Optix is a very flexible library for defining optimization pipelines with things like learning rate schedules, this means that there is no need for a `LearningRateScheduler` callback in Elegy. run_eagerly: Settable attribute indicating whether the model should run eagerly. Running eagerly means that your model will be run step by step, like Python code, instead of using Jax's `jit` to. Your model might run slower, but it should become easier for you to debug it by stepping into individual layer calls. """ if "rng" in kwargs and not isinstance(kwargs["rng"], (int, types.RNGSeq)): raise ValueError( f"rng must be one of the following types: int, types.RNGSeq. Got {kwargs['rng']}" ) super().__init__(**kwargs) # maybe add rng if initialized if self.initialized and (not hasattr(self.states, "rng") or self.states.rng is None): self.states = self.states.update(rng=types.RNGSeq(seed)) self.module = module self.loss = loss self.metrics = metrics self.optimizer = optimizer if loss is None: loss = {} if metrics is None: metrics = {} self.api_module = generalize(module) if module is not None else None self.api_loss = Losses(loss) self.api_metrics = Metrics(metrics) self.api_optimizer = (generalize_optimizer(optimizer) if optimizer is not None else None) self.seed = seed