コード例 #1
0
 def template(x, init_key=None):
   k1, k2 = random.split(init_key)
   layer1 = state.init(nn.Dense(20), name='dense')(k1, x)
   layer2 = state.init(nn.Dense(20), name='dense')(k2, x)
   return layer1(x) + layer2(x)
コード例 #2
0
 def dense(x, init_key=None):
   key, subkey = random.split(init_key)
   x = nn.Dense(50)(x, init_key=key, name='dense1')
   x = nn.Dense(20)(x, init_key=subkey, name='dense2')
   return x
コード例 #3
0
 def dense(x, init_key=None):
   return (nn.Dense(50) >> nn.Dense(20))(x, init_key=init_key, name='dense')
コード例 #4
0
 def dense(x, init_key=None):
   return nn.Dense(20)(x, init_key=init_key, name='dense')
コード例 #5
0
 def dense_no_rng(x):
   return nn.Dense(20)(x, name='dense')
コード例 #6
0
"""Tests for tensorflow_probability.spinoffs.oryx.core.serialize."""

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import random
import jax.numpy as np
import numpy as onp
from oryx.core import state
from oryx.core.serialize import deserialize
from oryx.core.serialize import serialize
from oryx.experimental import nn


templates = {
    'dense': nn.Dense(200),
    'dense_serial': nn.Dense(200) >> nn.Dense(200),
    'relu': nn.Relu(),
    'dense_relu': nn.Dense(200) >> nn.Relu(),
    'convnet': (
        nn.Reshape((28, 28, 1)) >> nn.Conv(32, (5, 5))
        >> nn.MaxPooling((2, 2), (2, 2))),
    'dropout': nn.Dropout(0.5)
}


class SerializeTest(parameterized.TestCase):

  @parameterized.named_parameters(templates.items())
  def test_serialize(self, template):
    network = state.init(template)(random.PRNGKey(0), state.Shape(784))