Example #1
0
    def test_extract_inner_model(self):
        vocab_size = 3

        inner_model = transformer.TransformerLM(vocab_size=vocab_size,
                                                d_model=2,
                                                d_ff=2,
                                                n_layers=0)
        obs_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        serialized_model = serialization_utils.SerializedModel(
            inner_model,
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )

        obs_sig = shapes.ShapeDtype((1, 2))
        act_sig = shapes.ShapeDtype((1, 1))
        (weights,
         state) = serialized_model.init(input_signature=(obs_sig, act_sig,
                                                         obs_sig, obs_sig), )
        (inner_weights,
         inner_state) = map(serialization_utils.extract_inner_model,
                            (weights, state))
        inner_model(np.array([[0]]), weights=inner_weights, state=inner_state)
Example #2
0
    def test_training_loop_cartpole_serialized_init_from_world_model(
            self, two_towers):
        gin.bind_parameter('BoxSpaceSerializer.precision', 1)

        transformer_kwargs = {
            'd_model': 1,
            'd_ff': 1,
            'n_layers': 1,
            'n_heads': 1,
            'max_len': 128,
        }
        obs_serializer = space_serializer.create(gym.spaces.MultiDiscrete(
            [2, 2]),
                                                 vocab_size=4)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=4)
        model_fn = lambda mode: serialization_utils.SerializedModel(  # pylint: disable=g-long-lambda
            seq_model=models.TransformerLM(
                mode=mode, vocab_size=4, **transformer_kwargs),
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )
        with self.tmp_dir() as output_dir:
            model_dir = os.path.join(output_dir, 'model')

            def dummy_stream(_):
                while True:
                    obs = np.zeros((1, 2, 2), dtype=np.int32)
                    act = np.zeros((1, 1), dtype=np.int32)
                    mask = np.ones_like(obs)
                    yield (obs, act, obs, mask)

            inputs = trax_inputs.Inputs(train_stream=dummy_stream,
                                        eval_stream=dummy_stream)
            inputs._input_shape = ((2, 2), (1, ))  # pylint: disable=protected-access
            inputs._input_dtype = (np.int32, np.int32)  # pylint: disable=protected-access

            # Initialize a world model checkpoint by running the trainer.
            trainer_lib.train(
                model_dir,
                model=model_fn,
                inputs=inputs,
                steps=1,
                eval_steps=1,
                has_weights=True,
            )

            policy_dir = os.path.join(output_dir, 'policy')
            trainer = self._make_trainer(
                train_env=self.get_wrapped_env('CartPole-v0', 2),
                eval_env=self.get_wrapped_env('CartPole-v0', 2),
                output_dir=policy_dir,
                model=functools.partial(models.TransformerDecoder,
                                        **transformer_kwargs),
                policy_and_value_vocab_size=4,
                init_policy_from_world_model_output_dir=model_dir,
                policy_and_value_two_towers=two_towers,
            )
            trainer.training_loop(n_epochs=2)
    def test_serialized_model_continuous(self):
        precision = 3
        gin.bind_parameter('BoxSpaceSerializer.precision', precision)

        vocab_size = 32
        obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0, 0]]])
        act = np.array([[0, 1, 0]])
        mask = np.array([[1, 1, 1, 0]])

        obs_serializer = space_serializer.create(gym.spaces.Box(shape=(2, ),
                                                                low=-2,
                                                                high=2),
                                                 vocab_size=vocab_size)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        serialized_model = serialization_utils.SerializedModel(
            TestModel(extra_dim=vocab_size),  # pylint: disable=no-value-for-parameter
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )

        example = (obs, act, obs, mask)
        serialized_model.init(shapes.signature(example))

        (obs_logits, obs_repr, weights) = serialized_model(example)
        self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size, ))
        self.assertEqual(obs_repr.shape,
                         (1, obs.shape[1], obs.shape[2] * precision))
        self.assertEqual(obs_repr.shape, weights.shape)
    def test_serialized_model_extracts_seq_model_weights_and_state(self):
        vocab_size = 3

        seq_model_fn = functools.partial(
            transformer.TransformerLM,
            vocab_size=vocab_size,
            d_model=2,
            d_ff=2,
            n_layers=0,
        )
        seq_model = seq_model_fn(mode='eval')
        obs_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        serialized_model = serialization_utils.SerializedModel(
            seq_model_fn,
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )

        obs_sig = shapes.ShapeDtype((1, 2))
        act_sig = shapes.ShapeDtype((1, 1))
        serialized_model.init(input_signature=(obs_sig, act_sig, obs_sig,
                                               obs_sig))
        seq_model.weights = serialized_model.seq_model_weights
        seq_model.state = serialized_model.seq_model_state
        # Run the model to check if the weights and state have correct structure.
        seq_model(jnp.array([[0]]))
    def test_serialized_model_discrete(self):
        vocab_size = 3
        obs = np.array([[[0, 1], [1, 1], [1, 0], [0, 0]]])
        act = np.array([[1, 0, 0]])
        mask = np.array([[1, 1, 1, 0]])

        test_model_inputs = []

        # pylint: disable=invalid-name
        def TestModelSavingInputs():
            def f(inputs):
                # Save the inputs for a later check.
                test_model_inputs.append(inputs)
                # Change type to np.float32 and add the logit dimension.
                return jnp.broadcast_to(
                    inputs.astype(np.float32)[:, :, None],
                    inputs.shape + (vocab_size, ))

            return layers_base.Fn('TestModelSavingInputs', f)
            # pylint: enable=invalid-name

        obs_serializer = space_serializer.create(gym.spaces.MultiDiscrete(
            [2, 2]),
                                                 vocab_size=vocab_size)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        serialized_model = serialization_utils.SerializedModel(
            TestModelSavingInputs(),  # pylint: disable=no-value-for-parameter
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )

        example = (obs, act, obs, mask)
        serialized_model.init(shapes.signature(example))

        (obs_logits, obs_repr, weights) = serialized_model(example)
        # Check that the model has been called with the correct input.
        np.testing.assert_array_equal(
            # The model is called multiple times for determining shapes etc.
            # Check the last saved input - that should be the actual concrete array
            # calculated during the forward pass.
            test_model_inputs[-1],
            # Should be serialized observations and actions interleaved.
            [[0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0]],
        )
        # Check the output shape.
        self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size, ))
        # Check that obs_logits are the same as obs_repr, just broadcasted over the
        # logit dimension.
        np.testing.assert_array_equal(np.min(obs_logits, axis=-1), obs_repr)
        np.testing.assert_array_equal(np.max(obs_logits, axis=-1), obs_repr)
        # Check that the observations are correct.
        np.testing.assert_array_equal(obs_repr, obs)
        # Check weights.
        np.testing.assert_array_equal(
            weights,
            [[[1., 1.], [1., 1.], [1., 1.], [0., 0.]]],
        )
Example #6
0
    def __init__(self,
                 model,
                 reward_fn,
                 done_fn,
                 vocab_size,
                 max_trajectory_length,
                 observation_space,
                 action_space,
                 significance_decay=1.0,
                 **kwargs):
        """Initializes the env.

    Args:
      model: trax model to use for simulation. It's assumed to take keyword
        arguments vocab_size and mode, where vocab_size is the number of symbols
        in the vocabulary and mode is either 'train' or 'eval'.

      reward_fn: Function (previous_observation, current_observation) -> reward.
      done_fn: Function (previous_observation, current_observation) -> done.
      vocab_size: (int) Number of symbols in the vocabulary.
      max_trajectory_length: (int) Maximum length of a trajectory unrolled from
        the model.
      observation_space: (gym.Space) Observation space.
      action_space: (gym.Space) Action space.
      significance_decay: (float) Decay for training weights of progressively
        less significant symbols in the representation.
      **kwargs: (dict) Keyword arguments passed to the base class.
    """
        self._reward_fn = reward_fn
        self._done_fn = done_fn
        self._vocab_size = vocab_size
        self._max_trajectory_length = max_trajectory_length
        self._significance_decay = significance_decay
        self._steps = None
        self._observation_space = None
        self._action_space = None
        self._last_observations = None

        self._obs_serializer = space_serializer.create(observation_space,
                                                       self._vocab_size)
        self._action_serializer = space_serializer.create(
            action_space, self._vocab_size)
        self._obs_repr_length = self._obs_serializer.representation_length
        self._act_repr_length = self._action_serializer.representation_length
        self._step_repr_length = self._obs_repr_length + self._act_repr_length

        # We assume that the model takes vocab_size as an argument (e.g.
        # TransformerLM).
        model = functools.partial(model, vocab_size=vocab_size)
        super(SerializedSequenceSimulatedEnvProblem,
              self).__init__(model=model,
                             observation_space=observation_space,
                             action_space=action_space,
                             **kwargs)
Example #7
0
 def _init_serialization(self, vocab_size):
     obs_serializer = space_serializer.create(
         self.train_env.observation_space, vocab_size=vocab_size)
     act_serializer = space_serializer.create(self.train_env.action_space,
                                              vocab_size=vocab_size)
     repr_length = (obs_serializer.representation_length +
                    act_serializer.representation_length) * (
                        self._max_timestep + 1)
     return {
         "observation_serializer": obs_serializer,
         "action_serializer": act_serializer,
         "representation_length": repr_length,
     }
Example #8
0
File: ppo.py Project: koz4k2/trax
def init_serialization(vocab_size, observation_space, action_space,
                       n_timesteps):
    """Initializes serialization keyword arguments."""
    obs_serializer = space_serializer.create(observation_space,
                                             vocab_size=vocab_size)
    act_serializer = space_serializer.create(action_space,
                                             vocab_size=vocab_size)
    repr_length = (obs_serializer.representation_length +
                   act_serializer.representation_length) * n_timesteps
    return {
        'observation_serializer': obs_serializer,
        'action_serializer': act_serializer,
        'representation_length': repr_length,
    }
Example #9
0
 def test_significance_map(self):
     gin.bind_parameter('BoxSpaceSerializer.precision', 3)
     significance_map = serialization_utils.significance_map(
         observation_serializer=space_serializer.create(gym.spaces.Box(
             low=0, high=1, shape=(2, )),
                                                        vocab_size=2),
         action_serializer=space_serializer.create(
             gym.spaces.MultiDiscrete(nvec=[2, 2]), vocab_size=2),
         representation_length=20,
     )
     np.testing.assert_array_equal(
         significance_map,
         # obs1, act1, obs2, act2, obs3 cut after 4th symbol.
         [0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0],
     )
Example #10
0
 def test_rewards_to_actions_map(self):
     rewards = onp.array([1, 2, 3])
     r2a_map = serialization_utils.rewards_to_actions_map(
         observation_serializer=space_serializer.create(
             gym.spaces.MultiDiscrete(nvec=[2, 2, 2]), vocab_size=2),
         action_serializer=space_serializer.create(
             gym.spaces.MultiDiscrete(nvec=[2, 2]), vocab_size=2),
         n_timesteps=len(rewards),
         representation_length=16,
     )
     broadcast_rewards = onp.dot(rewards, r2a_map)
     onp.testing.assert_array_equal(
         broadcast_rewards,
         # obs1, act1, obs2, act2, obs3 cut after 1st symbol.
         [0, 0, 0, 1, 1, 0, 0, 0, 2, 2, 0, 0, 0, 3, 3, 0],
     )
Example #11
0
def make_serialized_model(seq_model, space, vocab_size):
    srl = space_serializer.create(space, vocab_size)
    return serialization_utils.SerializedModel(
        functools.partial(seq_model, vocab_size=vocab_size),
        observation_serializer=srl,
        action_serializer=srl,
        significance_decay=0.7,
    )
Example #12
0
 def setUp(self):
     super(SerializationTest, self).setUp()
     self._serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                vocab_size=2)
     self._repr_length = 100
     self._serialization_utils_kwargs = {
         'observation_serializer': self._serializer,
         'action_serializer': self._serializer,
         'representation_length': self._repr_length,
     }
 def setUp(self):
     super().setUp()
     self._serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                vocab_size=2)
     self._repr_length = 100
     self._serialization_utils_kwargs = {
         'observation_serializer': self._serializer,
         'action_serializer': self._serializer,
         'representation_length': self._repr_length,
     }
     test_utils.ensure_flag('test_tmpdir')
Example #14
0
def wrap_policy(seq_model, observation_space, action_space, vocab_size):  # pylint: disable=invalid-name
  """Wraps a sequence model in either RawPolicy or SerializedPolicy.

  Args:
    seq_model: Trax sequence model.
    observation_space: Gym observation space.
    action_space: Gym action space.
    vocab_size: Either the number of symbols for a serialized policy, or None.

  Returns:
    RawPolicy if vocab_size is None, else SerializedPolicy.
  """
  (n_controls, n_actions) = analyze_action_space(action_space)
  if vocab_size is None:
    policy_wrapper = RawPolicy
  else:
    obs_serializer = space_serializer.create(observation_space, vocab_size)
    act_serializer = space_serializer.create(action_space, vocab_size)
    policy_wrapper = functools.partial(SerializedPolicy,
                                       observation_serializer=obs_serializer,
                                       action_serializer=act_serializer)
  return policy_wrapper(seq_model, n_controls, n_actions)
 def _make_space_and_serializer(
     self,
     low=-10,
     high=10,
     shape=(2, ),
     # Weird vocab_size to test that it doesn't only work with powers of 2.
     vocab_size=257,
     # Enough precision to represent float32s accurately.
     precision=4,
 ):
     gin.bind_parameter("BoxSpaceSerializer.precision", precision)
     space = gym.spaces.Box(low=low, high=high, shape=shape)
     serializer = space_serializer.create(space, vocab_size=vocab_size)
     return (space, serializer)
Example #16
0
    def test_serialized_model_continuous(self):
        precision = 3
        gin.bind_parameter('BoxSpaceSerializer.precision', precision)

        vocab_size = 32
        obs = onp.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0, 0]]])
        act = onp.array([[0, 1, 0]])
        mask = onp.array([[1, 1, 1, 0]])

        @layers_base.layer()
        def TestModel(inputs, **unused_kwargs):
            # Change type to onp.float32 and add the logit dimension.
            return np.broadcast_to(
                inputs.astype(onp.float32)[:, :, None],
                inputs.shape + (vocab_size, ))

        obs_serializer = space_serializer.create(gym.spaces.Box(shape=(2, ),
                                                                low=-2,
                                                                high=2),
                                                 vocab_size=vocab_size)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        serialized_model = serialization_utils.SerializedModel(
            TestModel(),  # pylint: disable=no-value-for-parameter
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )

        example = (obs, act, obs, mask)
        serialized_model.init(shapes.signature(example))
        (obs_logits, obs_repr, weights) = serialized_model(example)
        self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size, ))
        self.assertEqual(obs_repr.shape,
                         (1, obs.shape[1], obs.shape[2] * precision))
        self.assertEqual(obs_repr.shape, weights.shape)
 def setUp(self):
     super(MultiDiscreteSpaceSerializerTest, self).setUp()
     self._space = gym.spaces.MultiDiscrete(nvec=[2, 2])
     self._serializer = space_serializer.create(self._space, vocab_size=2)
 def setUp(self):
     super(DiscreteSpaceSerializerTest, self).setUp()
     self._space = gym.spaces.Discrete(n=2)
     self._serializer = space_serializer.create(self._space, vocab_size=2)