示例#1
0
 def test_view(self):
     state = StateArray.array(
         [State(torch.randn((3, 4))),
          State(torch.randn((3, 4)))])
     self.assertEqual(state.shape, (2, ))
     state = StateArray.array([state] * 3)
     self.assertEqual(state.shape, (3, 2))
     state = state.view((2, 3))
     self.assertEqual(state.shape, (2, 3))
     self.assertEqual(state.observation.shape, (2, 3, 3, 4))
示例#2
0
 def test_multi_dim(self):
     state = StateArray.array(
         [State(torch.randn((3, 4))),
          State(torch.randn((3, 4)))])
     self.assertEqual(state.shape, (2, ))
     state = StateArray.array([state] * 3)
     self.assertEqual(state.shape, (3, 2))
     state = StateArray.array([state] * 5)
     self.assertEqual(state.shape, (5, 3, 2))
     tt.assert_equal(state.mask, torch.ones((5, 3, 2)))
     tt.assert_equal(state.done, torch.zeros((5, 3, 2)).bool())
     tt.assert_equal(state.reward, torch.zeros((5, 3, 2)))
 def test_key_error(self):
     with warnings.catch_warnings(record=True) as w:
         warnings.simplefilter("always")
         StateArray.array([
             State({
                 'observation': torch.tensor([1, 2]),
                 'other_key': True
             }),
             State({
                 'observation': torch.tensor([1, 2]),
             }),
         ])
         self.assertEqual(len(w), 1)
         self.assertEqual(w[0].message.args[0], 'KeyError while creating StateArray for key "other_key", omitting.')