def test_view(self): state = StateArray.array( [State(torch.randn((3, 4))), State(torch.randn((3, 4)))]) self.assertEqual(state.shape, (2, )) state = StateArray.array([state] * 3) self.assertEqual(state.shape, (3, 2)) state = state.view((2, 3)) self.assertEqual(state.shape, (2, 3)) self.assertEqual(state.observation.shape, (2, 3, 3, 4))
def test_multi_dim(self): state = StateArray.array( [State(torch.randn((3, 4))), State(torch.randn((3, 4)))]) self.assertEqual(state.shape, (2, )) state = StateArray.array([state] * 3) self.assertEqual(state.shape, (3, 2)) state = StateArray.array([state] * 5) self.assertEqual(state.shape, (5, 3, 2)) tt.assert_equal(state.mask, torch.ones((5, 3, 2))) tt.assert_equal(state.done, torch.zeros((5, 3, 2)).bool()) tt.assert_equal(state.reward, torch.zeros((5, 3, 2)))
def test_key_error(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") StateArray.array([ State({ 'observation': torch.tensor([1, 2]), 'other_key': True }), State({ 'observation': torch.tensor([1, 2]), }), ]) self.assertEqual(len(w), 1) self.assertEqual(w[0].message.args[0], 'KeyError while creating StateArray for key "other_key", omitting.')