コード例 #1
0
ファイル: custom_vjp.py プロジェクト: wdevazelhes/flax
 def bwd(features, scope_fn, params, res, g):
   x = res
   fn = lambda params, x: nn.dense(scope_fn(params), x, features)
   _, pullback = jax.vjp(fn, params, x)
   g_param, g_x = pullback(g)
   g_param = jax.tree_map(jnp.sign, g_param)
   return g_param, g_x
コード例 #2
0
ファイル: lift_test.py プロジェクト: voicedm/flax
 def f(scope, x):
     nonlocal compiles
     compiles += 1
     if scope.is_mutable_collection(
             'intermediates'
     ) and not scope.is_mutable_collection('params'):
         scope.put_variable('intermediates', 'x', x + 1)
     return nn.dense(scope, x, 1)
コード例 #3
0
 def encode(self, scope, x):
     return nn.dense(scope, x, self.latents, bias=False)
コード例 #4
0
 def f(scope):
     nn.dense(scope.push('dense'), np.ones((1, 2)), 2)
コード例 #5
0
ファイル: custom_vjp.py プロジェクト: wdevazelhes/flax
 def fwd(scope, x, features):
   y = nn.dense(scope, x, features)
   return y, x