Example #1
0
 def template(x, init_key=None):
   k1, k2 = random.split(init_key)
   layer1 = state.init(nn.Dense(20), name='dense')(k1, x)
   layer2 = state.init(nn.Dense(20), name='dense')(k2, x)
   return layer1(x) + layer2(x)
Example #2
0
 def dense(x, init_key=None):
   key, subkey = random.split(init_key)
   x = nn.Dense(50)(x, init_key=key, name='dense1')
   x = nn.Dense(20)(x, init_key=subkey, name='dense2')
   return x
Example #3
0
 def dense(x, init_key=None):
   return (nn.Dense(50) >> nn.Dense(20))(x, init_key=init_key, name='dense')
Example #4
0
 def dense(x, init_key=None):
   return nn.Dense(20)(x, init_key=init_key, name='dense')
Example #5
0
 def dense_no_rng(x):
   return nn.Dense(20)(x, name='dense')
Example #6
0
"""Tests for tensorflow_probability.spinoffs.oryx.core.serialize."""

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import random
import jax.numpy as np
import numpy as onp
from oryx.core import state
from oryx.core.serialize import deserialize
from oryx.core.serialize import serialize
from oryx.experimental import nn


templates = {
    'dense': nn.Dense(200),
    'dense_serial': nn.Dense(200) >> nn.Dense(200),
    'relu': nn.Relu(),
    'dense_relu': nn.Dense(200) >> nn.Relu(),
    'convnet': (
        nn.Reshape((28, 28, 1)) >> nn.Conv(32, (5, 5))
        >> nn.MaxPooling((2, 2), (2, 2))),
    'dropout': nn.Dropout(0.5)
}


class SerializeTest(parameterized.TestCase):

  @parameterized.named_parameters(templates.items())
  def test_serialize(self, template):
    network = state.init(template)(random.PRNGKey(0), state.Shape(784))