def test_env_manager_state(self):
     frame_stack_size = 5
     env = FrameStackEnvManager('Breakout-v0',
                                preprocess=preprocess,
                                frame_stack_size=frame_stack_size)
     env.reset()
     state = env.state()
     self.assertEqual(frame_stack_size, state.shape[0])
 def test_env_manager_reset(self):
     env = FrameStackEnvManager('Pong-v0',
                                preprocess=preprocess,
                                frame_stack_size=4)
     state = env.reset()
     screen = torch.tensor(preprocess(env.render('rgb_array'))).float()
     # All frames in the stack is similar
     self.assertTrue(torch.equal(state[0], screen))
     self.assertTrue(torch.equal(state[1], screen))
     self.assertTrue(torch.equal(state[2], screen))
     self.assertTrue(torch.equal(state[3], screen))
 def test_env_manager_init(self):
     stack_size = 3
     env = FrameStackEnvManager('Breakout-v0',
                                preprocess=preprocess,
                                frame_stack_size=stack_size)
     self.assertEqual(0, len(env.frames))
     env.reset()
     self.assertEqual(env.state()[0].shape, torch.Size(img_size))
     self.assertEqual(env.state().shape[0], stack_size)
 def test_env_manager_replay_memory(self):
     ray.init()
     memory = ReplayMemory()
     env = FrameStackEnvManager('Pong-v0',
                                preprocess=preprocess,
                                replay_memory=memory)
     env.reset()
     for i in range(10):
         env.step(env.action_space.sample())
     self.assertEqual(len(memory.memory), 10)
     ray.shutdown()
    def test_env_manager_step(self):
        env = FrameStackEnvManager('Pong-v0', preprocess=preprocess)
        start_state = env.reset()
        next_state, _, _, _ = env.step(env.action_space.sample())
        new_frame = env.env.render('rgb_array')
        new_frame = cv2.cvtColor(new_frame, cv2.COLOR_RGB2GRAY)
        new_frame = cv2.resize(new_frame,
                               img_size,
                               interpolation=cv2.INTER_NEAREST)
        new_frame = torch.tensor(new_frame).float()
        self.assertTrue(torch.equal(next_state[0], start_state[1]))
        self.assertTrue(torch.equal(next_state[1], start_state[2]))
        self.assertTrue(torch.equal(next_state[2], start_state[3]))
        self.assertTrue(torch.equal(next_state[3], new_frame))

        next_state2, _, _, _ = env.step(env.action_space.sample())
        new_frame = env.env.render('rgb_array')
        new_frame = cv2.cvtColor(new_frame, cv2.COLOR_RGB2GRAY)
        new_frame = cv2.resize(new_frame,
                               img_size,
                               interpolation=cv2.INTER_NEAREST)
        new_frame = torch.tensor(new_frame).float()
        self.assertTrue(torch.equal(next_state2[0], next_state[1]))
        self.assertTrue(torch.equal(next_state2[1], next_state[2]))
        self.assertTrue(torch.equal(next_state2[2], next_state[3]))
        self.assertTrue(torch.equal(next_state2[3], new_frame))

        next_state3, _, _, _ = env.step(env.action_space.sample())
        next_state4, _, _, _ = env.step(env.action_space.sample())
        new_frame = env.env.render('rgb_array')
        new_frame = cv2.cvtColor(new_frame, cv2.COLOR_RGB2GRAY)
        new_frame = cv2.resize(new_frame,
                               img_size,
                               interpolation=cv2.INTER_NEAREST)
        new_frame = torch.tensor(new_frame).float()
        self.assertTrue(torch.equal(next_state4[0], next_state[3]))
        self.assertTrue(torch.equal(next_state4[1], next_state2[3]))
        self.assertTrue(torch.equal(next_state4[2], next_state3[3]))
        self.assertTrue(torch.equal(next_state4[3], new_frame))
Esempio n. 6
0
    f'./resources/{env_id.split(":")[1]}_crowding_heatmap.png')

# Last gen position
step = max(1, int((num_generations - 1) / 4))
for i in range(0, num_generations, step):
    bg_img = f'{env_id.split(":")[1].split("-")[1]}'
    last_gen_pos = positions_log[i * (num_populations * 2):(i + 1) *
                                 (num_populations * 2)]
    last_gen_heatmap_plot, ax = plot_final_pos(
        cv.resize(cv.imread(f'./resources/{bg_img}.png', 0), (48, 48)),
        last_gen_pos)
    last_gen_heatmap_plot.savefig(
        f'./resources/{env_id.split(":")[1]}_crowding_{int(i/step)}_pos.png')

controller = solution.to_network()
torch.save(controller.state_dict(),
           f'./resources/{env_id.split(":")[1]}_crowding.params')
env = FrameStackEnvManager(env_id, frame_stack_size=1, preprocess=preprocess)

for i in range(2):
    total_reward = 0
    state = env.reset()
    done = False

    while not done:
        env.render()
        action = controller(state.unsqueeze(0)).argmax().item()
        state, reward, done, _ = env.step(action)
        total_reward += reward.item()

    print('Total reward : ', total_reward)
Esempio n. 7
0
import numpy as np
from environment.framestack_env_manager import FrameStackEnvManager
from environment.classic_control_env_manager import ClassicControlEnvManager

img_size = (64, 64)


#%% Test EnvironmentManager
def preprocess(screen):
    screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
    screen = cv2.resize(screen, (64, 64), interpolation=cv2.INTER_NEAREST)
    screen = np.ascontiguousarray(screen)
    return screen


env = FrameStackEnvManager('Breakout-v0', preprocess=preprocess)

raw_screen = env.get_raw_screen()
plt.figure()
plt.title('Raw Screen')
plt.axis('off')
plt.imshow(raw_screen)  #.transpose(1,2,0))
plt.show()

env.reset()
state = env.state()
plt.suptitle('Initial State')
plt.subplot(141)
plt.axis('off')
plt.imshow(state[0], cmap=plt.get_cmap('gray'))
plt.subplot(142)