def template(x, init_key=None): k1, k2 = random.split(init_key) layer1 = state.init(nn.Dense(20), name='dense')(k1, x) layer2 = state.init(nn.Dense(20), name='dense')(k2, x) return layer1(x) + layer2(x)
def dense(x, init_key=None): key, subkey = random.split(init_key) x = nn.Dense(50)(x, init_key=key, name='dense1') x = nn.Dense(20)(x, init_key=subkey, name='dense2') return x
def dense(x, init_key=None): return (nn.Dense(50) >> nn.Dense(20))(x, init_key=init_key, name='dense')
def dense(x, init_key=None): return nn.Dense(20)(x, init_key=init_key, name='dense')
def dense_no_rng(x): return nn.Dense(20)(x, name='dense')
"""Tests for tensorflow_probability.spinoffs.oryx.core.serialize.""" from absl.testing import absltest from absl.testing import parameterized import jax from jax import random import jax.numpy as np import numpy as onp from oryx.core import state from oryx.core.serialize import deserialize from oryx.core.serialize import serialize from oryx.experimental import nn templates = { 'dense': nn.Dense(200), 'dense_serial': nn.Dense(200) >> nn.Dense(200), 'relu': nn.Relu(), 'dense_relu': nn.Dense(200) >> nn.Relu(), 'convnet': ( nn.Reshape((28, 28, 1)) >> nn.Conv(32, (5, 5)) >> nn.MaxPooling((2, 2), (2, 2))), 'dropout': nn.Dropout(0.5) } class SerializeTest(parameterized.TestCase): @parameterized.named_parameters(templates.items()) def test_serialize(self, template): network = state.init(template)(random.PRNGKey(0), state.Shape(784))