Example #1
0
    def test_ppo_loss(self):
        self.rng_key, key1, key2, key3 = jax_random.split(self.rng_key, num=4)

        B, T, A, OBS = 2, 10, 2, (28, 28, 3)  # pylint: disable=invalid-name
        batch_observation_shape = (-1, -1) + OBS

        old_policy_params, _ = ppo.policy_net(
            key1, batch_observation_shape, A,
            [layers.Flatten(num_axis_to_keep=2)])

        new_policy_params, policy_apply = ppo.policy_net(
            key2, batch_observation_shape, A,
            [layers.Flatten(num_axis_to_keep=2)])

        value_params, value_apply = ppo.value_net(
            key3, batch_observation_shape, A,
            [layers.Flatten(num_axis_to_keep=2)])

        # Generate a batch of observations.

        observations = np.random.uniform(size=(B, T + 1) + OBS)
        actions = np.random.randint(0, A, size=(B, T))
        rewards = np.random.uniform(0, 1, size=(B, T))
        mask = np.ones_like(rewards)

        # Just test that this computes at all.
        _ = ppo.ppo_loss(policy_apply, new_policy_params, old_policy_params,
                         value_apply, value_params, observations, actions,
                         rewards, mask)
Example #2
0
def policy_and_value_net(n_actions, n_controls, vocab_size, bottom_layers_fn,
                         two_towers):
    """A policy and value net function."""

    # Layers.

    # Now, with the current logits, one head computes action probabilities and the
    # other computes the value function.
    # NOTE: The LogSoftmax instead of the Softmax because of numerical stability.

    @tl.layer()
    def FlattenControlsIntoTime(x, **unused_kwargs):  # pylint: disable=invalid-name
        """Splits logits for actions in different controls and flattens controls."""
        return np.reshape(x, (x.shape[0], -1, n_actions))

    if vocab_size is None:
        # In continuous policies every element of the output sequence corresponds to
        # an observation.
        n_preds_per_input = n_controls
        kwargs = {}
    else:
        # In discrete policies every element of the output sequence corresponds to
        # a symbol in the discrete representation, and each control takes 1 symbol.
        n_preds_per_input = 1
        kwargs = {"vocab_size": vocab_size}

    if two_towers:
        layers = [
            tl.Dup(),
            tl.Parallel(
                [
                    bottom_layers_fn(**kwargs),
                    tl.Dense(n_preds_per_input * n_actions),
                    FlattenControlsIntoTime(),  # pylint: disable=no-value-for-parameter
                    tl.LogSoftmax()
                ],
                [
                    bottom_layers_fn(**kwargs),
                    tl.Dense(n_preds_per_input),
                    tl.Flatten()
                ],
            )
        ]
    else:
        layers = [
            bottom_layers_fn(**kwargs),
            tl.Dup(),
            tl.Parallel(
                [
                    tl.Dense(n_preds_per_input * n_actions),
                    FlattenControlsIntoTime(),  # pylint: disable=no-value-for-parameter
                    tl.LogSoftmax()
                ],
                [tl.Dense(n_preds_per_input),
                 tl.Flatten()],
            )
        ]
    return tl.Model(layers)
Example #3
0
    def test_collect_trajectories(self):
        observation_shape = (2, 3, 4)
        num_actions = 2
        policy_params, policy_apply = ppo.policy_net(
            self.rng_key,
            (-1, -1) + observation_shape,
            num_actions,
            # flatten except batch and time
            # step dimensions.
            [layers.Flatten(num_axis_to_keep=2)])

        # We'll get done at time-step #5, starting from 0, therefore in 6 steps.
        done_time_step = 5
        env = fake_env.FakeEnv(observation_shape,
                               num_actions,
                               done_time_step=done_time_step)

        num_trajectories = 5
        trajectories = ppo.collect_trajectories(
            env,
            policy_fun=lambda obs: policy_apply(obs, policy_params),
            num_trajectories=num_trajectories,
            policy="categorical-sampling")

        # Number of trajectories is as expected.
        self.assertEqual(num_trajectories, len(trajectories))

        # Shapes of observations, actions and rewards are as expected.
        for observations, actions, rewards in trajectories:
            # observations are one more in number than rewards or actions.
            self.assertEqual((done_time_step + 2, ) + observation_shape,
                             observations.shape)
            self.assertEqual((done_time_step + 1, ), actions.shape)
            self.assertEqual((done_time_step + 1, ), rewards.shape)

        # Test collect using a Policy and Value function.
        pnv_params, pnv_apply = ppo.policy_and_value_net(
            self.rng_key, (-1, -1) + observation_shape, num_actions,
            [layers.Flatten(num_axis_to_keep=2)])

        trajectories = ppo.collect_trajectories(
            env,
            policy_fun=lambda obs: pnv_apply(obs, pnv_params)[0],
            num_trajectories=num_trajectories,
            policy="categorical-sampling")

        # Number of trajectories is as expected.
        self.assertEqual(num_trajectories, len(trajectories))

        # Shapes of observations, actions and rewards are as expected.
        for observations, actions, rewards in trajectories:
            # observations are one more in number than rewards or actions.
            self.assertEqual((done_time_step + 2, ) + observation_shape,
                             observations.shape)
            self.assertEqual((done_time_step + 1, ), actions.shape)
            self.assertEqual((done_time_step + 1, ), rewards.shape)
Example #4
0
def AtariCnn(hidden_sizes=(32, 32), output_size=128):
    # Input's shape = (B, T, H, W, C)
    return tl.Serial(
        tl.Div(divisor=255.0),
        # Have 4 copies of the input, each one shifted to the right by one.
        tl.Branch(
            tl.NoOp(), tl.ShiftRight(),
            tl.Serial(
                tl.ShiftRight(),
                tl.ShiftRight(),
            ), tl.Serial(
                tl.ShiftRight(),
                tl.ShiftRight(),
                tl.ShiftRight(),
            )),
        # Concatenated on the last axis.
        tl.Concatenate(axis=-1),  # (B, T, H, W, 4C)
        tl.Rebatch(tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), 2),
        tl.Relu(),
        tl.Rebatch(tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), 2),
        tl.Relu(),
        tl.Flatten(num_axis_to_keep=2),  # B, T and rest.
        tl.Dense(output_size),
        tl.Relu(),
        # Eventually this is shaped (B, T, output_size)
    )
Example #5
0
def WideResnet(n_blocks=3, d_hidden=64, n_output_classes=10, mode='train'):
    """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    n_blocks: int, number of blocks in a group.
    d_hidden: Dimensionality of the first hidden layer (multiplied later).
    n_output_classes: int, number of distinct output classes.
    mode: Whether we are training or evaluating or doing inference.

  Returns:
    The list of layers comprising a WideResnet model with the given parameters.
  """
    del mode
    return tl.Model(
        tl.ToFloat(),
        tl.Conv(d_hidden, (3, 3), padding='SAME'),
        WideResnetGroup(n_blocks, d_hidden),
        WideResnetGroup(n_blocks, d_hidden * 2, (2, 2)),
        WideResnetGroup(n_blocks, d_hidden * 4, (2, 2)),
        tl.BatchNorm(),
        tl.Relu(),
        tl.AvgPool(pool_size=(8, 8)),
        tl.Flatten(),
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
Example #6
0
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10,
               mode='train'):
  """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    num_blocks: int, number of blocks in a group.
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: int, number of classes to distinguish.
    mode: is it training or eval.

  Returns:
    The WideResnet model with given layer and output sizes.
  """
  del mode
  return tl.Serial(
      tl.Conv(hidden_size, (3, 3), padding='SAME'),
      WideResnetGroup(num_blocks, hidden_size),
      WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)),
      WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)),
      tl.BatchNorm(),
      tl.Relu(),
      tl.AvgPool(pool_size=(8, 8)),
      tl.Flatten(),
      tl.Dense(num_output_classes),
      tl.LogSoftmax()
  )
Example #7
0
def WideResnet(n_blocks=3, widen_factor=1, n_output_classes=10, mode='train'):
    """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    n_blocks: int, number of blocks in a group. total layers = 6n + 4.
    widen_factor: int, widening factor of each group. k=1 is vanilla resnet.
    n_output_classes: int, number of distinct output classes.
    mode: Whether we are training or evaluating or doing inference.

  Returns:
    The list of layers comprising a WideResnet model with the given parameters.
  """
    return tl.Model(
        tl.ToFloat(),
        tl.Conv(16, (3, 3), padding='SAME'),
        WideResnetGroup(n_blocks, 16 * widen_factor, mode=mode),
        WideResnetGroup(n_blocks, 32 * widen_factor, (2, 2), mode=mode),
        WideResnetGroup(n_blocks, 64 * widen_factor, (2, 2), mode=mode),
        tl.BatchNorm(mode=mode),
        tl.Relu(),
        tl.AvgPool(pool_size=(8, 8)),
        tl.Flatten(),
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
Example #8
0
def AtariCnn(hidden_sizes=(32, 32), output_size=128, mode='train'):
    """An Atari CNN."""
    del mode

    # TODO(jonni): Include link to paper?
    # Input shape: (B, T, H, W, C)
    # Output shape: (B, T, output_size)
    return tl.Model(
        tl.ToFloat(),
        tl.Div(divisor=255.0),

        # Set up 4 successive game frames, concatenated on the last axis.
        tl.Dup(),
        tl.Dup(),
        tl.Dup(),
        tl.Parallel(None, _shift_right(1), _shift_right(2), _shift_right(3)),
        tl.Concatenate(n_items=4, axis=-1),  # (B, T, H, W, 4C)
        tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Flatten(n_axes_to_keep=2),  # B, T and rest.
        tl.Dense(output_size),
        tl.Relu(),
    )
Example #9
0
    def test_policy_and_value_net(self):
        observation_shape = (3, 4, 5)
        batch_observation_shape = (1, 1) + observation_shape
        n_actions = 2
        n_controls = 3
        pnv_model = ppo.policy_and_value_net(
            n_controls=n_controls,
            n_actions=n_actions,
            vocab_size=None,
            bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)],
            two_towers=True,
        )
        _, _ = pnv_model.initialize_once(batch_observation_shape, np.float32,
                                         self.rng_key)

        batch = 2
        time_steps = 10
        batch_of_observations = np.random.uniform(size=(batch, time_steps) +
                                                  observation_shape)
        pnv_output = pnv_model(batch_of_observations)

        # Output is a list, first is probab of actions and the next is value output.
        self.assertEqual(2, len(pnv_output))
        self.assertEqual((batch, time_steps * n_controls, n_actions),
                         pnv_output[0].shape)
        self.assertEqual((batch, time_steps * n_controls), pnv_output[1].shape)
Example #10
0
    def test_policy_net(self):
        observation_shape = (3, 4)
        num_actions = 2
        policy_params, policy_apply = ppo.policy_net(
            self.rng_key,
            (-1, -1) + observation_shape,
            num_actions,
            # flatten except batch and time
            # step dimensions.
            [layers.Flatten(num_axis_to_keep=2)])

        # Generate a batch of observations.
        batch = 2
        time_steps = 10
        batch_of_observations = np.random.uniform(size=(batch, time_steps) +
                                                  observation_shape)

        # Apply the policy net on observations
        policy_output = policy_apply(batch_of_observations, policy_params)

        # Verify certain expectations on the output.
        self.assertEqual((batch, time_steps, num_actions), policy_output.shape)

        # Also exp of last axis normalizes to 1, since these are log-probabilities.
        sum_actions = np.sum(np.exp(policy_output), axis=-1)
        self.assertAllClose(np.ones_like(sum_actions), sum_actions)
Example #11
0
def policy_and_value_net(n_actions, n_controls, bottom_layers_fn, two_towers):
    """A policy and value net function."""

    # Layers.

    # Now, with the current logits, one head computes action probabilities and the
    # other computes the value function.
    # NOTE: The LogSoftmax instead of the Softmax because of numerical stability.

    @tl.layer()
    def FlattenControlsIntoTime(x, **unused_kwargs):  # pylint: disable=invalid-name
        """Splits logits for actions in different controls and flattens controls."""
        return np.reshape(x, (x.shape[0], -1, n_actions))

    n_logits = n_controls * n_actions

    if two_towers:
        layers = [
            tl.Dup(),
            tl.Parallel(
                [
                    bottom_layers_fn(),
                    tl.Dense(n_logits),
                    FlattenControlsIntoTime(),  # pylint: disable=no-value-for-parameter
                    tl.LogSoftmax()
                ],
                [bottom_layers_fn(),
                 tl.Dense(n_controls),
                 tl.Flatten()],
            )
        ]
    else:
        layers = [
            bottom_layers_fn(),
            tl.Dup(),
            tl.Parallel(
                [
                    tl.Dense(n_logits),
                    FlattenControlsIntoTime(),  # pylint: disable=no-value-for-parameter
                    tl.LogSoftmax()
                ],
                [tl.Dense(n_controls), tl.Flatten()],
            )
        ]
    return tl.Model(layers)
Example #12
0
def common_layers():
    cur_layers = []
    if FLAGS.flatten_non_batch_time_dims:
        cur_layers = [
            layers.Div(divisor=255.0),
            layers.Flatten(num_axis_to_keep=2)
        ]
    body = [layers.Dense(64), layers.Tanh(), layers.Dense(64), layers.Tanh()]
    return cur_layers + body
Example #13
0
def MLP(num_hidden_layers=2,
        hidden_size=512,
        activation_fn=tl.Relu,
        num_output_classes=10,
        mode="train"):
    """Multi-layer feed-forward neural network with non-linear activations."""
    del mode
    cur_layers = [tl.Flatten()]
    for _ in range(num_hidden_layers):
        cur_layers += [tl.Dense(hidden_size), activation_fn()]
    cur_layers += [tl.Dense(num_output_classes), tl.LogSoftmax()]
    return tl.Serial(*cur_layers)
Example #14
0
def common_layers():
    # TODO(afrozm): Refactor.
    if "NoFrameskip" in FLAGS.env_problem_name:
        return atari_layers()

    cur_layers = []
    if FLAGS.flatten_dims:
        cur_layers = [
            layers.Div(divisor=255.0),
            layers.Flatten(num_axis_to_keep=2)
        ]
    body = [layers.Dense(64), layers.Tanh(), layers.Dense(64), layers.Tanh()]
    return cur_layers + body
Example #15
0
def common_layers():
    cur_layers = []
    if FLAGS.env_name == "Pong-v0":
        cur_layers = [
            layers.Div(divisor=255.0),
            layers.Flatten(num_axis_to_keep=2)
        ]
    return cur_layers + [
        layers.Dense(16),
        layers.Relu(),
        layers.Dense(4),
        layers.Relu()
    ]
Example #16
0
def Resnet50(d_hidden=64, n_output_classes=1001, mode='train'):
    """ResNet.

  Args:
    d_hidden: Dimensionality of the first hidden layer (multiplied later).
    n_output_classes: Number of distinct output classes.
    mode: Whether we are training or evaluating or doing inference.

  Returns:
    The list of layers comprising a ResNet model with the given parameters.
  """
    return tl.Model(
        tl.ToFloat(),
        tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'),
        tl.BatchNorm(mode=mode),
        tl.Relu(),
        tl.MaxPool(pool_size=(3, 3), strides=(2, 2)),
        ConvBlock(3, [d_hidden, d_hidden, 4 * d_hidden], (1, 1), mode=mode),
        IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode),
        IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode),
        ConvBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], (2, 2),
                  mode=mode),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden],
                      mode=mode),
        ConvBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], (2, 2),
                  mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        ConvBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], (2, 2),
                  mode=mode),
        IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden],
                      mode=mode),
        tl.AvgPool(pool_size=(7, 7)),
        tl.Flatten(),
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
Example #17
0
    def test_value_net(self):
        observation_shape = (3, 4, 5)
        num_actions = 2
        value_params, value_apply = ppo.value_net(
            self.rng_key, (-1, -1) + observation_shape, num_actions,
            [layers.Flatten(num_axis_to_keep=2)])
        batch = 2
        time_steps = 10
        batch_of_observations = np.random.uniform(size=(batch, time_steps) +
                                                  observation_shape)
        value_output = value_apply(batch_of_observations, value_params)

        # NOTE: The extra dimension at the end because of Dense(1).
        self.assertEqual((batch, time_steps, 1), value_output.shape)
Example #18
0
def MLP(n_hidden_layers=2,
        d_hidden=512,
        activation_fn=tl.Relu,
        n_output_classes=10,
        mode="train"):
    """A multi-layer feedforward (perceptron) network."""
    del mode

    return tl.Model(
        tl.Flatten(),
        [[tl.Dense(d_hidden), activation_fn()]
         for _ in range(n_hidden_layers)],
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
Example #19
0
def MLP(n_hidden_layers=2,
        d_hidden=512,
        activation_fn=tl.Relu,
        n_output_classes=10,
        mode="train"):
    """Multi-layer feed-forward neural network with non-linear activations."""
    del mode

    return [
        tl.Flatten(),
        [[tl.Dense(d_hidden), activation_fn()]
         for _ in range(n_hidden_layers)],
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    ]
Example #20
0
 def model(mode):
   del mode
   return layers.Serial(
       layers.Parallel(
           layers.Flatten(),  # Observation stack.
           layers.Embedding(d_feature=1, vocab_size=n_actions),  # Action.
       ),
       layers.Concatenate(),
       layers.Dense(n_units=1),
       layers.Dup(),
       layers.Parallel(
           layers.Dense(n_units=obs_shape[1]),  # New observation.
           None,  # Reward.
       )
   )
Example #21
0
    def test_policy_and_value_net(self):
        observation_shape = (3, 4, 5)
        batch_observation_shape = (-1, -1) + observation_shape
        num_actions = 2
        pnv_params, pnv_apply = ppo.policy_and_value_net(
            self.rng_key, batch_observation_shape, num_actions,
            [layers.Flatten(num_axis_to_keep=2)])
        batch = 2
        time_steps = 10
        batch_of_observations = np.random.uniform(size=(batch, time_steps) +
                                                  observation_shape)
        pnv_output = pnv_apply(batch_of_observations, pnv_params)

        # Output is a list, first is probab of actions and the next is value output.
        self.assertEqual(2, len(pnv_output))
        self.assertEqual((batch, time_steps, num_actions), pnv_output[0].shape)
        self.assertEqual((batch, time_steps, 1), pnv_output[1].shape)
Example #22
0
    def test_collect_trajectories_max_timestep(self):
        self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3)
        observation_shape = (2, 3, 4)
        num_actions = 2
        pnv_params, pnv_apply = ppo.policy_and_value_net(
            key1, (-1, -1) + observation_shape, num_actions,
            lambda: [layers.Flatten(num_axis_to_keep=2)])

        def pnv_fun(obs, rng=None):
            rng, r = jax_random.split(rng)
            lp, v = pnv_apply(obs, pnv_params, rng=r)
            return lp, v, rng

        # We'll get done at time-step #5, starting from 0, therefore in 6 steps.
        done_time_step = 5
        env = fake_env.FakeEnv(observation_shape,
                               num_actions,
                               done_time_step=done_time_step)

        num_trajectories = 5

        # Let's collect trajectories only till `max_timestep`.
        max_timestep = 3

        # we're testing when we early stop the trajectory.
        assert max_timestep < done_time_step

        trajectories = ppo.collect_trajectories(
            env,
            policy_fun=pnv_fun,
            num_trajectories=num_trajectories,
            policy="categorical-sampling",
            max_timestep=max_timestep,
            rng=key2)

        # Number of trajectories is as expected.
        self.assertEqual(num_trajectories, len(trajectories))

        # Shapes of observations, actions and rewards are as expected.
        for observations, actions, rewards in trajectories:
            # observations are one more in number than rewards or actions.
            self.assertEqual((max_timestep, ) + observation_shape,
                             observations.shape)
            self.assertEqual((max_timestep - 1, ), actions.shape)
            self.assertEqual((max_timestep - 1, ), rewards.shape)
Example #23
0
  def test_policy_and_value_net(self):
    observation_shape = (3, 4, 5)
    batch_observation_shape = (1, 1) + observation_shape
    n_actions = 2
    pnv_model = ppo.policy_and_value_net(
        n_actions, lambda: [layers.Flatten(n_axes_to_keep=2)], two_towers=True)
    pnv_params, pnv_state = pnv_model.initialize(
        batch_observation_shape, np.float32, self.rng_key)

    batch = 2
    time_steps = 10
    batch_of_observations = np.random.uniform(
        size=(batch, time_steps) + observation_shape)
    pnv_output, _ = pnv_model(batch_of_observations, pnv_params, pnv_state)

    # Output is a list, first is probab of actions and the next is value output.
    self.assertEqual(2, len(pnv_output))
    self.assertEqual((batch, time_steps, n_actions), pnv_output[0].shape)
    self.assertEqual((batch, time_steps, 1), pnv_output[1].shape)
Example #24
0
    def test_collect_trajectories_max_timestep(self):
        observation_shape = (2, 3, 4)
        num_actions = 2
        policy_params, policy_apply = ppo.policy_net(
            self.rng_key,
            (-1, -1) + observation_shape,
            num_actions,
            # flatten except batch and time
            # step dimensions.
            [layers.Flatten(num_axis_to_keep=2)])

        # We'll get done at time-step #5, starting from 0, therefore in 6 steps.
        done_time_step = 5
        env = fake_env.FakeEnv(observation_shape,
                               num_actions,
                               done_time_step=done_time_step)

        num_trajectories = 5

        # Let's collect trajectories only till `max_timestep`.
        max_timestep = 3

        # we're testing when we early stop the trajectory.
        assert max_timestep < done_time_step

        trajectories = ppo.collect_trajectories(env,
                                                policy_apply,
                                                policy_params,
                                                num_trajectories,
                                                policy="categorical-sampling",
                                                max_timestep=max_timestep)

        # Number of trajectories is as expected.
        self.assertEqual(num_trajectories, len(trajectories))

        # Shapes of observations, actions and rewards are as expected.
        for observations, actions, rewards in trajectories:
            # observations are one more in number than rewards or actions.
            self.assertEqual((max_timestep, ) + observation_shape,
                             observations.shape)
            self.assertEqual((max_timestep - 1, ), actions.shape)
            self.assertEqual((max_timestep - 1, ), rewards.shape)
Example #25
0
def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'):
    """An Atari CNN."""
    del mode

    # TODO(jonni): Include link to paper?
    # Input shape: (B, T, H, W, C)
    # Output shape: (B, T, output_size)
    return tl.Model(
        tl.ToFloat(),
        tl.Div(divisor=255.0),

        # Set up n_frames successive game frames, concatenated on the last axis.
        FrameStack(n_frames=n_frames),  # (B, T, H, W, 4C)
        tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Flatten(n_axes_to_keep=2),  # B, T and rest.
        tl.Dense(output_size),
        tl.Relu(),
    )
Example #26
0
def Resnet50(hidden_size=64, num_output_classes=1001, mode='train'):
    """ResNet.

  Args:
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: how many classes to distinguish.
    mode: whether we are training or evaluating or doing inference.

  Returns:
    The ResNet model with the given layer and output sizes.
  """
    del mode
    return tl.Serial(
        tl.Conv(hidden_size, (7, 7), (2, 2),
                'SAME'), tl.BatchNorm(), tl.Relu(),
        tl.MaxPool(pool_size=(3, 3), strides=(2, 2)),
        ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)),
        IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]),
        ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]),
        tl.AvgPool(pool_size=(7, 7)), tl.Flatten(),
        tl.Dense(num_output_classes), tl.LogSoftmax())
Example #27
0
    def test_flatten_n(self):
        input_shape = (29, 87, 10, 20, 30)

        actual_shape = check_layer(self, layers.Flatten(), input_shape)
        self.assertEqual(actual_shape, (29, 87 * 10 * 20 * 30))

        actual_shape = check_layer(self, layers.Flatten(num_axis_to_keep=2),
                                   input_shape)
        self.assertEqual(actual_shape, (29, 87, 10 * 20 * 30))

        actual_shape = check_layer(self, layers.Flatten(num_axis_to_keep=3),
                                   input_shape)
        self.assertEqual(actual_shape, (29, 87, 10, 20 * 30))

        actual_shape = check_layer(self, layers.Flatten(num_axis_to_keep=4),
                                   input_shape)
        self.assertEqual(actual_shape, (29, 87, 10, 20, 30))

        # Not enough dimensions.
        with self.assertRaises(ValueError):
            check_layer(self, layers.Flatten(num_axis_to_keep=5), input_shape)

        with self.assertRaises(ValueError):
            check_layer(self, layers.Flatten(num_axis_to_keep=6), input_shape)
Example #28
0
    def test_combined_loss(self):
        self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3)

        B, T, A, OBS = 2, 10, 2, (28, 28, 3)  # pylint: disable=invalid-name
        batch_observation_shape = (1, 1) + OBS

        net = ppo.policy_and_value_net(
            A, lambda: [layers.Flatten(n_axes_to_keep=2)], two_towers=True)

        old_params, _ = net.initialize(batch_observation_shape, np.float32,
                                       key1)
        new_params, state = net.initialize(batch_observation_shape, np.float32,
                                           key2)

        # Generate a batch of observations.

        observations = np.random.uniform(size=(B, T + 1) + OBS)
        actions = np.random.randint(0, A, size=(B, T))
        rewards = np.random.uniform(0, 1, size=(B, T))
        mask = np.ones_like(rewards)

        # Just test that this computes at all.
        (new_log_probabs,
         value_predictions_new), _ = net(observations, new_params, state)
        (old_log_probabs,
         value_predictions_old), _ = net(observations, old_params, state)

        gamma = 0.99
        lambda_ = 0.95
        epsilon = 0.2
        c1 = 1.0
        c2 = 0.01

        (value_loss_1, _) = ppo.value_loss_given_predictions(
            value_predictions_new,
            rewards,
            mask,
            gamma=gamma,
            value_prediction_old=value_predictions_old,
            epsilon=epsilon)
        (ppo_loss_1, _) = ppo.ppo_loss_given_predictions(new_log_probabs,
                                                         old_log_probabs,
                                                         value_predictions_old,
                                                         actions,
                                                         rewards,
                                                         mask,
                                                         gamma=gamma,
                                                         lambda_=lambda_,
                                                         epsilon=epsilon)

        (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _,
         state) = (ppo.combined_loss(new_params,
                                     old_log_probabs,
                                     value_predictions_old,
                                     net,
                                     observations,
                                     actions,
                                     rewards,
                                     mask,
                                     gamma=gamma,
                                     lambda_=lambda_,
                                     epsilon=epsilon,
                                     c1=c1,
                                     c2=c2,
                                     state=state))

        # Test that these compute at all and are self consistent.
        self.assertGreater(entropy_bonus, 0.0)
        self.assertNear(value_loss_1, value_loss_2, 1e-6)
        self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6)
        self.assertNear(
            combined_loss,
            ppo_loss_2 + (c1 * value_loss_2) - (c2 * entropy_bonus), 1e-6)
Example #29
0
    def test_combined_loss(self):
        self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3)

        B, T, A, OBS = 2, 10, 2, (28, 28, 3)  # pylint: disable=invalid-name
        batch_observation_shape = (-1, -1) + OBS

        old_params, _ = ppo.policy_and_value_net(
            key1, batch_observation_shape, A,
            [layers.Flatten(num_axis_to_keep=2)])

        new_params, net_apply = ppo.policy_and_value_net(
            key2, batch_observation_shape, A,
            [layers.Flatten(num_axis_to_keep=2)])

        # Generate a batch of observations.

        observations = np.random.uniform(size=(B, T + 1) + OBS)
        actions = np.random.randint(0, A, size=(B, T))
        rewards = np.random.uniform(0, 1, size=(B, T))
        mask = np.ones_like(rewards)

        # Just test that this computes at all.
        new_log_probabs, _ = net_apply(observations, new_params)
        old_log_probabs, value_predictions = net_apply(observations,
                                                       old_params)

        gamma = 0.99
        lambda_ = 0.95
        epsilon = 0.2
        c1 = 1.0
        c2 = 0.01

        value_loss_1 = ppo.value_loss_given_predictions(value_predictions,
                                                        rewards,
                                                        mask,
                                                        gamma=gamma)
        ppo_loss_1 = ppo.ppo_loss_given_predictions(new_log_probabs,
                                                    old_log_probabs,
                                                    value_predictions,
                                                    actions,
                                                    rewards,
                                                    mask,
                                                    gamma=gamma,
                                                    lambda_=lambda_,
                                                    epsilon=epsilon)

        (combined_loss, ppo_loss_2, value_loss_2,
         entropy_bonus) = (ppo.combined_loss(new_params,
                                             old_params,
                                             net_apply,
                                             observations,
                                             actions,
                                             rewards,
                                             mask,
                                             gamma=gamma,
                                             lambda_=lambda_,
                                             epsilon=epsilon,
                                             c1=c1,
                                             c2=c2))

        # Test that these compute at all and are self consistent.
        self.assertEqual(0.0, entropy_bonus)
        self.assertNear(value_loss_1, value_loss_2, 1e-6)
        self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6)
        self.assertNear(combined_loss, ppo_loss_2 + (c1 * value_loss_2), 1e-6)
Example #30
0
    def test_collect_trajectories(self):
        self.rng_key, key1, key2, key3, key4 = jax_random.split(self.rng_key,
                                                                num=5)
        observation_shape = (2, 3, 4)
        num_actions = 2
        policy_params, policy_apply = ppo.policy_net(
            key1,
            (-1, -1) + observation_shape,
            num_actions,
            # flatten except batch and time
            # step dimensions.
            [layers.Flatten(num_axis_to_keep=2)])

        # We'll get done at time-step #5, starting from 0, therefore in 6 steps.
        done_time_step = 5
        env = fake_env.FakeEnv(observation_shape,
                               num_actions,
                               done_time_step=done_time_step)

        def policy_fun(obs, rng=None):
            rng, r = jax_random.split(rng)
            return policy_apply(obs, policy_params, rng=r), (), rng

        num_trajectories = 5
        trajectories = ppo.collect_trajectories(
            env,
            policy_fun=policy_fun,
            num_trajectories=num_trajectories,
            policy="categorical-sampling",
            rng=key2)

        # Number of trajectories is as expected.
        self.assertEqual(num_trajectories, len(trajectories))

        # Shapes of observations, actions and rewards are as expected.
        for observations, actions, rewards in trajectories:
            # observations are one more in number than rewards or actions.
            self.assertEqual((done_time_step + 2, ) + observation_shape,
                             observations.shape)
            self.assertEqual((done_time_step + 1, ), actions.shape)
            self.assertEqual((done_time_step + 1, ), rewards.shape)

        # Test collect using a Policy and Value function.
        pnv_params, pnv_apply = ppo.policy_and_value_net(
            key3, (-1, -1) + observation_shape, num_actions,
            lambda: [layers.Flatten(num_axis_to_keep=2)])

        def pnv_fun(obs, rng=None):
            rng, r = jax_random.split(rng)
            lp, v = pnv_apply(obs, pnv_params, rng=r)
            return lp, v, rng

        trajectories = ppo.collect_trajectories(
            env,
            policy_fun=pnv_fun,
            num_trajectories=num_trajectories,
            policy="categorical-sampling",
            rng=key4)

        # Number of trajectories is as expected.
        self.assertEqual(num_trajectories, len(trajectories))

        # Shapes of observations, actions and rewards are as expected.
        for observations, actions, rewards in trajectories:
            # observations are one more in number than rewards or actions.
            self.assertEqual((done_time_step + 2, ) + observation_shape,
                             observations.shape)
            self.assertEqual((done_time_step + 1, ), actions.shape)
            self.assertEqual((done_time_step + 1, ), rewards.shape)