예제 #1
0
    def __init__(self,
                 model,
                 reward_fn,
                 done_fn,
                 vocab_size,
                 max_trajectory_length,
                 observation_space,
                 action_space,
                 significance_decay=1.0,
                 **kwargs):
        """Initializes the env.

    Args:
      model: TRAX model to use for simulation. It's assumed to take keyword
        arguments vocab_size and mode, where vocab_size is the number of symbols
        in the vocabulary and mode is either "train" or "eval".

      reward_fn: Function (previous_observation, current_observation) -> reward.
      done_fn: Function (previous_observation, current_observation) -> done.
      vocab_size: (int) Number of symbols in the vocabulary.
      max_trajectory_length: (int) Maximum length of a trajectory unrolled from
        the model.
      observation_space: (gym.Space) Observation space.
      action_space: (gym.Space) Action space.
      significance_decay: (float) Decay for training weights of progressively
        less significant symbols in the representation.
      **kwargs: (dict) Keyword arguments passed to the base class.
    """
        self._reward_fn = reward_fn
        self._done_fn = done_fn
        self._vocab_size = vocab_size
        self._max_trajectory_length = max_trajectory_length
        self._significance_decay = significance_decay
        self._history = None
        self._steps = None
        self._observation_space = None
        self._action_space = None
        self._last_observations = None

        self._obs_serializer = space_serializer.create(observation_space,
                                                       self._vocab_size)
        self._action_serializer = space_serializer.create(
            action_space, self._vocab_size)
        self._obs_repr_length = self._obs_serializer.representation_length
        self._action_repr_length = self._action_serializer.representation_length
        self._step_repr_length = self._obs_repr_length + self._action_repr_length

        # We assume that the model takes vocab_size as an argument (e.g.
        # TransformerLM).
        model = functools.partial(model, vocab_size=vocab_size)
        super(SerializedSequenceSimulatedEnvProblem,
              self).__init__(model=model,
                             observation_space=observation_space,
                             action_space=action_space,
                             **kwargs)
 def _init_serialization(self, vocab_size):
     obs_serializer = space_serializer.create(
         self.train_env.observation_space, vocab_size=vocab_size)
     act_serializer = space_serializer.create(self.train_env.action_space,
                                              vocab_size=vocab_size)
     repr_length = (obs_serializer.representation_length +
                    act_serializer.representation_length) * (
                        self._max_timestep + 1)
     return {
         "observation_serializer": obs_serializer,
         "action_serializer": act_serializer,
         "representation_length": repr_length,
     }
 def test_significance_map(self):
   gin.bind_parameter("BoxSpaceSerializer.precision", 3)
   significance_map = serialization_utils.significance_map(
       observation_serializer=space_serializer.create(
           gym.spaces.Box(low=0, high=1, shape=(2,)), vocab_size=2
       ),
       action_serializer=space_serializer.create(
           gym.spaces.MultiDiscrete(nvec=[2, 2]), vocab_size=2
       ),
       representation_length=20,
   )
   np.testing.assert_array_equal(
       significance_map,
       # obs1, act1, obs2, act2, obs3 cut after 4th symbol.
       [0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0],
   )
예제 #4
0
 def _make_space_and_serializer(self, low=-10, high=10, shape=(2,)):
   # Enough precision to represent float32s accurately.
   gin.bind_parameter("BoxSpaceSerializer.precision", 4)
   space = gym.spaces.Box(low=low, high=high, shape=shape)
   serializer = space_serializer.create(
       space,
       # Weird vocab_size to test that it doesn't only work with powers of 2.
       vocab_size=257)
   return (space, serializer)
예제 #5
0
 def setUp(self):
     super(BoxSpaceSerializerTest, self).setUp()
     # Enough precision to represent float32s accurately.
     gin.bind_parameter("BoxSpaceSerializer.precision", 4)
     self._space = gym.spaces.Box(low=-10, high=10, shape=(2, ))
     self._serializer = space_serializer.create(
         self._space,
         # Weird vocab_size to test that it doesn't only work with powers of 2.
         vocab_size=257)
 def test_rewards_to_actions_map(self):
   rewards = np.array([1, 2, 3])
   r2a_map = serialization_utils.rewards_to_actions_map(
       observation_serializer=space_serializer.create(
           gym.spaces.MultiDiscrete(nvec=[2, 2, 2]), vocab_size=2
       ),
       action_serializer=space_serializer.create(
           gym.spaces.MultiDiscrete(nvec=[2, 2]), vocab_size=2
       ),
       n_timesteps=len(rewards),
       representation_length=16,
   )
   broadcast_rewards = np.dot(rewards, r2a_map)
   np.testing.assert_array_equal(
       broadcast_rewards,
       # obs1, act1, obs2, act2, obs3 cut after 1st symbol.
       [0, 0, 0, 1, 1, 0, 0, 0, 2, 2, 0, 0, 0, 3, 3, 0],
   )
 def initialize_environments(self, batch_size=1, **kwargs):
   """Initializes the environments."""
   self._obs_serializer = space_serializer.create(
       self.observation_space, self._vocab_size)
   self._action_serializer = space_serializer.create(
       self.action_space, self._vocab_size)
   self._obs_repr_length = self._obs_serializer.representation_length
   self._action_repr_length = self._action_serializer.representation_length
   self._step_repr_length = self._obs_repr_length + self._action_repr_length
   self._history = np.zeros((
       batch_size,
       self._max_trajectory_length * self._step_repr_length
   ), dtype=np.int32)
   self._steps = np.zeros(batch_size, dtype=np.int32)
   self._last_observations = np.full(
       (batch_size,) + self._observation_space.shape, np.nan)
   return super(
       SerializedSequenceSimulatedEnvProblem, self
   ).initialize_environments(batch_size=batch_size, **kwargs)
 def setUp(self):
   super(SerializationTest, self).setUp()
   self._serializer = space_serializer.create(
       gym.spaces.Discrete(2), vocab_size=2
   )
   self._repr_length = 100
   self._serialization_utils_kwargs = {
       "observation_serializer": self._serializer,
       "action_serializer": self._serializer,
       "representation_length": self._repr_length,
   }
예제 #9
0
 def setUp(self):
   super(DiscreteSpaceSerializerTest, self).setUp()
   self._space = gym.spaces.Discrete(n=2)
   self._serializer = space_serializer.create(self._space, vocab_size=2)
 def setUp(self):
     super(MultiDiscreteSpaceSerializerTest, self).setUp()
     self._space = gym.spaces.MultiDiscrete(nvec=[2, 2])
     self._serializer = space_serializer.create(self._space, vocab_size=2)