Exemplo n.º 1
0
 def step(self, action):
     self._state = State.from_gym(
         self._env.step(self._convert(action)),
         dtype=self._env.observation_space.dtype,
         device=self._device
     )
     return self._state
 def test_apply_done(self):
     observation = torch.randn(3, 4)
     state = State.from_gym((observation, 0., True, {}))
     model = torch.nn.Conv1d(3, 5, 2)
     output = state.apply(model, 'observation')
     self.assertEqual(output.shape, (5, 3))
     self.assertEqual(output.sum().item(), 0)
Exemplo n.º 3
0
def test_single_episode(env,
                        _agent,
                        generate_gif_callback=None,
                        side="first_0"):
    # initialize the episode
    observation = env.reset()
    returns = 0
    num_steps = 0
    frame_idx = 0
    prev_obs = None
    print(side)

    # loop until the episode is finished
    done = False
    while not done:
        #print(_agent.agents)
        action = _agent.act(
            side,
            State.from_gym((observation.reshape((1, 84, 84), )),
                           device=device,
                           dtype=np.uint8))
        observation, reward, done, info = env.step(action)
        returns += reward
        num_steps += 1

    return returns, num_steps
Exemplo n.º 4
0
def generate_episode_gifs(env, _agent, max_frames, save_dir, side="first_0"):
    # initialize the episode
    observation = env.reset()
    frame_idx = 0
    prev_obs = None

    # loop until the episode is finished
    done = False
    while not done:
        #print(_agent.agents)
        action = _agent.act(
            side,
            State.from_gym((observation.reshape((1, 84, 84), )),
                           device=device,
                           dtype=np.uint8))
        observation, reward, done, info = env.step(action)
        if reward != 0.0:
            print(reward)
        obs = env.render(mode='rgb_array')
        if not prev_obs or not np.equal(obs, prev_obs).all():
            im = Image.fromarray(obs)
            im.save(f"{save_dir}{str(frame_idx).zfill(4)}.png")

            frame_idx += 1
            if frame_idx >= max_frames:
                break
 def reset(self):
     self._env.reset()
     observation, _, _, _ = self._env.last()
     state = State.from_gym((observation.reshape((1, 84, 84), )),
                            device=self.device,
                            dtype=np.uint8)
     return state
 def test_from_gym_reset(self):
     observation = np.array([1, 2, 3])
     state = State.from_gym(observation)
     tt.assert_equal(state.observation, torch.from_numpy(observation))
     self.assertEqual(state.mask, 1.)
     self.assertEqual(state.done, False)
     self.assertEqual(state.reward, 0.)
     self.assertEqual(state.shape, ())
 def test_from_gym_step(self):
     observation = np.array([1, 2, 3])
     state = State.from_gym((observation, 2., True, {'coolInfo': 3.}))
     tt.assert_equal(state.observation, torch.from_numpy(observation))
     self.assertEqual(state.mask, 0.)
     self.assertEqual(state.done, True)
     self.assertEqual(state.reward, 2.)
     self.assertEqual(state['coolInfo'], 3.)
     self.assertEqual(state.shape, ())
Exemplo n.º 8
0
    def _reset(self):
        state = State.from_gym(self._reset_orig(),
                               dtype=self.observation_space.dtype,
                               device="cuda")
        obs = state.observation.cpu().numpy()
        r = state.reward
        done = state.done
        info = {}

        pilot_action = onehot_encode(self.pilot_policy(state))
        obs = np.concatenate((obs, pilot_action))
        return obs
Exemplo n.º 9
0
    def _step(self, action):
        state = State.from_gym(self._step_orig(_convert(disc_to_cont(action))),
                               dtype=self.observation_space.dtype,
                               device="cuda")

        obs = state.observation.cpu().numpy()
        r = state.reward
        done = state.done
        info = {}

        pilot_action = onehot_encode(self.pilot_policy(state))
        obs = np.concatenate((obs, pilot_action))
        return obs, r, done, info
Exemplo n.º 10
0
 def step(self, action):
     observation, reward, done, info = self._env.last()
     if self._env.dones[self._env.agent_selection]:
         action = None
     if torch.is_tensor(action):
         self._env.step(action.item())
     else:
         self._env.step(action)
     observation, reward, done, info = self._env.last()
     return State.from_gym((observation.reshape(
         (1, 84, 84)), reward, done, info),
                           device=self.device,
                           dtype=np.uint8)
Exemplo n.º 11
0
 def reset(self):
     state = self._env.reset(), 0., False, None
     self._state = State.from_gym(state, dtype=self._env.observation_space.dtype, device=self._device)
     return self._state
 def test_apply_mask(self):
     observation = torch.randn(3, 4)
     state = State.from_gym((observation, 0., True, {}))
     tt.assert_equal(state.apply_mask(observation), torch.zeros(3, 4))