def test_policy_net_categorical_wo_net(self): batch_size = 5 act_dim = 16 obs_dim = 60 action_space = gym.spaces.Discrete(act_dim) pi = ub_nets.PolicyNet(action_space) obs = tf.zeros((batch_size, obs_dim), dtype=tf.float32) dist = pi(obs) self.assertTrue(isinstance(dist, ub_prob.Categorical)) act = dist.sample() self.assertArrayEqual((batch_size, ), act.shape) self.assertEqual(2, len(pi.trainable_variables))
def test_policy_net_diag_gaussian_wo_net(self): batch_size = 5 act_dim = 16 obs_dim = 60 low = np.ones((act_dim, ), dtype=np.float32) * -1.0 high = np.ones((act_dim, ), dtype=np.float32) * 1.0 action_space = gym.spaces.Box(low=low, high=high) pi = ub_nets.PolicyNet(action_space, squash=False) obs = tf.zeros((batch_size, obs_dim), dtype=tf.float32) dist = pi(obs) self.assertTrue(isinstance(dist, ub_prob.MultiNormal)) act = dist.sample() self.assertArrayEqual((batch_size, act_dim), act.shape) self.assertEqual(4, len(pi.trainable_variables))
def test_policy_net_exception(self): space = gym.spaces.MultiDiscrete([3, 4]) with self.assertRaises(ValueError): ub_nets.PolicyNet(space)