コード例 #1
0
 def test_policy_net_categorical_wo_net(self):
     batch_size = 5
     act_dim = 16
     obs_dim = 60
     action_space = gym.spaces.Discrete(act_dim)
     pi = ub_nets.PolicyNet(action_space)
     obs = tf.zeros((batch_size, obs_dim), dtype=tf.float32)
     dist = pi(obs)
     self.assertTrue(isinstance(dist, ub_prob.Categorical))
     act = dist.sample()
     self.assertArrayEqual((batch_size, ), act.shape)
     self.assertEqual(2, len(pi.trainable_variables))
コード例 #2
0
 def test_policy_net_diag_gaussian_wo_net(self):
     batch_size = 5
     act_dim = 16
     obs_dim = 60
     low = np.ones((act_dim, ), dtype=np.float32) * -1.0
     high = np.ones((act_dim, ), dtype=np.float32) * 1.0
     action_space = gym.spaces.Box(low=low, high=high)
     pi = ub_nets.PolicyNet(action_space, squash=False)
     obs = tf.zeros((batch_size, obs_dim), dtype=tf.float32)
     dist = pi(obs)
     self.assertTrue(isinstance(dist, ub_prob.MultiNormal))
     act = dist.sample()
     self.assertArrayEqual((batch_size, act_dim), act.shape)
     self.assertEqual(4, len(pi.trainable_variables))
コード例 #3
0
 def test_policy_net_exception(self):
     space = gym.spaces.MultiDiscrete([3, 4])
     with self.assertRaises(ValueError):
         ub_nets.PolicyNet(space)