def __init__(self, model, reward_fn, done_fn, vocab_size, max_trajectory_length, observation_space, action_space, significance_decay=1.0, **kwargs): """Initializes the env. Args: model: TRAX model to use for simulation. It's assumed to take keyword arguments vocab_size and mode, where vocab_size is the number of symbols in the vocabulary and mode is either "train" or "eval". reward_fn: Function (previous_observation, current_observation) -> reward. done_fn: Function (previous_observation, current_observation) -> done. vocab_size: (int) Number of symbols in the vocabulary. max_trajectory_length: (int) Maximum length of a trajectory unrolled from the model. observation_space: (gym.Space) Observation space. action_space: (gym.Space) Action space. significance_decay: (float) Decay for training weights of progressively less significant symbols in the representation. **kwargs: (dict) Keyword arguments passed to the base class. """ self._reward_fn = reward_fn self._done_fn = done_fn self._vocab_size = vocab_size self._max_trajectory_length = max_trajectory_length self._significance_decay = significance_decay self._history = None self._steps = None self._observation_space = None self._action_space = None self._last_observations = None self._obs_serializer = space_serializer.create(observation_space, self._vocab_size) self._action_serializer = space_serializer.create( action_space, self._vocab_size) self._obs_repr_length = self._obs_serializer.representation_length self._action_repr_length = self._action_serializer.representation_length self._step_repr_length = self._obs_repr_length + self._action_repr_length # We assume that the model takes vocab_size as an argument (e.g. # TransformerLM). model = functools.partial(model, vocab_size=vocab_size) super(SerializedSequenceSimulatedEnvProblem, self).__init__(model=model, observation_space=observation_space, action_space=action_space, **kwargs)
def _init_serialization(self, vocab_size): obs_serializer = space_serializer.create( self.train_env.observation_space, vocab_size=vocab_size) act_serializer = space_serializer.create(self.train_env.action_space, vocab_size=vocab_size) repr_length = (obs_serializer.representation_length + act_serializer.representation_length) * ( self._max_timestep + 1) return { "observation_serializer": obs_serializer, "action_serializer": act_serializer, "representation_length": repr_length, }
def test_significance_map(self): gin.bind_parameter("BoxSpaceSerializer.precision", 3) significance_map = serialization_utils.significance_map( observation_serializer=space_serializer.create( gym.spaces.Box(low=0, high=1, shape=(2,)), vocab_size=2 ), action_serializer=space_serializer.create( gym.spaces.MultiDiscrete(nvec=[2, 2]), vocab_size=2 ), representation_length=20, ) np.testing.assert_array_equal( significance_map, # obs1, act1, obs2, act2, obs3 cut after 4th symbol. [0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0], )
def _make_space_and_serializer(self, low=-10, high=10, shape=(2,)): # Enough precision to represent float32s accurately. gin.bind_parameter("BoxSpaceSerializer.precision", 4) space = gym.spaces.Box(low=low, high=high, shape=shape) serializer = space_serializer.create( space, # Weird vocab_size to test that it doesn't only work with powers of 2. vocab_size=257) return (space, serializer)
def setUp(self): super(BoxSpaceSerializerTest, self).setUp() # Enough precision to represent float32s accurately. gin.bind_parameter("BoxSpaceSerializer.precision", 4) self._space = gym.spaces.Box(low=-10, high=10, shape=(2, )) self._serializer = space_serializer.create( self._space, # Weird vocab_size to test that it doesn't only work with powers of 2. vocab_size=257)
def test_rewards_to_actions_map(self): rewards = np.array([1, 2, 3]) r2a_map = serialization_utils.rewards_to_actions_map( observation_serializer=space_serializer.create( gym.spaces.MultiDiscrete(nvec=[2, 2, 2]), vocab_size=2 ), action_serializer=space_serializer.create( gym.spaces.MultiDiscrete(nvec=[2, 2]), vocab_size=2 ), n_timesteps=len(rewards), representation_length=16, ) broadcast_rewards = np.dot(rewards, r2a_map) np.testing.assert_array_equal( broadcast_rewards, # obs1, act1, obs2, act2, obs3 cut after 1st symbol. [0, 0, 0, 1, 1, 0, 0, 0, 2, 2, 0, 0, 0, 3, 3, 0], )
def initialize_environments(self, batch_size=1, **kwargs): """Initializes the environments.""" self._obs_serializer = space_serializer.create( self.observation_space, self._vocab_size) self._action_serializer = space_serializer.create( self.action_space, self._vocab_size) self._obs_repr_length = self._obs_serializer.representation_length self._action_repr_length = self._action_serializer.representation_length self._step_repr_length = self._obs_repr_length + self._action_repr_length self._history = np.zeros(( batch_size, self._max_trajectory_length * self._step_repr_length ), dtype=np.int32) self._steps = np.zeros(batch_size, dtype=np.int32) self._last_observations = np.full( (batch_size,) + self._observation_space.shape, np.nan) return super( SerializedSequenceSimulatedEnvProblem, self ).initialize_environments(batch_size=batch_size, **kwargs)
def setUp(self): super(SerializationTest, self).setUp() self._serializer = space_serializer.create( gym.spaces.Discrete(2), vocab_size=2 ) self._repr_length = 100 self._serialization_utils_kwargs = { "observation_serializer": self._serializer, "action_serializer": self._serializer, "representation_length": self._repr_length, }
def setUp(self): super(DiscreteSpaceSerializerTest, self).setUp() self._space = gym.spaces.Discrete(n=2) self._serializer = space_serializer.create(self._space, vocab_size=2)
def setUp(self): super(MultiDiscreteSpaceSerializerTest, self).setUp() self._space = gym.spaces.MultiDiscrete(nvec=[2, 2]) self._serializer = space_serializer.create(self._space, vocab_size=2)