Пример #1
0
 def test_preprocess_observation_box_bounded(self):
     space = gym.spaces.Box(low=np.zeros((64,64,3)),
                            high=np.ones((64,64,3))*255, 
                            dtype=np.uint8)
     obs = space.sample()
     norm_obs = utils.normalize(obs, space.low, space.high, 0., 1.)
     res_obs = utils.preprocess_observation(obs, space)
     self.assertArrayClose(norm_obs, res_obs)
     # batch
     batch_size = 8
     obses = [space.sample() for _ in range(batch_size)]
     obs = utils.stack_obs(obses, space)
     norm_obs = utils.normalize(obs, space.low, space.high, 0., 1.)
     res_obs = utils.preprocess_observation(obs, space)
     self.assertArrayClose(norm_obs, res_obs)
Пример #2
0
    def _forward(self, inputs, training=True):
        '''Forward networks

        Args:
            inputs (tf.Tensor): batch observations in shape (batch, obs_space.shape).
                tf.uint8 for image observations and tf.float32 for non-image 
                observations.
            training (bool, optional): training mode. Default to True.

        Returns:
            tf.Tensor or tuple: Predicted latent distributions.
                tf.Tensor for discrete action space (categorical)
                    shape (batch, act_space.n)
                tuple (mean, std) for box action space (multinormal)
                    shapes are both (batch, act_space.shape)
            tf.Tensor: Predicted state values, shape (batch,)
        '''

        # cast and normalize non-float32 inputs (e.g. image in uint8)
        # NOTICE: image in float32 is considered as having been normalized
        inputs = preprocess_observation(inputs, self.observation_space)

        # forward network
        latent = self.net(inputs, training=training)
        # forward policy net
        logits = self.policy_net(latent, training=training)

        # forward value net
        if self.net2 is not None:
            latent = self.net2(inputs, training=training)
        values = self.value_net(latent, training=training)  # (batch, 1)
        values = tf.squeeze(values, axis=-1)  # (batch, )

        return logits, values
Пример #3
0
 def test_preprocess_observation_box_unbounded(self):
     space = gym.spaces.Box(low=np.full((64,), -np.inf, dtype=np.float32),
                            high=np.full((64,), np.inf, dtype=np.float32),
                            dtype=np.float32)
     obs = space.sample()
     res_obs = utils.preprocess_observation(obs, space)
     self.assertArrayClose(obs, res_obs)
Пример #4
0
 def test_preprocess_observation_discrete(self):
     space_dim = 5
     space = gym.spaces.Discrete(space_dim)
     obs = space.sample()
     # one hot
     norm_obs = np.zeros((space_dim,), dtype=np.float32)
     norm_obs[obs] = 1.0
     res_obs = utils.preprocess_observation(obs, space)
     self.assertArrayClose(norm_obs, res_obs)
     # batch
     batch_size = 8
     obses = [space.sample() for _ in range(batch_size)]
     obs = utils.stack_obs(obses, space)
     # one hot
     norm_obs = np.zeros((obs.size, space_dim), dtype=np.float32)
     norm_obs[np.arange(obs.size), obs] = 1.0
     res_obs = utils.preprocess_observation(obs, space)
     self.assertArrayClose(norm_obs, res_obs)
Пример #5
0
 def test_preprocess_observation_multidiscrete(self):
     space_dims = [4, 7]
     space = gym.spaces.MultiDiscrete(space_dims)
     obs = space.sample()
     # one hot
     offset = np.cumsum([0] + space_dims)[:len(space_dims)]
     norm_obs = np.zeros((np.sum(space_dims),), dtype=np.float32)
     norm_obs[obs+offset] = 1.0
     res_obs = utils.preprocess_observation(obs, space)
     self.assertArrayClose(norm_obs, res_obs)
     # batch
     batch_size = 8
     obses = [space.sample() for _ in range(batch_size)]
     obs = utils.stack_obs(obses, space)
     # one hot
     norm_obs = np.zeros((batch_size, np.sum(space_dims)), dtype=np.float32)
     for batch, item in zip(np.arange(batch_size), obs+offset):
         norm_obs[batch, item] = 1.0
     res_obs = utils.preprocess_observation(obs, space)
     self.assertArrayClose(norm_obs, res_obs)
Пример #6
0
 def test_preprocess_observation_multibinary(self):
     space_dim = 6
     space = gym.spaces.MultiBinary(space_dim)
     obs = space.sample()
     norm_obs = utils.preprocess_observation(obs, space)
     self.assertArrayClose(obs, norm_obs)