def test_env_manager_state(self): frame_stack_size = 5 env = FrameStackEnvManager('Breakout-v0', preprocess=preprocess, frame_stack_size=frame_stack_size) env.reset() state = env.state() self.assertEqual(frame_stack_size, state.shape[0])
def test_env_manager_reset(self): env = FrameStackEnvManager('Pong-v0', preprocess=preprocess, frame_stack_size=4) state = env.reset() screen = torch.tensor(preprocess(env.render('rgb_array'))).float() # All frames in the stack is similar self.assertTrue(torch.equal(state[0], screen)) self.assertTrue(torch.equal(state[1], screen)) self.assertTrue(torch.equal(state[2], screen)) self.assertTrue(torch.equal(state[3], screen))
def test_env_manager_init(self): stack_size = 3 env = FrameStackEnvManager('Breakout-v0', preprocess=preprocess, frame_stack_size=stack_size) self.assertEqual(0, len(env.frames)) env.reset() self.assertEqual(env.state()[0].shape, torch.Size(img_size)) self.assertEqual(env.state().shape[0], stack_size)
def test_env_manager_replay_memory(self): ray.init() memory = ReplayMemory() env = FrameStackEnvManager('Pong-v0', preprocess=preprocess, replay_memory=memory) env.reset() for i in range(10): env.step(env.action_space.sample()) self.assertEqual(len(memory.memory), 10) ray.shutdown()
def test_env_manager_step(self): env = FrameStackEnvManager('Pong-v0', preprocess=preprocess) start_state = env.reset() next_state, _, _, _ = env.step(env.action_space.sample()) new_frame = env.env.render('rgb_array') new_frame = cv2.cvtColor(new_frame, cv2.COLOR_RGB2GRAY) new_frame = cv2.resize(new_frame, img_size, interpolation=cv2.INTER_NEAREST) new_frame = torch.tensor(new_frame).float() self.assertTrue(torch.equal(next_state[0], start_state[1])) self.assertTrue(torch.equal(next_state[1], start_state[2])) self.assertTrue(torch.equal(next_state[2], start_state[3])) self.assertTrue(torch.equal(next_state[3], new_frame)) next_state2, _, _, _ = env.step(env.action_space.sample()) new_frame = env.env.render('rgb_array') new_frame = cv2.cvtColor(new_frame, cv2.COLOR_RGB2GRAY) new_frame = cv2.resize(new_frame, img_size, interpolation=cv2.INTER_NEAREST) new_frame = torch.tensor(new_frame).float() self.assertTrue(torch.equal(next_state2[0], next_state[1])) self.assertTrue(torch.equal(next_state2[1], next_state[2])) self.assertTrue(torch.equal(next_state2[2], next_state[3])) self.assertTrue(torch.equal(next_state2[3], new_frame)) next_state3, _, _, _ = env.step(env.action_space.sample()) next_state4, _, _, _ = env.step(env.action_space.sample()) new_frame = env.env.render('rgb_array') new_frame = cv2.cvtColor(new_frame, cv2.COLOR_RGB2GRAY) new_frame = cv2.resize(new_frame, img_size, interpolation=cv2.INTER_NEAREST) new_frame = torch.tensor(new_frame).float() self.assertTrue(torch.equal(next_state4[0], next_state[3])) self.assertTrue(torch.equal(next_state4[1], next_state2[3])) self.assertTrue(torch.equal(next_state4[2], next_state3[3])) self.assertTrue(torch.equal(next_state4[3], new_frame))
f'./resources/{env_id.split(":")[1]}_crowding_heatmap.png') # Last gen position step = max(1, int((num_generations - 1) / 4)) for i in range(0, num_generations, step): bg_img = f'{env_id.split(":")[1].split("-")[1]}' last_gen_pos = positions_log[i * (num_populations * 2):(i + 1) * (num_populations * 2)] last_gen_heatmap_plot, ax = plot_final_pos( cv.resize(cv.imread(f'./resources/{bg_img}.png', 0), (48, 48)), last_gen_pos) last_gen_heatmap_plot.savefig( f'./resources/{env_id.split(":")[1]}_crowding_{int(i/step)}_pos.png') controller = solution.to_network() torch.save(controller.state_dict(), f'./resources/{env_id.split(":")[1]}_crowding.params') env = FrameStackEnvManager(env_id, frame_stack_size=1, preprocess=preprocess) for i in range(2): total_reward = 0 state = env.reset() done = False while not done: env.render() action = controller(state.unsqueeze(0)).argmax().item() state, reward, done, _ = env.step(action) total_reward += reward.item() print('Total reward : ', total_reward)
import numpy as np from environment.framestack_env_manager import FrameStackEnvManager from environment.classic_control_env_manager import ClassicControlEnvManager img_size = (64, 64) #%% Test EnvironmentManager def preprocess(screen): screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY) screen = cv2.resize(screen, (64, 64), interpolation=cv2.INTER_NEAREST) screen = np.ascontiguousarray(screen) return screen env = FrameStackEnvManager('Breakout-v0', preprocess=preprocess) raw_screen = env.get_raw_screen() plt.figure() plt.title('Raw Screen') plt.axis('off') plt.imshow(raw_screen) #.transpose(1,2,0)) plt.show() env.reset() state = env.state() plt.suptitle('Initial State') plt.subplot(141) plt.axis('off') plt.imshow(state[0], cmap=plt.get_cmap('gray')) plt.subplot(142)