def test_jit_auto_init(self): with elegy.training_context(True): total_called = 0 class SomeModule(elegy.Module): n: jnp.ndarray def call(self, x): nonlocal total_called total_called += 1 n = self.add_parameter("n", initializer=jnp.array(0)) self.update_parameter("n", n + 1) if elegy.is_training(): return x + 1 else: return x - 1 m = SomeModule() assert total_called == 0 m.jit(0) assert total_called == 1 m_jit = elegy.jit(m) assert m.n == 1 y = m_jit(0) assert y == 1 assert m.n == 2 assert total_called == 2 y = m_jit(0) assert m.n == 3 assert total_called == 2 elegy.set_training(False) y = m_jit(0) assert y == -1 assert total_called == 3 assert m.n == 4 with elegy.training_context(training=True): y = m_jit(0) assert y == 1 assert total_called == 3 assert m.n == 5 with elegy.training_context(training=True), elegy.hooks_context(): y = m_jit(0) assert y == 1 assert total_called == 4 assert m.n == 6 with elegy.training_context(training=True), elegy.hooks_context(): y = m_jit(0) assert y == 1 assert total_called == 4 assert m.n == 7 with elegy.training_context(training=True), elegy.hooks_context( summaries=True): y = m_jit(0) assert y == 1 assert total_called == 5 assert m.n == 8 with elegy.training_context(training=False), elegy.hooks_context( summaries=True): y = m_jit(0) assert y == -1 assert total_called == 6 assert m.n == 9
def test_get_parameters(self): x = np.random.uniform(-1, 1, size=(4, 5)) m = ModuleTest.MyModule() m.init(x) parameters = m.get_parameters() state = m.get_parameters(trainable=False) assert "bias" in parameters assert "linear" in parameters assert "w" in parameters["linear"] assert "b" in parameters["linear"] assert parameters["linear"]["n"] == 0 assert parameters["linear1"]["n"] == 0 assert "linear1" in parameters with elegy.hooks_context(summaries=True): y: jnp.ndarray = m(x) # y2: jnp.ndarray = m.call_jit(x) losses = elegy.get_losses() metrics = elegy.get_metrics() summaries = elegy.get_summaries() assert losses assert metrics assert summaries parameters = m.get_parameters() assert y.shape == (4, 7) assert "bias" in parameters assert "linear" in parameters assert "w" in parameters["linear"] assert "b" in parameters["linear"] assert m.linear.get_parameters()["n"] == 1 assert parameters["linear"]["n"] == 1 assert "linear1" in parameters assert "activation_sum_loss" in losses assert "my_module/linear/activation_mean" in metrics assert "my_module/linear_1/activation_mean" in metrics assert summaries[0][:2] == (m.linear, "my_module/linear") assert summaries[0][2].shape == (4, 6) assert summaries[1][:2] == (m.linear1, "my_module/linear_1") assert summaries[1][2].shape == (4, 7) assert summaries[2][:2] == (m, "my_module") assert summaries[2][2].shape == (4, 7) m.set_parameters(jax.tree_map(lambda x: -x, parameters)) parameters = m.get_parameters() assert parameters["bias"][0] == -1 assert m.linear.get_parameters()["w"][0, 0] == -1 assert m.linear.get_parameters()["b"][0] == -1 assert m.linear1.get_parameters()["w"][0, 0] == -1 assert m.linear1.get_parameters()["b"][0] == -1 assert m.parameters_size(include_submodules=False) == 7 current_parameters = m.get_parameters() m.reset() parameters = m.get_parameters() assert jax.tree_leaves(parameters) == [] assert m.parameters_size() == 0 m.set_parameters(current_parameters) assert m.get_parameters()["bias"][0] == -1 assert m.linear.get_parameters()["w"][0, 0] == -1 assert m.linear.get_parameters()["b"][0] == -1 assert m.linear1.get_parameters()["w"][0, 0] == -1 assert m.linear1.get_parameters()["b"][0] == -1
def test_auto_init(self): x = np.random.uniform(-1, 1, size=(4, 5)) initial_key = elegy.get_rng().key m = ModuleDynamicTest.MyModule() m(x) # THESE: assert m.linear.get_parameters()["n"] == 1 assert m.get_parameters()["linear"]["n"] == 1 assert not jnp.allclose(initial_key, elegy.get_rng().key) assert "bias" in m.get_parameters() assert "linear" in m.get_parameters() assert "w" in m.get_parameters()["linear"] assert "b" in m.get_parameters()["linear"] assert "linear_1" in m.get_parameters() with elegy.hooks_context(summaries=True): # y: jnp.ndarray = m(x) y: jnp.ndarray = m.jit(x) losses = elegy.get_losses() metrics = elegy.get_metrics() summaries = elegy.get_summaries() assert losses assert metrics assert summaries assert y.shape == (4, 7) assert "bias" in m.get_parameters() assert "linear" in m.get_parameters() assert "w" in m.get_parameters()["linear"] assert "b" in m.get_parameters()["linear"] assert m.linear.get_parameters()["n"] == 2 assert m.get_parameters()["linear"]["n"] == 2 assert "linear_1" in m.get_parameters() assert "activation_sum_loss" in losses assert "my_module/linear/activation_mean" in metrics assert "my_module/linear_1/activation_mean" in metrics assert summaries[0][:2] == (m.linear, "my_module/linear") assert summaries[0][2].shape == (4, 6) assert summaries[1][:2] == (m.linear_1, "my_module/linear_1") assert summaries[1][2].shape == (4, 7) assert summaries[2][:2] == (m, "my_module") assert summaries[2][2].shape == (4, 7) m.set_parameters(jax.tree_map(lambda x: -x, m.get_parameters())) assert m.get_parameters()["bias"][0] == -1 assert m.linear.get_parameters()["w"][0, 0] == -1 assert m.linear.get_parameters()["b"][0] == -1 assert m.linear_1.get_parameters()["w"][0, 0] == -1 assert m.linear_1.get_parameters()["b"][0] == -1 assert m.parameters_size(include_submodules=False) == 7 current_parameters = m.get_parameters() m.reset() assert jax.tree_leaves(m.get_parameters()) == [] assert m.parameters_size() == 0 m.set_parameters(current_parameters) assert m.get_parameters()["bias"][0] == -1 assert m.linear.get_parameters()["w"][0, 0] == -1 assert m.linear.get_parameters()["b"][0] == -1 assert m.linear_1.get_parameters()["w"][0, 0] == -1 assert m.linear_1.get_parameters()["b"][0] == -1