Exemple #1
0
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    """WideResnet convolutational block."""
    main = stax.Serial(stax.BatchNorm(), stax.Relu(),
                       stax.Conv(channels, (3, 3), strides, padding='SAME'),
                       stax.BatchNorm(), stax.Relu(),
                       stax.Conv(channels, (3, 3), padding='SAME'))
    shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
        channels, (3, 3), strides, padding='SAME')
    return stax.Serial(stax.FanOut(), stax.Parallel(main, shortcut),
                       stax.FanInSum())
Exemple #2
0
def IdentityBlock(kernel_size, filters):
    """ResNet identical size block."""
    ks = kernel_size
    filters1, filters2, filters3 = filters
    main = stax.Serial(stax.Conv(filters1, (1, 1)), stax.BatchNorm(),
                       stax.Relu(),
                       stax.Conv(filters2, (ks, ks), padding='SAME'),
                       stax.BatchNorm(), stax.Relu(),
                       stax.Conv(filters3, (1, 1)), stax.BatchNorm())
    return stax.Serial(stax.FanOut(), stax.Parallel(main, stax.Identity()),
                       stax.FanInSum(), stax.Relu())
Exemple #3
0
def ConvBlock(kernel_size, filters, strides):
    """ResNet convolutional striding block."""
    ks = kernel_size
    filters1, filters2, filters3 = filters
    main = stax.Serial(stax.Conv(filters1, (1, 1), strides), stax.BatchNorm(),
                       stax.Relu(),
                       stax.Conv(filters2, (ks, ks), padding='SAME'),
                       stax.BatchNorm(), stax.Relu(),
                       stax.Conv(filters3, (1, 1)), stax.BatchNorm())
    shortcut = stax.Serial(stax.Conv(filters3, (1, 1), strides),
                           stax.BatchNorm())
    return stax.Serial(stax.FanOut(), stax.Parallel(main, shortcut),
                       stax.FanInSum(), stax.Relu())
Exemple #4
0
 def test_dense_param_sharing(self):
     model1 = stax.Serial(stax.Dense(32), stax.Dense(32))
     layer = stax.Dense(32)
     model2 = stax.Serial(layer, layer)
     init_fun1, _ = model1
     init_fun2, _ = model2
     rng = random.get_prng(0)
     _, params1 = init_fun1(rng, [-1, 32])
     _, params2 = init_fun2(rng, [-1, 32])
     # The first parameters have 2 kernels of size (32, 32).
     self.assertEqual((32, 32), params1[0][0].shape)
     self.assertEqual((32, 32), params1[1][0].shape)
     # The second parameters have 1 kernel of size (32, 32) and an empty dict.
     self.assertEqual((32, 32), params2[0][0].shape)
     self.assertEqual((), params2[1])
Exemple #5
0
def MLP(num_hidden_layers=2,
        hidden_size=512,
        activation_fn=stax.Relu,
        num_output_classes=10,
        mode="train"):
    """Multi-layer feed-forward neural network with non-linear activations."""
    del mode
    layers = [stax.Flatten()]
    for _ in range(num_hidden_layers):
        layers += [stax.Dense(hidden_size), activation_fn()]
    layers += [stax.Dense(num_output_classes), stax.LogSoftmax()]
    return stax.Serial(*layers)
Exemple #6
0
def policy_and_value_net(rng_key,
                         batch_observations_shape,
                         num_actions,
                         bottom_layers=None):
  """A policy and value net function."""

  # Layers.
  layers = []
  if bottom_layers is not None:
    layers.extend(bottom_layers)

  # Now, with the current logits, one head computes action probabilities and the
  # other computes the value function.
  layers.extend([stax.FanOut(), stax.Parallel(
      stax.Serial(stax.Dense(num_actions), stax.Softmax()),
      stax.Dense(1)
  )])

  net_init, net_apply = stax.Serial(layers)

  _, net_params = net_init(rng_key, batch_observations_shape)
  return net_params, net_apply
Exemple #7
0
def policy_net(rng_key,
               batch_observations_shape,
               num_actions,
               bottom_layers=None):
  """A policy net function."""
  # Use the bottom_layers as the bottom part of the network and just add the
  # required layers on top of it.
  if bottom_layers is None:
    bottom_layers = []
  bottom_layers.extend([stax.Dense(num_actions), stax.Softmax()])

  net_init, net_apply = stax.Serial(bottom_layers)

  _, net_params = net_init(rng_key, batch_observations_shape)
  return net_params, net_apply
Exemple #8
0
def value_net(rng_key,
              batch_observations_shape,
              num_actions,
              bottom_layers=None):
  """A value net function."""
  del num_actions

  if bottom_layers is None:
    bottom_layers = []
  bottom_layers.extend([
      stax.Dense(1),
  ])

  net_init, net_apply = stax.Serial(bottom_layers)

  _, net_params = net_init(rng_key, batch_observations_shape)
  return net_params, net_apply
Exemple #9
0
def WideResnet(num_blocks=3,
               hidden_size=64,
               num_output_classes=10,
               mode='train'):
    """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    num_blocks: int, number of blocks in a group.
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: int, number of classes to distinguish.
    mode: is it training or eval.

  Returns:
    The WideResnet model with given layer and output sizes.
  """
    del mode
    return stax.Serial(stax.Conv(hidden_size, (3, 3), padding='SAME'),
                       WideResnetGroup(num_blocks, hidden_size),
                       WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)),
                       WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)),
                       stax.BatchNorm(), stax.Relu(),
                       stax.AvgPool(pool_size=(8, 8)), stax.Flatten(),
                       stax.Dense(num_output_classes), stax.LogSoftmax())
Exemple #10
0
def Resnet50(hidden_size=64, num_output_classes=1001, mode='train'):
    """ResNet.

  Args:
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: how many classes to distinguish.
    mode: whether we are training or evaluating or doing inference.

  Returns:
    The ResNet model with the given layer and output sizes.
  """
    del mode
    return stax.Serial(
        stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(),
        stax.Relu(), stax.MaxPool(pool_size=(3, 3), strides=(2, 2)),
        ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)),
        IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]),
        ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]),
        stax.AvgPool(pool_size=(7, 7)), stax.Flatten(),
        stax.Dense(num_output_classes), stax.LogSoftmax())
Exemple #11
0
def WideResnetGroup(n, channels, strides=(1, 1)):
    blocks = []
    blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
    for _ in range(n - 1):
        blocks += [WideResnetBlock(channels, (1, 1))]
    return stax.Serial(*blocks)