コード例 #1
0
ファイル: linen_module_test.py プロジェクト: Dave0995/elegy
            def __call__(self, x):
                c1 = self.param("c1", lambda _: jnp.ones([5]))
                c2 = self.variable("states", "c2", lambda: jnp.ones([6]))

                x = jax.nn.relu(x)
                elegy.flax_summary(self, "relu", jax.nn.relu, x)

                return x
コード例 #2
0
ファイル: linen_module_test.py プロジェクト: Dave0995/elegy
            def __call__(self, x):
                a1 = self.param("a1", lambda _: jnp.ones([1]))
                a2 = self.variable("states", "a2", lambda: jnp.ones([2]))

                x = ModuleB()(x)

                x = jax.nn.relu(x)
                elegy.flax_summary(self, "relu", jax.nn.relu, x)

                return x
コード例 #3
0
ファイル: linen_module_test.py プロジェクト: Dave0995/elegy
            def __call__(self, x):
                b1 = self.param("b1", lambda _: jnp.ones([3]))
                b2 = self.variable("states", "b2", lambda: jnp.ones([4]))

                x = ModuleC()(x)

                x = jax.nn.relu(x)
                elegy.flax_summary(self, "relu", jax.nn.relu, x)

                return x