コード例 #1
0
ファイル: model_test.py プロジェクト: chjort/elegy
    def call(self, x):
        w = elegy.get_parameter("w", [x.shape[-1], self.units], jnp.float32,
                                jnp.ones)
        b = elegy.get_parameter("b", [self.units], jnp.float32, jnp.ones)

        n = self.get_state("n", [], np.int32, jnp.zeros)

        self.set_state("n", n + 1)

        y = jnp.dot(x, w) + b

        elegy.add_loss("activation_sum", jnp.sum(y))
        elegy.add_metric("activation_mean", jnp.mean(y))

        return y
コード例 #2
0
ファイル: module_test.py プロジェクト: chjort/elegy
 def call(self, x) -> np.ndarray:
     x = self.linear(x)
     x = self.linear1(x)
     self.bias = elegy.get_parameter("bias", [x.shape[-1]], jnp.float32,
                                     jnp.ones)
     return x + self.bias * 10
コード例 #3
0
ファイル: module_test.py プロジェクト: chjort/elegy
 def call(self, x) -> np.ndarray:
     x = ModuleDynamicTest.Linear(6)(x)
     x = ModuleDynamicTest.Linear(7)(x)
     self.bias = elegy.get_parameter("bias", [x.shape[-1]],
                                     initializer=jnp.ones)
     return x + self.bias * 10