class TestValueAgent(TestCase):
    def setUp(self) -> None:
        self.env = gym.make("CartPole-v0")
        self.net = Mock(return_value=torch.Tensor([[0.0, 100.0]]))
        self.state = torch.tensor(self.env.reset())
        self.device = self.state.device
        self.value_agent = ValueAgent(self.net, self.env.action_space.n)

    def test_value_agent(self):

        action = self.value_agent(self.state, self.device)
        self.assertIsInstance(action, int)

    def test_value_agent_GET_ACTION(self):
        action = self.value_agent.get_action(self.state, self.device)
        self.assertIsInstance(action, int)
        self.assertEqual(action, 1)

    def test_value_agent_RANDOM(self):
        action = self.value_agent.get_random_action()
        self.assertIsInstance(action, int)
예제 #2
0
class TestValueAgent(TestCase):
    def setUp(self) -> None:
        self.env = gym.make("CartPole-v0")
        self.net = Mock(return_value=Tensor([[0.0, 100.0]]))
        self.state = [self.env.reset()]
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.value_agent = ValueAgent(self.net, self.env.action_space.n)

    def test_value_agent(self):

        action = self.value_agent(self.state, self.device)
        self.assertIsInstance(action, list)
        self.assertIsInstance(action[0], int)

    def test_value_agent_get_action(self):
        action = self.value_agent.get_action(self.state, self.device)
        self.assertIsInstance(action, np.ndarray)
        self.assertEqual(action[0], 1)

    def test_value_agent_random(self):
        action = self.value_agent.get_random_action(self.state)
        self.assertIsInstance(action[0], int)