Пример #1
0
 def template(x, init_key=None):
   k1, k2 = random.split(init_key)
   layer1 = state.init(nn.Dense(20), name='dense')(k1, x)
   layer2 = state.init(nn.Dense(20), name='dense')(k2, x)
   return layer1(x) + layer2(x)
Пример #2
0
 def dense(x, init_key=None):
   key, subkey = random.split(init_key)
   x = nn.Dense(50)(x, init_key=key, name='dense1')
   x = nn.Dense(20)(x, init_key=subkey, name='dense2')
   return x
Пример #3
0
 def dense(x, init_key=None):
   return (nn.Dense(50) >> nn.Dense(20))(x, init_key=init_key, name='dense')
Пример #4
0
 def dense(x, init_key=None):
   return nn.Dense(20)(x, init_key=init_key, name='dense')
Пример #5
0
 def dense_no_rng(x):
   return nn.Dense(20)(x, name='dense')
Пример #6
0
"""Tests for tensorflow_probability.spinoffs.oryx.core.serialize."""

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import random
import jax.numpy as np
import numpy as onp
from oryx.core import state
from oryx.core.serialize import deserialize
from oryx.core.serialize import serialize
from oryx.experimental import nn


templates = {
    'dense': nn.Dense(200),
    'dense_serial': nn.Dense(200) >> nn.Dense(200),
    'relu': nn.Relu(),
    'dense_relu': nn.Dense(200) >> nn.Relu(),
    'convnet': (
        nn.Reshape((28, 28, 1)) >> nn.Conv(32, (5, 5))
        >> nn.MaxPooling((2, 2), (2, 2))),
    'dropout': nn.Dropout(0.5)
}


class SerializeTest(parameterized.TestCase):

  @parameterized.named_parameters(templates.items())
  def test_serialize(self, template):
    network = state.init(template)(random.PRNGKey(0), state.Shape(784))