Ejemplo n.º 1
0
    def test_jit_auto_init(self):
        with elegy.training_context(True):
            total_called = 0

            class SomeModule(elegy.Module):
                n: jnp.ndarray

                def call(self, x):
                    nonlocal total_called

                    total_called += 1

                    n = self.add_parameter("n", initializer=jnp.array(0))
                    self.update_parameter("n", n + 1)

                    if elegy.is_training():
                        return x + 1
                    else:
                        return x - 1

            m = SomeModule()

            assert total_called == 0

            m.jit(0)
            assert total_called == 1

            m_jit = elegy.jit(m)

            assert m.n == 1

            y = m_jit(0)

            assert y == 1
            assert m.n == 2
            assert total_called == 2

            y = m_jit(0)
            assert m.n == 3
            assert total_called == 2

            elegy.set_training(False)
            y = m_jit(0)
            assert y == -1
            assert total_called == 3
            assert m.n == 4

            with elegy.training_context(training=True):
                y = m_jit(0)
                assert y == 1
                assert total_called == 3
                assert m.n == 5

            with elegy.training_context(training=True), elegy.hooks_context():
                y = m_jit(0)
                assert y == 1
                assert total_called == 4
                assert m.n == 6

            with elegy.training_context(training=True), elegy.hooks_context():
                y = m_jit(0)
                assert y == 1
                assert total_called == 4
                assert m.n == 7

            with elegy.training_context(training=True), elegy.hooks_context(
                    summaries=True):
                y = m_jit(0)
                assert y == 1
                assert total_called == 5
                assert m.n == 8

            with elegy.training_context(training=False), elegy.hooks_context(
                    summaries=True):
                y = m_jit(0)
                assert y == -1
                assert total_called == 6
                assert m.n == 9
Ejemplo n.º 2
0
    def test_get_parameters(self):
        x = np.random.uniform(-1, 1, size=(4, 5))

        m = ModuleTest.MyModule()

        m.init(x)

        parameters = m.get_parameters()
        state = m.get_parameters(trainable=False)

        assert "bias" in parameters
        assert "linear" in parameters
        assert "w" in parameters["linear"]
        assert "b" in parameters["linear"]
        assert parameters["linear"]["n"] == 0
        assert parameters["linear1"]["n"] == 0
        assert "linear1" in parameters

        with elegy.hooks_context(summaries=True):
            y: jnp.ndarray = m(x)
            # y2: jnp.ndarray = m.call_jit(x)

            losses = elegy.get_losses()
            metrics = elegy.get_metrics()
            summaries = elegy.get_summaries()

        assert losses
        assert metrics
        assert summaries

        parameters = m.get_parameters()

        assert y.shape == (4, 7)
        assert "bias" in parameters
        assert "linear" in parameters
        assert "w" in parameters["linear"]
        assert "b" in parameters["linear"]
        assert m.linear.get_parameters()["n"] == 1
        assert parameters["linear"]["n"] == 1
        assert "linear1" in parameters

        assert "activation_sum_loss" in losses
        assert "my_module/linear/activation_mean" in metrics
        assert "my_module/linear_1/activation_mean" in metrics

        assert summaries[0][:2] == (m.linear, "my_module/linear")
        assert summaries[0][2].shape == (4, 6)
        assert summaries[1][:2] == (m.linear1, "my_module/linear_1")
        assert summaries[1][2].shape == (4, 7)
        assert summaries[2][:2] == (m, "my_module")
        assert summaries[2][2].shape == (4, 7)

        m.set_parameters(jax.tree_map(lambda x: -x, parameters))

        parameters = m.get_parameters()

        assert parameters["bias"][0] == -1
        assert m.linear.get_parameters()["w"][0, 0] == -1
        assert m.linear.get_parameters()["b"][0] == -1
        assert m.linear1.get_parameters()["w"][0, 0] == -1
        assert m.linear1.get_parameters()["b"][0] == -1

        assert m.parameters_size(include_submodules=False) == 7

        current_parameters = m.get_parameters()

        m.reset()

        parameters = m.get_parameters()

        assert jax.tree_leaves(parameters) == []
        assert m.parameters_size() == 0

        m.set_parameters(current_parameters)

        assert m.get_parameters()["bias"][0] == -1
        assert m.linear.get_parameters()["w"][0, 0] == -1
        assert m.linear.get_parameters()["b"][0] == -1
        assert m.linear1.get_parameters()["w"][0, 0] == -1
        assert m.linear1.get_parameters()["b"][0] == -1
Ejemplo n.º 3
0
    def test_auto_init(self):
        x = np.random.uniform(-1, 1, size=(4, 5))
        initial_key = elegy.get_rng().key
        m = ModuleDynamicTest.MyModule()

        m(x)

        # THESE:
        assert m.linear.get_parameters()["n"] == 1
        assert m.get_parameters()["linear"]["n"] == 1

        assert not jnp.allclose(initial_key, elegy.get_rng().key)
        assert "bias" in m.get_parameters()
        assert "linear" in m.get_parameters()
        assert "w" in m.get_parameters()["linear"]
        assert "b" in m.get_parameters()["linear"]
        assert "linear_1" in m.get_parameters()

        with elegy.hooks_context(summaries=True):
            # y: jnp.ndarray = m(x)
            y: jnp.ndarray = m.jit(x)

            losses = elegy.get_losses()
            metrics = elegy.get_metrics()
            summaries = elegy.get_summaries()

        assert losses
        assert metrics
        assert summaries

        assert y.shape == (4, 7)
        assert "bias" in m.get_parameters()
        assert "linear" in m.get_parameters()
        assert "w" in m.get_parameters()["linear"]
        assert "b" in m.get_parameters()["linear"]
        assert m.linear.get_parameters()["n"] == 2
        assert m.get_parameters()["linear"]["n"] == 2
        assert "linear_1" in m.get_parameters()

        assert "activation_sum_loss" in losses
        assert "my_module/linear/activation_mean" in metrics
        assert "my_module/linear_1/activation_mean" in metrics

        assert summaries[0][:2] == (m.linear, "my_module/linear")
        assert summaries[0][2].shape == (4, 6)
        assert summaries[1][:2] == (m.linear_1, "my_module/linear_1")
        assert summaries[1][2].shape == (4, 7)
        assert summaries[2][:2] == (m, "my_module")
        assert summaries[2][2].shape == (4, 7)

        m.set_parameters(jax.tree_map(lambda x: -x, m.get_parameters()))

        assert m.get_parameters()["bias"][0] == -1
        assert m.linear.get_parameters()["w"][0, 0] == -1
        assert m.linear.get_parameters()["b"][0] == -1
        assert m.linear_1.get_parameters()["w"][0, 0] == -1
        assert m.linear_1.get_parameters()["b"][0] == -1

        assert m.parameters_size(include_submodules=False) == 7

        current_parameters = m.get_parameters()

        m.reset()

        assert jax.tree_leaves(m.get_parameters()) == []
        assert m.parameters_size() == 0

        m.set_parameters(current_parameters)

        assert m.get_parameters()["bias"][0] == -1
        assert m.linear.get_parameters()["w"][0, 0] == -1
        assert m.linear.get_parameters()["b"][0] == -1
        assert m.linear_1.get_parameters()["w"][0, 0] == -1
        assert m.linear_1.get_parameters()["b"][0] == -1