示例#1
0
 def test_multi_env(self):
     state = State(torch.randn(2, 2))
     self.agent.act(state, 0)
     tt.assert_allclose(self.test_agent.last_state.features,
                        torch.tensor([[0.3923, -0.2236, 0.],
                                      [-0.3195, -1.2050, 0.]]),
                        atol=1e-04)
     self.agent.act(state, 0)
     tt.assert_allclose(self.test_agent.last_state.features,
                        torch.tensor([[0.3923, -0.2236, 1e-3],
                                      [-0.3195, -1.2050, 1e-3]]),
                        atol=1e-04)
     self.agent.act(State(state.features, torch.tensor([1., 0.])), 0)
     tt.assert_allclose(self.test_agent.last_state.features,
                        torch.tensor([[0.3923, -0.2236, 2e-3],
                                      [-0.3195, -1.2050, 2e-3]]),
                        atol=1e-04)
     self.agent.act(state, 0)
     tt.assert_allclose(self.test_agent.last_state.features,
                        torch.tensor([[0.3923, -0.2236, 3e-3],
                                      [-0.3195, -1.2050, 0.]]),
                        atol=1e-04)
     self.agent.act(state, 0)
     tt.assert_allclose(self.test_agent.last_state.features,
                        torch.tensor([[0.3923, -0.2236, 4e-3],
                                      [-0.3195, -1.2050, 1e-3]]),
                        atol=1e-04)
示例#2
0
 def test_reset(self):
     state = State(torch.randn(1, 4))
     self.agent.act(state, 0)
     tt.assert_allclose(self.test_agent.last_state.features,
                        torch.tensor(
                            [[0.3923, -0.2236, -0.3195, -1.2050, 0.0000]]),
                        atol=1e-04)
     self.agent.act(state, 0)
     tt.assert_allclose(self.test_agent.last_state.features,
                        torch.tensor(
                            [[0.3923, -0.2236, -0.3195, -1.2050, 1e-3]]),
                        atol=1e-04)
     self.agent.act(State(state.features, DONE), 0)
     tt.assert_allclose(self.test_agent.last_state.features,
                        torch.tensor(
                            [[0.3923, -0.2236, -0.3195, -1.2050, 2e-3]]),
                        atol=1e-04)
     self.agent.act(State(state.features), 0)
     tt.assert_allclose(self.test_agent.last_state.features,
                        torch.tensor(
                            [[0.3923, -0.2236, -0.3195, -1.2050, 0.0000]]),
                        atol=1e-04)
     self.agent.act(state, 0)
     tt.assert_allclose(self.test_agent.last_state.features,
                        torch.tensor(
                            [[0.3923, -0.2236, -0.3195, -1.2050, 1e-3]]),
                        atol=1e-04)
示例#3
0
 def test_from_list(self):
     state1 = State(torch.randn(1, 4), mask=DONE, info=['a'])
     state2 = State(torch.randn(1, 4), mask=NOT_DONE, info=['b'])
     state3 = State(torch.randn(1, 4))
     state = State.from_list([state1, state2, state3])
     tt.assert_equal(state.raw,
                     torch.cat((state1.raw, state2.raw, state3.raw)))
     tt.assert_equal(state.mask, torch.tensor([0, 1, 1]))
     self.assertEqual(state.info, ['a', 'b', None])
示例#4
0
 def test_constructor_defaults(self):
     raw = torch.randn(3, 4)
     state = State(raw)
     tt.assert_equal(state.features, raw)
     tt.assert_equal(state.mask, torch.ones(3))
     tt.assert_equal(state.raw, raw)
     self.assertEqual(state.info, [None] * 3)
示例#5
0
 def test_get_item(self):
     raw = torch.randn(3, 4)
     states = State(raw)
     state = states[2]
     tt.assert_equal(state.raw, raw[2].unsqueeze(0))
     tt.assert_equal(state.mask, NOT_DONE)
     self.assertEqual(state.info, [None])
示例#6
0
 def test_from_gym(self):
     gym_obs = np.array([1, 2, 3])
     done = True
     info = 'a'
     state = State.from_gym(gym_obs, done, info)
     tt.assert_equal(state.raw, torch.tensor([[1, 2, 3]]))
     tt.assert_equal(state.mask, DONE)
     self.assertEqual(state.info, ['a'])
示例#7
0
 def test_custom_constructor_args(self):
     raw = torch.randn(3, 4)
     mask = torch.zeros(3)
     info = ['a', 'b', 'c']
     state = State(raw, mask=mask, info=info)
     tt.assert_equal(state.features, raw)
     tt.assert_equal(state.mask, torch.zeros(3))
     self.assertEqual(state.info, info)
示例#8
0
 def test_len(self):
     state = State(torch.randn(3, 4))
     self.assertEqual(len(state), 3)
示例#9
0
 def test_done(self):
     raw = torch.randn(1, 4)
     state = State(raw, mask=DONE)
     self.assertTrue(state.done)
示例#10
0
 def test_not_done(self):
     state = State(torch.randn(1, 4))
     self.assertFalse(state.done)