def test_connects(self): transformer_model = elegy.nn.Transformer( head_size=64, num_heads=4, num_encoder_layers=2, num_decoder_layers=2, ) src = np.random.uniform(size=(5, 32, 64)) tgt = np.random.uniform(size=(5, 32, 64)) _, params, collections = transformer_model.init(rng=elegy.RNGSeq(42))( src, tgt) out, params, collections = transformer_model.apply( params, collections, rng=elegy.RNGSeq(420))(src, tgt)
def test_basic(self): class M(linen.Module): @linen.compact def __call__(self, x): initialized = self.has_variable("batch_stats", "n") vn = self.variable("batch_stats", "n", lambda: 0) w = self.param("w", lambda key: 2.0) if initialized: vn.value += 1 return x * w gm = generalize(M()) rng = elegy.RNGSeq(42) y_true, params, states = gm.init(rng)(x=3.0, y=1) assert y_true == 6 assert params["w"] == 2 assert states["batch_stats"]["n"] == 0 params = params.copy(dict(w=10.0)) y_true, params, states = gm.apply(params, states, training=True, rng=rng)(x=3.0, y=1) assert y_true == 30 assert params["w"] == 10 assert states["batch_stats"]["n"] == 1
def test_basic(self): class M(elegy.Module): def call(self, x): n = self.add_parameter("n", lambda: 0, trainable=False) w = self.add_parameter("w", lambda: 2.0) self.update_parameter("n", n + 1) key = self.next_key() return x * w gm = generalize(M()) rng = elegy.RNGSeq(42) y_true, params, states = gm.init(rng)(x=3.0, y=1) assert y_true == 6 assert params["w"] == 2 assert states["states"]["n"] == 0 params["w"] = 10.0 y_true, params, states = gm.apply(params, states, training=True, rng=rng)( x=3.0, y=1 ) assert y_true == 30 assert params["w"] == 10 assert states["states"]["n"] == 1
def test_basic(self): class M(haiku.Module): def __call__(self, x): n = haiku.get_state( "n", shape=[], dtype=jnp.int32, init=lambda *args: np.array(0) ) w = haiku.get_parameter("w", [], init=lambda *args: np.array(2.0)) haiku.set_state("n", n + 1) return x * w def f(x, initializing, rng): return M()(x) gm = elegy.HaikuModule(f) rng = elegy.RNGSeq(42) y_true, params, states = gm.init(rng)(x=3.0, y=1, rng=None, initializing=True) assert y_true == 6 assert params["m"]["w"] == 2 assert states["m"]["n"] == 0 params = haiku.data_structures.to_mutable_dict(params) params["m"]["w"] = np.array(10.0) y_true, params, states = gm.apply(params, states, training=True, rng=rng)( x=3.0, y=1, rng=None, initializing=True ) assert y_true == 30 assert params["m"]["w"] == 10 assert states["m"]["n"] == 1
def test_losses(self): def loss_fn(): return 3.0 losses = elegy.model.model.Losses( dict(a=dict(b=[loss_fn, loss_fn], c=loss_fn))) rng = elegy.RNGSeq(42) hooks_losses = dict(x=0.3, y=4.5) with elegy.hooks.context(losses=True): elegy.hooks.add_loss("d", 1.0) aux_losses = elegy.hooks.get_losses() logs, logs, states = losses.init(aux_losses, rng)() with elegy.hooks.context(losses=True): elegy.hooks.add_loss("d", 1.0) aux_losses = elegy.hooks.get_losses() loss, logs, states = losses.apply(aux_losses, states)() assert loss == 10 assert len(losses.losses) == 3 assert "a/b/loss_fn" in losses.losses assert "a/b/loss_fn_1" in losses.losses assert "a/c/loss_fn" in losses.losses assert len(logs) == 5 assert "loss" in logs assert "a/b/loss_fn" in logs assert "a/b/loss_fn_1" in logs assert "a/c/loss_fn" in logs assert "d_loss" in logs
def test_optimizer_epoch(self): optax_op = optax.adam(1e-3) lr_schedule = lambda step, epoch: epoch optimizer = elegy.Optimizer(optax_op, lr_schedule=lr_schedule, steps_per_epoch=2) params = np.random.uniform((3, 4)) grads = np.random.uniform((3, 4)) rng = elegy.RNGSeq(42) optimizer_states = optimizer.init( rng=rng, net_params=params, ) assert jnp.allclose(optimizer.current_lr(optimizer_states), 0) params, optimizer_states = optimizer.apply(params, grads, optimizer_states, rng) assert jnp.allclose(optimizer.current_lr(optimizer_states), 0) params, optimizer_states = optimizer.apply(params, grads, optimizer_states, rng) assert jnp.allclose(optimizer.current_lr(optimizer_states), 1) params, optimizer_states = optimizer.apply(params, grads, optimizer_states, rng) assert jnp.allclose(optimizer.current_lr(optimizer_states), 1) params, optimizer_states = optimizer.apply(params, grads, optimizer_states, rng)
def test_connects(self): x = np.random.uniform(-1, 1, size=(4, 3)) linear = elegy.nn.Linear(5) y_pred = linear.call_with_defaults(rng=elegy.RNGSeq(42))(x) assert y_pred.shape == (4, 5)
def test_module_system_docs(self): class Linear(elegy.Module): def __init__(self, n_out): super().__init__() self.n_out = n_out def call(self, x): w = self.add_parameter( "w", lambda: elegy.initializers.RandomUniform()( shape=[x.shape[-1], self.n_out], dtype=jnp.float32, ), ) b = self.add_parameter("b", lambda: jnp.zeros(shape=[self.n_out])) return jnp.dot(x, w) + b class MLP(elegy.Module): def call(self, x): x = Linear(64)(x) x = jax.nn.relu(x) x = Linear(32)(x) x = jax.nn.relu(x) x = Linear(1)(x) return x def loss_fn(parameters, x, y): y_pred, _ = mlp.apply(dict(parameters=parameters))(x) return jnp.mean(jnp.square(y - y_pred)) def update(parameters, x, y): loss, gradients = jax.value_and_grad(loss_fn)(parameters, x, y) parameters = jax.tree_multimap(lambda p, g: p - 0.01 * g, parameters, gradients) return loss, parameters x = np.random.uniform(size=(15, 3)) y = np.random.uniform(size=(15, 1)) mlp = MLP() y_pred, collections = mlp.init(rng=elegy.RNGSeq(42))(x) parameters = collections["parameters"] update_jit = jax.jit(update) for step in range(1): loss, parameters = update_jit(parameters, x, y) mlp.set_default_parameters(dict(parameters=parameters))
def test_basic(): w = 2.0 grads = 1.5 lr = 1.0 rng = elegy.RNGSeq(42) go = generalize_optimizer(optax.sgd(lr)) states = go.init(rng, w) w, states = go.apply(w, grads, states, rng) assert w == 0.5
def test_connects(self): y = elegy.nn.Sequential( lambda: [ elegy.nn.Flatten(), elegy.nn.Linear(5), jax.nn.relu, elegy.nn.Linear(2), ] ).call_with_defaults(rng=elegy.RNGSeq(42))(jnp.ones([10, 3])) assert y.shape == (10, 2) y = elegy.nn.Sequential( lambda: [ elegy.nn.Flatten(), elegy.nn.Linear(5), jax.nn.relu, elegy.nn.Linear(2), ] ).call_with_defaults(rng=elegy.RNGSeq(42), training=False)(jnp.ones([10, 3])) assert y.shape == (10, 2)
def init_step(self, x): rng = elegy.RNGSeq(0) gx, g_params, g_states = self.generator.init(rng=rng)(x) dx, d_params, d_states = self.discriminator.init(rng=rng)(gx) g_optimizer_states = self.g_optimizer.init(g_params) d_optimizer_states = self.d_optimizer.init(d_params) return elegy.States( g_states=g_states, d_states=d_states, g_params=g_params, d_params=d_params, g_opt_states=g_optimizer_states, d_opt_states=d_optimizer_states, rng=rng, step=0, )
def test_di(self): m = elegy.nn.Sequential( lambda: [ elegy.nn.Flatten(), elegy.nn.Linear(2), ] ) y = elegy.inject_dependencies( m.call_with_defaults(rng=elegy.RNGSeq(42), training=False), signature_f=m )( jnp.ones([5, 3]), a=1, b=2, ) assert y.shape == (5, 2)
def test_optimizer_chain(self): optimizer = elegy.Optimizer( optax.sgd(0.1), optax.clip(0.5), ) params = np.zeros(shape=(3, 4)) grads = np.ones(shape=(3, 4)) * 100_000 rng = elegy.RNGSeq(42) optimizer_states = optimizer.init( rng=rng, net_params=params, ) params, optimizer_states = optimizer.apply(params, grads, optimizer_states, rng) assert np.all(-0.5 <= params) and np.all(params <= 0.5)
def test_metrics(self): class M(elegy.Module): def call(self, x): n = self.add_parameter("n", lambda: 0, trainable=False) self.update_parameter("n", n + 1) return x metrics = elegy.model.model.Metrics(dict(a=dict(b=[M(), M()], c=M()))) rng = elegy.RNGSeq(42) x = np.random.uniform(size=(5, 7, 7)) with elegy.hooks.context(metrics=True): elegy.hooks.add_metric("d", 10) aux_metrics = elegy.hooks.get_metrics() logs, states = metrics.init(aux_metrics, rng)(x, training=True) with elegy.hooks.context(metrics=True): elegy.hooks.add_metric("d", 10) aux_metrics = elegy.hooks.get_metrics() logs, states = metrics.apply(aux_metrics, rng, states)(x, training=True) assert len(metrics.metrics) == 3 assert "a/b/m" in metrics.metrics assert "a/b/m_1" in metrics.metrics assert "a/c/m" in metrics.metrics assert len(logs) == 4 assert "a/b/m" in logs assert "a/b/m_1" in logs assert "a/c/m" in logs assert "d" in logs assert len(states) == 3 assert "a/b/m" in states assert "a/b/m_1" in states assert "a/c/m" in states
def test_optimizer(self): optax_op = optax.adam(1e-3) lr_schedule = lambda step, epoch: step / 3 optimizer = elegy.Optimizer(optax_op, lr_schedule=lr_schedule) params = np.random.uniform((3, 4)) grads = np.random.uniform((3, 4)) rng = elegy.RNGSeq(42) optimizer_states = optimizer.init(rng, params) assert jnp.allclose(optimizer.current_lr(optimizer_states), 0 / 3) params, optimizer_states = optimizer.apply(params, grads, optimizer_states, rng) assert jnp.allclose(optimizer.current_lr(optimizer_states), 1 / 3) params, optimizer_states = optimizer.apply(params, grads, optimizer_states, rng) assert jnp.allclose(optimizer.current_lr(optimizer_states), 2 / 3) params, optimizer_states = optimizer.apply(params, grads, optimizer_states, rng) assert jnp.allclose(optimizer.current_lr(optimizer_states), 3 / 3)
def test_dropout_connects(self): elegy.nn.Dropout(0.25).call_with_defaults(rng=elegy.RNGSeq(42))( jnp.ones([3, 3]), training=True)