def Resnet50(hidden_size=64, num_output_classes=1001): """ResNet. Args: hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: how many classes to distinguish. Returns: The ResNet model with the given layer and output sizes. """ return stax.serial( stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu, stax.MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)), IdentityBlock(3, [hidden_size, hidden_size]), IdentityBlock(3, [hidden_size, hidden_size]), ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size], (2, 2)), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size], (2, 2)), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), stax.AvgPool((7, 7)), stax.Flatten(), stax.Dense(num_output_classes), stax.LogSoftmax)
def main(argv): del argv if FLAGS.jax_debug_nans: config.update("jax_debug_nans", True) bottom_layers = common_stax_layers() if FLAGS.env_name == "Pong-v0": bottom_layers = [stax.Div(255.0), stax.Flatten(2)] + bottom_layers optimizer_fun = functools.partial(ppo.optimizer_fun, step_size=FLAGS.learning_rate) ppo.training_loop( env_name=FLAGS.env_name, epochs=FLAGS.epochs, policy_net_fun=functools.partial( ppo.policy_net, bottom_layers=bottom_layers), value_net_fun=functools.partial( ppo.value_net, bottom_layers=bottom_layers), policy_optimizer_fun=optimizer_fun, value_optimizer_fun=optimizer_fun, batch_size=FLAGS.batch_size, num_optimizer_steps=FLAGS.num_optimizer_steps, boundary=FLAGS.boundary, random_seed=FLAGS.random_seed)
def MLP(num_hidden_layers=2, hidden_size=512, activation_fn=stax.Relu, num_output_classes=10): layers = [stax.Flatten()] layers += [stax.Dense(hidden_size), activation_fn] * num_hidden_layers layers += [stax.Dense(num_output_classes), stax.LogSoftmax] return stax.serial(*layers)
def MLP(num_hidden_layers=2, hidden_size=512, activation_fn=stax.Relu, num_output_classes=10, mode="train"): """Multi-layer feed-forward neural network with non-linear activations.""" del mode layers = [stax.Flatten()] for _ in range(num_hidden_layers): layers += [stax.Dense(hidden_size), activation_fn()] layers += [stax.Dense(num_output_classes), stax.LogSoftmax()] return stax.Serial(*layers)
def test_flatten_n(self): input_shape = (29, 87, 10, 20, 30) actual_shape = check_staxlayer(self, stax.Flatten(), input_shape) self.assertEqual(actual_shape, (29, 87 * 10 * 20 * 30)) actual_shape = check_staxlayer(self, stax.Flatten(num_axis_to_keep=2), input_shape) self.assertEqual(actual_shape, (29, 87, 10 * 20 * 30)) actual_shape = check_staxlayer(self, stax.Flatten(num_axis_to_keep=3), input_shape) self.assertEqual(actual_shape, (29, 87, 10, 20 * 30)) actual_shape = check_staxlayer(self, stax.Flatten(num_axis_to_keep=4), input_shape) self.assertEqual(actual_shape, (29, 87, 10, 20, 30)) # Not enough dimensions. with self.assertRaises(ValueError): check_staxlayer(self, stax.Flatten(num_axis_to_keep=5), input_shape) with self.assertRaises(ValueError): check_staxlayer(self, stax.Flatten(num_axis_to_keep=6), input_shape)
def main(argv): del argv logging.set_verbosity(FLAGS.log_level) bottom_layers = common_stax_layers() if FLAGS.env_name == "Pong-v0": bottom_layers = [stax.Div(255.0), stax.Flatten(2)] + bottom_layers ppo.training_loop( env_name=FLAGS.env_name, epochs=FLAGS.epochs, policy_net_fun=functools.partial( ppo.policy_net, bottom_layers=bottom_layers), value_net_fun=functools.partial( ppo.value_net, bottom_layers=bottom_layers), batch_size=FLAGS.batch_size, boundary=FLAGS.boundary, random_seed=FLAGS.random_seed)
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: num_blocks: int, number of blocks in a group. hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: int, number of classes to distinguish. Returns: The WideResnet model with given layer and output sizes. """ return stax.serial(stax.Conv(hidden_size, (3, 3), padding='SAME'), WideResnetGroup(num_blocks, hidden_size), WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)), WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)), stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)), stax.Flatten(), stax.Dense(num_output_classes), stax.LogSoftmax)
def Resnet50(hidden_size=64, num_output_classes=1001, mode='train'): """ResNet. Args: hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: how many classes to distinguish. mode: whether we are training or evaluating or doing inference. Returns: The ResNet model with the given layer and output sizes. """ del mode return stax.Serial( stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu(), stax.MaxPool(pool_size=(3, 3), strides=(2, 2)), ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)), IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]), IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]), ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size], (2, 2)), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size], (2, 2)), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]), stax.AvgPool(pool_size=(7, 7)), stax.Flatten(), stax.Dense(num_output_classes), stax.LogSoftmax())
def common_stax_layers(): layers = [] if FLAGS.env_name == "Pong-v0": layers = [stax.Div(divisor=255.0), stax.Flatten(num_axis_to_keep=2)] return layers + [stax.Dense(16), stax.Relu(), stax.Dense(4), stax.Relu()]