def test_ppo_loss(self): self.rng_key, key1, key2, key3 = jax_random.split(self.rng_key, num=4) B, T, A, OBS = 2, 10, 2, (28, 28, 3) # pylint: disable=invalid-name batch_observation_shape = (-1, -1) + OBS old_policy_params, _ = ppo.policy_net( key1, batch_observation_shape, A, [layers.Flatten(num_axis_to_keep=2)]) new_policy_params, policy_apply = ppo.policy_net( key2, batch_observation_shape, A, [layers.Flatten(num_axis_to_keep=2)]) value_params, value_apply = ppo.value_net( key3, batch_observation_shape, A, [layers.Flatten(num_axis_to_keep=2)]) # Generate a batch of observations. observations = np.random.uniform(size=(B, T + 1) + OBS) actions = np.random.randint(0, A, size=(B, T)) rewards = np.random.uniform(0, 1, size=(B, T)) mask = np.ones_like(rewards) # Just test that this computes at all. _ = ppo.ppo_loss(policy_apply, new_policy_params, old_policy_params, value_apply, value_params, observations, actions, rewards, mask)
def policy_and_value_net(n_actions, n_controls, vocab_size, bottom_layers_fn, two_towers): """A policy and value net function.""" # Layers. # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. @tl.layer() def FlattenControlsIntoTime(x, **unused_kwargs): # pylint: disable=invalid-name """Splits logits for actions in different controls and flattens controls.""" return np.reshape(x, (x.shape[0], -1, n_actions)) if vocab_size is None: # In continuous policies every element of the output sequence corresponds to # an observation. n_preds_per_input = n_controls kwargs = {} else: # In discrete policies every element of the output sequence corresponds to # a symbol in the discrete representation, and each control takes 1 symbol. n_preds_per_input = 1 kwargs = {"vocab_size": vocab_size} if two_towers: layers = [ tl.Dup(), tl.Parallel( [ bottom_layers_fn(**kwargs), tl.Dense(n_preds_per_input * n_actions), FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter tl.LogSoftmax() ], [ bottom_layers_fn(**kwargs), tl.Dense(n_preds_per_input), tl.Flatten() ], ) ] else: layers = [ bottom_layers_fn(**kwargs), tl.Dup(), tl.Parallel( [ tl.Dense(n_preds_per_input * n_actions), FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter tl.LogSoftmax() ], [tl.Dense(n_preds_per_input), tl.Flatten()], ) ] return tl.Model(layers)
def test_collect_trajectories(self): observation_shape = (2, 3, 4) num_actions = 2 policy_params, policy_apply = ppo.policy_net( self.rng_key, (-1, -1) + observation_shape, num_actions, # flatten except batch and time # step dimensions. [layers.Flatten(num_axis_to_keep=2)]) # We'll get done at time-step #5, starting from 0, therefore in 6 steps. done_time_step = 5 env = fake_env.FakeEnv(observation_shape, num_actions, done_time_step=done_time_step) num_trajectories = 5 trajectories = ppo.collect_trajectories( env, policy_fun=lambda obs: policy_apply(obs, policy_params), num_trajectories=num_trajectories, policy="categorical-sampling") # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((done_time_step + 2, ) + observation_shape, observations.shape) self.assertEqual((done_time_step + 1, ), actions.shape) self.assertEqual((done_time_step + 1, ), rewards.shape) # Test collect using a Policy and Value function. pnv_params, pnv_apply = ppo.policy_and_value_net( self.rng_key, (-1, -1) + observation_shape, num_actions, [layers.Flatten(num_axis_to_keep=2)]) trajectories = ppo.collect_trajectories( env, policy_fun=lambda obs: pnv_apply(obs, pnv_params)[0], num_trajectories=num_trajectories, policy="categorical-sampling") # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((done_time_step + 2, ) + observation_shape, observations.shape) self.assertEqual((done_time_step + 1, ), actions.shape) self.assertEqual((done_time_step + 1, ), rewards.shape)
def AtariCnn(hidden_sizes=(32, 32), output_size=128): # Input's shape = (B, T, H, W, C) return tl.Serial( tl.Div(divisor=255.0), # Have 4 copies of the input, each one shifted to the right by one. tl.Branch( tl.NoOp(), tl.ShiftRight(), tl.Serial( tl.ShiftRight(), tl.ShiftRight(), ), tl.Serial( tl.ShiftRight(), tl.ShiftRight(), tl.ShiftRight(), )), # Concatenated on the last axis. tl.Concatenate(axis=-1), # (B, T, H, W, 4C) tl.Rebatch(tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), 2), tl.Relu(), tl.Rebatch(tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), 2), tl.Relu(), tl.Flatten(num_axis_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), # Eventually this is shaped (B, T, output_size) )
def WideResnet(n_blocks=3, d_hidden=64, n_output_classes=10, mode='train'): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: n_blocks: int, number of blocks in a group. d_hidden: Dimensionality of the first hidden layer (multiplied later). n_output_classes: int, number of distinct output classes. mode: Whether we are training or evaluating or doing inference. Returns: The list of layers comprising a WideResnet model with the given parameters. """ del mode return tl.Model( tl.ToFloat(), tl.Conv(d_hidden, (3, 3), padding='SAME'), WideResnetGroup(n_blocks, d_hidden), WideResnetGroup(n_blocks, d_hidden * 2, (2, 2)), WideResnetGroup(n_blocks, d_hidden * 4, (2, 2)), tl.BatchNorm(), tl.Relu(), tl.AvgPool(pool_size=(8, 8)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10, mode='train'): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: num_blocks: int, number of blocks in a group. hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: int, number of classes to distinguish. mode: is it training or eval. Returns: The WideResnet model with given layer and output sizes. """ del mode return tl.Serial( tl.Conv(hidden_size, (3, 3), padding='SAME'), WideResnetGroup(num_blocks, hidden_size), WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)), WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)), tl.BatchNorm(), tl.Relu(), tl.AvgPool(pool_size=(8, 8)), tl.Flatten(), tl.Dense(num_output_classes), tl.LogSoftmax() )
def WideResnet(n_blocks=3, widen_factor=1, n_output_classes=10, mode='train'): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: n_blocks: int, number of blocks in a group. total layers = 6n + 4. widen_factor: int, widening factor of each group. k=1 is vanilla resnet. n_output_classes: int, number of distinct output classes. mode: Whether we are training or evaluating or doing inference. Returns: The list of layers comprising a WideResnet model with the given parameters. """ return tl.Model( tl.ToFloat(), tl.Conv(16, (3, 3), padding='SAME'), WideResnetGroup(n_blocks, 16 * widen_factor, mode=mode), WideResnetGroup(n_blocks, 32 * widen_factor, (2, 2), mode=mode), WideResnetGroup(n_blocks, 64 * widen_factor, (2, 2), mode=mode), tl.BatchNorm(mode=mode), tl.Relu(), tl.AvgPool(pool_size=(8, 8)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def AtariCnn(hidden_sizes=(32, 32), output_size=128, mode='train'): """An Atari CNN.""" del mode # TODO(jonni): Include link to paper? # Input shape: (B, T, H, W, C) # Output shape: (B, T, output_size) return tl.Model( tl.ToFloat(), tl.Div(divisor=255.0), # Set up 4 successive game frames, concatenated on the last axis. tl.Dup(), tl.Dup(), tl.Dup(), tl.Parallel(None, _shift_right(1), _shift_right(2), _shift_right(3)), tl.Concatenate(n_items=4, axis=-1), # (B, T, H, W, 4C) tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Flatten(n_axes_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), )
def test_policy_and_value_net(self): observation_shape = (3, 4, 5) batch_observation_shape = (1, 1) + observation_shape n_actions = 2 n_controls = 3 pnv_model = ppo.policy_and_value_net( n_controls=n_controls, n_actions=n_actions, vocab_size=None, bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)], two_towers=True, ) _, _ = pnv_model.initialize_once(batch_observation_shape, np.float32, self.rng_key) batch = 2 time_steps = 10 batch_of_observations = np.random.uniform(size=(batch, time_steps) + observation_shape) pnv_output = pnv_model(batch_of_observations) # Output is a list, first is probab of actions and the next is value output. self.assertEqual(2, len(pnv_output)) self.assertEqual((batch, time_steps * n_controls, n_actions), pnv_output[0].shape) self.assertEqual((batch, time_steps * n_controls), pnv_output[1].shape)
def test_policy_net(self): observation_shape = (3, 4) num_actions = 2 policy_params, policy_apply = ppo.policy_net( self.rng_key, (-1, -1) + observation_shape, num_actions, # flatten except batch and time # step dimensions. [layers.Flatten(num_axis_to_keep=2)]) # Generate a batch of observations. batch = 2 time_steps = 10 batch_of_observations = np.random.uniform(size=(batch, time_steps) + observation_shape) # Apply the policy net on observations policy_output = policy_apply(batch_of_observations, policy_params) # Verify certain expectations on the output. self.assertEqual((batch, time_steps, num_actions), policy_output.shape) # Also exp of last axis normalizes to 1, since these are log-probabilities. sum_actions = np.sum(np.exp(policy_output), axis=-1) self.assertAllClose(np.ones_like(sum_actions), sum_actions)
def policy_and_value_net(n_actions, n_controls, bottom_layers_fn, two_towers): """A policy and value net function.""" # Layers. # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. @tl.layer() def FlattenControlsIntoTime(x, **unused_kwargs): # pylint: disable=invalid-name """Splits logits for actions in different controls and flattens controls.""" return np.reshape(x, (x.shape[0], -1, n_actions)) n_logits = n_controls * n_actions if two_towers: layers = [ tl.Dup(), tl.Parallel( [ bottom_layers_fn(), tl.Dense(n_logits), FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter tl.LogSoftmax() ], [bottom_layers_fn(), tl.Dense(n_controls), tl.Flatten()], ) ] else: layers = [ bottom_layers_fn(), tl.Dup(), tl.Parallel( [ tl.Dense(n_logits), FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter tl.LogSoftmax() ], [tl.Dense(n_controls), tl.Flatten()], ) ] return tl.Model(layers)
def common_layers(): cur_layers = [] if FLAGS.flatten_non_batch_time_dims: cur_layers = [ layers.Div(divisor=255.0), layers.Flatten(num_axis_to_keep=2) ] body = [layers.Dense(64), layers.Tanh(), layers.Dense(64), layers.Tanh()] return cur_layers + body
def MLP(num_hidden_layers=2, hidden_size=512, activation_fn=tl.Relu, num_output_classes=10, mode="train"): """Multi-layer feed-forward neural network with non-linear activations.""" del mode cur_layers = [tl.Flatten()] for _ in range(num_hidden_layers): cur_layers += [tl.Dense(hidden_size), activation_fn()] cur_layers += [tl.Dense(num_output_classes), tl.LogSoftmax()] return tl.Serial(*cur_layers)
def common_layers(): # TODO(afrozm): Refactor. if "NoFrameskip" in FLAGS.env_problem_name: return atari_layers() cur_layers = [] if FLAGS.flatten_dims: cur_layers = [ layers.Div(divisor=255.0), layers.Flatten(num_axis_to_keep=2) ] body = [layers.Dense(64), layers.Tanh(), layers.Dense(64), layers.Tanh()] return cur_layers + body
def common_layers(): cur_layers = [] if FLAGS.env_name == "Pong-v0": cur_layers = [ layers.Div(divisor=255.0), layers.Flatten(num_axis_to_keep=2) ] return cur_layers + [ layers.Dense(16), layers.Relu(), layers.Dense(4), layers.Relu() ]
def Resnet50(d_hidden=64, n_output_classes=1001, mode='train'): """ResNet. Args: d_hidden: Dimensionality of the first hidden layer (multiplied later). n_output_classes: Number of distinct output classes. mode: Whether we are training or evaluating or doing inference. Returns: The list of layers comprising a ResNet model with the given parameters. """ return tl.Model( tl.ToFloat(), tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'), tl.BatchNorm(mode=mode), tl.Relu(), tl.MaxPool(pool_size=(3, 3), strides=(2, 2)), ConvBlock(3, [d_hidden, d_hidden, 4 * d_hidden], (1, 1), mode=mode), IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode), IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode), ConvBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], (2, 2), mode=mode), IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), ConvBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], (2, 2), mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), ConvBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], (2, 2), mode=mode), IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], mode=mode), IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], mode=mode), tl.AvgPool(pool_size=(7, 7)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def test_value_net(self): observation_shape = (3, 4, 5) num_actions = 2 value_params, value_apply = ppo.value_net( self.rng_key, (-1, -1) + observation_shape, num_actions, [layers.Flatten(num_axis_to_keep=2)]) batch = 2 time_steps = 10 batch_of_observations = np.random.uniform(size=(batch, time_steps) + observation_shape) value_output = value_apply(batch_of_observations, value_params) # NOTE: The extra dimension at the end because of Dense(1). self.assertEqual((batch, time_steps, 1), value_output.shape)
def MLP(n_hidden_layers=2, d_hidden=512, activation_fn=tl.Relu, n_output_classes=10, mode="train"): """A multi-layer feedforward (perceptron) network.""" del mode return tl.Model( tl.Flatten(), [[tl.Dense(d_hidden), activation_fn()] for _ in range(n_hidden_layers)], tl.Dense(n_output_classes), tl.LogSoftmax(), )
def MLP(n_hidden_layers=2, d_hidden=512, activation_fn=tl.Relu, n_output_classes=10, mode="train"): """Multi-layer feed-forward neural network with non-linear activations.""" del mode return [ tl.Flatten(), [[tl.Dense(d_hidden), activation_fn()] for _ in range(n_hidden_layers)], tl.Dense(n_output_classes), tl.LogSoftmax(), ]
def model(mode): del mode return layers.Serial( layers.Parallel( layers.Flatten(), # Observation stack. layers.Embedding(d_feature=1, vocab_size=n_actions), # Action. ), layers.Concatenate(), layers.Dense(n_units=1), layers.Dup(), layers.Parallel( layers.Dense(n_units=obs_shape[1]), # New observation. None, # Reward. ) )
def test_policy_and_value_net(self): observation_shape = (3, 4, 5) batch_observation_shape = (-1, -1) + observation_shape num_actions = 2 pnv_params, pnv_apply = ppo.policy_and_value_net( self.rng_key, batch_observation_shape, num_actions, [layers.Flatten(num_axis_to_keep=2)]) batch = 2 time_steps = 10 batch_of_observations = np.random.uniform(size=(batch, time_steps) + observation_shape) pnv_output = pnv_apply(batch_of_observations, pnv_params) # Output is a list, first is probab of actions and the next is value output. self.assertEqual(2, len(pnv_output)) self.assertEqual((batch, time_steps, num_actions), pnv_output[0].shape) self.assertEqual((batch, time_steps, 1), pnv_output[1].shape)
def test_collect_trajectories_max_timestep(self): self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3) observation_shape = (2, 3, 4) num_actions = 2 pnv_params, pnv_apply = ppo.policy_and_value_net( key1, (-1, -1) + observation_shape, num_actions, lambda: [layers.Flatten(num_axis_to_keep=2)]) def pnv_fun(obs, rng=None): rng, r = jax_random.split(rng) lp, v = pnv_apply(obs, pnv_params, rng=r) return lp, v, rng # We'll get done at time-step #5, starting from 0, therefore in 6 steps. done_time_step = 5 env = fake_env.FakeEnv(observation_shape, num_actions, done_time_step=done_time_step) num_trajectories = 5 # Let's collect trajectories only till `max_timestep`. max_timestep = 3 # we're testing when we early stop the trajectory. assert max_timestep < done_time_step trajectories = ppo.collect_trajectories( env, policy_fun=pnv_fun, num_trajectories=num_trajectories, policy="categorical-sampling", max_timestep=max_timestep, rng=key2) # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((max_timestep, ) + observation_shape, observations.shape) self.assertEqual((max_timestep - 1, ), actions.shape) self.assertEqual((max_timestep - 1, ), rewards.shape)
def test_policy_and_value_net(self): observation_shape = (3, 4, 5) batch_observation_shape = (1, 1) + observation_shape n_actions = 2 pnv_model = ppo.policy_and_value_net( n_actions, lambda: [layers.Flatten(n_axes_to_keep=2)], two_towers=True) pnv_params, pnv_state = pnv_model.initialize( batch_observation_shape, np.float32, self.rng_key) batch = 2 time_steps = 10 batch_of_observations = np.random.uniform( size=(batch, time_steps) + observation_shape) pnv_output, _ = pnv_model(batch_of_observations, pnv_params, pnv_state) # Output is a list, first is probab of actions and the next is value output. self.assertEqual(2, len(pnv_output)) self.assertEqual((batch, time_steps, n_actions), pnv_output[0].shape) self.assertEqual((batch, time_steps, 1), pnv_output[1].shape)
def test_collect_trajectories_max_timestep(self): observation_shape = (2, 3, 4) num_actions = 2 policy_params, policy_apply = ppo.policy_net( self.rng_key, (-1, -1) + observation_shape, num_actions, # flatten except batch and time # step dimensions. [layers.Flatten(num_axis_to_keep=2)]) # We'll get done at time-step #5, starting from 0, therefore in 6 steps. done_time_step = 5 env = fake_env.FakeEnv(observation_shape, num_actions, done_time_step=done_time_step) num_trajectories = 5 # Let's collect trajectories only till `max_timestep`. max_timestep = 3 # we're testing when we early stop the trajectory. assert max_timestep < done_time_step trajectories = ppo.collect_trajectories(env, policy_apply, policy_params, num_trajectories, policy="categorical-sampling", max_timestep=max_timestep) # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((max_timestep, ) + observation_shape, observations.shape) self.assertEqual((max_timestep - 1, ), actions.shape) self.assertEqual((max_timestep - 1, ), rewards.shape)
def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'): """An Atari CNN.""" del mode # TODO(jonni): Include link to paper? # Input shape: (B, T, H, W, C) # Output shape: (B, T, output_size) return tl.Model( tl.ToFloat(), tl.Div(divisor=255.0), # Set up n_frames successive game frames, concatenated on the last axis. FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Flatten(n_axes_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), )
def Resnet50(hidden_size=64, num_output_classes=1001, mode='train'): """ResNet. Args: hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: how many classes to distinguish. mode: whether we are training or evaluating or doing inference. Returns: The ResNet model with the given layer and output sizes. """ del mode return tl.Serial( tl.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), tl.BatchNorm(), tl.Relu(), tl.MaxPool(pool_size=(3, 3), strides=(2, 2)), ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)), IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]), IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]), ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size], (2, 2)), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size], (2, 2)), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]), tl.AvgPool(pool_size=(7, 7)), tl.Flatten(), tl.Dense(num_output_classes), tl.LogSoftmax())
def test_flatten_n(self): input_shape = (29, 87, 10, 20, 30) actual_shape = check_layer(self, layers.Flatten(), input_shape) self.assertEqual(actual_shape, (29, 87 * 10 * 20 * 30)) actual_shape = check_layer(self, layers.Flatten(num_axis_to_keep=2), input_shape) self.assertEqual(actual_shape, (29, 87, 10 * 20 * 30)) actual_shape = check_layer(self, layers.Flatten(num_axis_to_keep=3), input_shape) self.assertEqual(actual_shape, (29, 87, 10, 20 * 30)) actual_shape = check_layer(self, layers.Flatten(num_axis_to_keep=4), input_shape) self.assertEqual(actual_shape, (29, 87, 10, 20, 30)) # Not enough dimensions. with self.assertRaises(ValueError): check_layer(self, layers.Flatten(num_axis_to_keep=5), input_shape) with self.assertRaises(ValueError): check_layer(self, layers.Flatten(num_axis_to_keep=6), input_shape)
def test_combined_loss(self): self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3) B, T, A, OBS = 2, 10, 2, (28, 28, 3) # pylint: disable=invalid-name batch_observation_shape = (1, 1) + OBS net = ppo.policy_and_value_net( A, lambda: [layers.Flatten(n_axes_to_keep=2)], two_towers=True) old_params, _ = net.initialize(batch_observation_shape, np.float32, key1) new_params, state = net.initialize(batch_observation_shape, np.float32, key2) # Generate a batch of observations. observations = np.random.uniform(size=(B, T + 1) + OBS) actions = np.random.randint(0, A, size=(B, T)) rewards = np.random.uniform(0, 1, size=(B, T)) mask = np.ones_like(rewards) # Just test that this computes at all. (new_log_probabs, value_predictions_new), _ = net(observations, new_params, state) (old_log_probabs, value_predictions_old), _ = net(observations, old_params, state) gamma = 0.99 lambda_ = 0.95 epsilon = 0.2 c1 = 1.0 c2 = 0.01 (value_loss_1, _) = ppo.value_loss_given_predictions( value_predictions_new, rewards, mask, gamma=gamma, value_prediction_old=value_predictions_old, epsilon=epsilon) (ppo_loss_1, _) = ppo.ppo_loss_given_predictions(new_log_probabs, old_log_probabs, value_predictions_old, actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon) (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _, state) = (ppo.combined_loss(new_params, old_log_probabs, value_predictions_old, net, observations, actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon, c1=c1, c2=c2, state=state)) # Test that these compute at all and are self consistent. self.assertGreater(entropy_bonus, 0.0) self.assertNear(value_loss_1, value_loss_2, 1e-6) self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6) self.assertNear( combined_loss, ppo_loss_2 + (c1 * value_loss_2) - (c2 * entropy_bonus), 1e-6)
def test_combined_loss(self): self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3) B, T, A, OBS = 2, 10, 2, (28, 28, 3) # pylint: disable=invalid-name batch_observation_shape = (-1, -1) + OBS old_params, _ = ppo.policy_and_value_net( key1, batch_observation_shape, A, [layers.Flatten(num_axis_to_keep=2)]) new_params, net_apply = ppo.policy_and_value_net( key2, batch_observation_shape, A, [layers.Flatten(num_axis_to_keep=2)]) # Generate a batch of observations. observations = np.random.uniform(size=(B, T + 1) + OBS) actions = np.random.randint(0, A, size=(B, T)) rewards = np.random.uniform(0, 1, size=(B, T)) mask = np.ones_like(rewards) # Just test that this computes at all. new_log_probabs, _ = net_apply(observations, new_params) old_log_probabs, value_predictions = net_apply(observations, old_params) gamma = 0.99 lambda_ = 0.95 epsilon = 0.2 c1 = 1.0 c2 = 0.01 value_loss_1 = ppo.value_loss_given_predictions(value_predictions, rewards, mask, gamma=gamma) ppo_loss_1 = ppo.ppo_loss_given_predictions(new_log_probabs, old_log_probabs, value_predictions, actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon) (combined_loss, ppo_loss_2, value_loss_2, entropy_bonus) = (ppo.combined_loss(new_params, old_params, net_apply, observations, actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon, c1=c1, c2=c2)) # Test that these compute at all and are self consistent. self.assertEqual(0.0, entropy_bonus) self.assertNear(value_loss_1, value_loss_2, 1e-6) self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6) self.assertNear(combined_loss, ppo_loss_2 + (c1 * value_loss_2), 1e-6)
def test_collect_trajectories(self): self.rng_key, key1, key2, key3, key4 = jax_random.split(self.rng_key, num=5) observation_shape = (2, 3, 4) num_actions = 2 policy_params, policy_apply = ppo.policy_net( key1, (-1, -1) + observation_shape, num_actions, # flatten except batch and time # step dimensions. [layers.Flatten(num_axis_to_keep=2)]) # We'll get done at time-step #5, starting from 0, therefore in 6 steps. done_time_step = 5 env = fake_env.FakeEnv(observation_shape, num_actions, done_time_step=done_time_step) def policy_fun(obs, rng=None): rng, r = jax_random.split(rng) return policy_apply(obs, policy_params, rng=r), (), rng num_trajectories = 5 trajectories = ppo.collect_trajectories( env, policy_fun=policy_fun, num_trajectories=num_trajectories, policy="categorical-sampling", rng=key2) # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((done_time_step + 2, ) + observation_shape, observations.shape) self.assertEqual((done_time_step + 1, ), actions.shape) self.assertEqual((done_time_step + 1, ), rewards.shape) # Test collect using a Policy and Value function. pnv_params, pnv_apply = ppo.policy_and_value_net( key3, (-1, -1) + observation_shape, num_actions, lambda: [layers.Flatten(num_axis_to_keep=2)]) def pnv_fun(obs, rng=None): rng, r = jax_random.split(rng) lp, v = pnv_apply(obs, pnv_params, rng=r) return lp, v, rng trajectories = ppo.collect_trajectories( env, policy_fun=pnv_fun, num_trajectories=num_trajectories, policy="categorical-sampling", rng=key4) # Number of trajectories is as expected. self.assertEqual(num_trajectories, len(trajectories)) # Shapes of observations, actions and rewards are as expected. for observations, actions, rewards in trajectories: # observations are one more in number than rewards or actions. self.assertEqual((done_time_step + 2, ) + observation_shape, observations.shape) self.assertEqual((done_time_step + 1, ), actions.shape) self.assertEqual((done_time_step + 1, ), rewards.shape)