Ejemplo n.º 1
0
def Resnet50(hidden_size=64, num_output_classes=1001):
    """ResNet.

  Args:
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: how many classes to distinguish.

  Returns:
    The ResNet model with the given layer and output sizes.
  """
    return stax.serial(
        stax.Conv(hidden_size, (7, 7), (2, 2),
                  'SAME'), stax.BatchNorm(), stax.Relu,
        stax.MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)),
        IdentityBlock(3, [hidden_size, hidden_size]),
        IdentityBlock(3, [hidden_size, hidden_size]),
        ConvBlock(3,
                  [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size],
                  (2, 2)), IdentityBlock(3,
                                         [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size],
                  (2, 2)), IdentityBlock(3,
                                         [8 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]),
        stax.AvgPool((7, 7)), stax.Flatten(), stax.Dense(num_output_classes),
        stax.LogSoftmax)
Ejemplo n.º 2
0
def main(argv):
  del argv

  if FLAGS.jax_debug_nans:
    config.update("jax_debug_nans", True)

  bottom_layers = common_stax_layers()

  if FLAGS.env_name == "Pong-v0":
    bottom_layers = [stax.Div(255.0), stax.Flatten(2)] + bottom_layers

  optimizer_fun = functools.partial(ppo.optimizer_fun,
                                    step_size=FLAGS.learning_rate)

  ppo.training_loop(
      env_name=FLAGS.env_name,
      epochs=FLAGS.epochs,
      policy_net_fun=functools.partial(
          ppo.policy_net, bottom_layers=bottom_layers),
      value_net_fun=functools.partial(
          ppo.value_net, bottom_layers=bottom_layers),
      policy_optimizer_fun=optimizer_fun,
      value_optimizer_fun=optimizer_fun,
      batch_size=FLAGS.batch_size,
      num_optimizer_steps=FLAGS.num_optimizer_steps,
      boundary=FLAGS.boundary,
      random_seed=FLAGS.random_seed)
Ejemplo n.º 3
0
def MLP(num_hidden_layers=2,
        hidden_size=512,
        activation_fn=stax.Relu,
        num_output_classes=10):
    layers = [stax.Flatten()]
    layers += [stax.Dense(hidden_size), activation_fn] * num_hidden_layers
    layers += [stax.Dense(num_output_classes), stax.LogSoftmax]
    return stax.serial(*layers)
Ejemplo n.º 4
0
def MLP(num_hidden_layers=2,
        hidden_size=512,
        activation_fn=stax.Relu,
        num_output_classes=10,
        mode="train"):
    """Multi-layer feed-forward neural network with non-linear activations."""
    del mode
    layers = [stax.Flatten()]
    for _ in range(num_hidden_layers):
        layers += [stax.Dense(hidden_size), activation_fn()]
    layers += [stax.Dense(num_output_classes), stax.LogSoftmax()]
    return stax.Serial(*layers)
Ejemplo n.º 5
0
    def test_flatten_n(self):
        input_shape = (29, 87, 10, 20, 30)

        actual_shape = check_staxlayer(self, stax.Flatten(), input_shape)
        self.assertEqual(actual_shape, (29, 87 * 10 * 20 * 30))

        actual_shape = check_staxlayer(self, stax.Flatten(num_axis_to_keep=2),
                                       input_shape)
        self.assertEqual(actual_shape, (29, 87, 10 * 20 * 30))

        actual_shape = check_staxlayer(self, stax.Flatten(num_axis_to_keep=3),
                                       input_shape)
        self.assertEqual(actual_shape, (29, 87, 10, 20 * 30))

        actual_shape = check_staxlayer(self, stax.Flatten(num_axis_to_keep=4),
                                       input_shape)
        self.assertEqual(actual_shape, (29, 87, 10, 20, 30))

        # Not enough dimensions.
        with self.assertRaises(ValueError):
            check_staxlayer(self, stax.Flatten(num_axis_to_keep=5),
                            input_shape)

        with self.assertRaises(ValueError):
            check_staxlayer(self, stax.Flatten(num_axis_to_keep=6),
                            input_shape)
Ejemplo n.º 6
0
def main(argv):
  del argv
  logging.set_verbosity(FLAGS.log_level)
  bottom_layers = common_stax_layers()

  if FLAGS.env_name == "Pong-v0":
    bottom_layers = [stax.Div(255.0), stax.Flatten(2)] + bottom_layers

  ppo.training_loop(
      env_name=FLAGS.env_name,
      epochs=FLAGS.epochs,
      policy_net_fun=functools.partial(
          ppo.policy_net, bottom_layers=bottom_layers),
      value_net_fun=functools.partial(
          ppo.value_net, bottom_layers=bottom_layers),
      batch_size=FLAGS.batch_size,
      boundary=FLAGS.boundary,
      random_seed=FLAGS.random_seed)
Ejemplo n.º 7
0
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10):
    """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    num_blocks: int, number of blocks in a group.
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: int, number of classes to distinguish.

  Returns:
    The WideResnet model with given layer and output sizes.
  """
    return stax.serial(stax.Conv(hidden_size, (3, 3), padding='SAME'),
                       WideResnetGroup(num_blocks, hidden_size),
                       WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)),
                       WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)),
                       stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)),
                       stax.Flatten(), stax.Dense(num_output_classes),
                       stax.LogSoftmax)
Ejemplo n.º 8
0
def Resnet50(hidden_size=64, num_output_classes=1001, mode='train'):
    """ResNet.

  Args:
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: how many classes to distinguish.
    mode: whether we are training or evaluating or doing inference.

  Returns:
    The ResNet model with the given layer and output sizes.
  """
    del mode
    return stax.Serial(
        stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(),
        stax.Relu(), stax.MaxPool(pool_size=(3, 3), strides=(2, 2)),
        ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)),
        IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]),
        ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]),
        stax.AvgPool(pool_size=(7, 7)), stax.Flatten(),
        stax.Dense(num_output_classes), stax.LogSoftmax())
Ejemplo n.º 9
0
def common_stax_layers():
    layers = []
    if FLAGS.env_name == "Pong-v0":
        layers = [stax.Div(divisor=255.0), stax.Flatten(num_axis_to_keep=2)]
    return layers + [stax.Dense(16), stax.Relu(), stax.Dense(4), stax.Relu()]