Exemplo n.º 1
0
    def test_onehot(self):
        single = 3
        oh_single = ch.onehot(single, dim=DIM)
        self.assertTrue(oh_single.size(0) == 1)
        self.assertTrue(oh_single.size(1) == DIM)
        ref = th.zeros(1, DIM)
        ref[0, single] += 1
        self.assertTrue((oh_single - ref).pow(2).sum().item() == 0)

        multi = th.arange(DIM)
        multi = ch.onehot(multi, dim=DIM)
        ref = th.eye(DIM)
        self.assertTrue((multi - ref).pow(2).sum().item() == 0)
Exemplo n.º 2
0
 def forward(self, state):
     state = ch.onehot(state, dim=self.input_size)
     loc = self.mean(state)
     density = Categorical(logits=loc)
     action = density.sample()
     log_prob = density.log_prob(action).mean().view(-1, 1).detach()
     return action, {'density': density, 'log_prob': log_prob}
Exemplo n.º 3
0
 def forward(self, x):
     x = ch.onehot(x, self.env.state_size)
     q_values = self.qf(x)
     action = self.e_greedy(q_values)
     info = {
         'q_action': q_values[:, action],
     }
     return action, info
Exemplo n.º 4
0
    def step(self, action, *args, **kwargs):
        state, reward, done, info = super(VisdomLogger, self).step(action, *args, **kwargs)

        if self.interval > 0 and self.num_steps % self.interval == 0:
            if isinstance(info, tuple):
                self.update_steps_plots(info[0]['logger_steps_stats'])
                self.update_ep_plots(info[0]['logger_ep_stats'])
            else:
                self.update_steps_plots(info['logger_steps_stats'])
                self.update_ep_plots(info['logger_ep_stats'])

            if len(self.full_ep_actions) > 0:
                self.update_ribbon_plot(self.full_ep_actions,
                                        self.ep_actions_win)
            if len(self.full_ep_renders) > 0:
                try:
                    # TODO: Remove try clause when merged:
                    # https://github.com/facebookresearch/visdom/pull/595
                    frames = np.stack(self.full_ep_renders)
                    self.update_video(frames, self.ep_renders_win)
                    self.full_ep_renders = []
                except Exception:
                    pass

        if not self.is_vectorized:
            # Should record ?
            if self.num_episodes % self.ep_interval == 0:
                if self.discrete_actions:
                    action = ch.onehot(action, dim=self.action_size)[0]
                self.ep_actions.append(action)
                if self.render and self.can_record:
                    frame = self.env.render(mode='rgb_array')
                    self.ep_renders.append(frame)

            # Done recording ?
            if done and (self.num_episodes - 1) % self.ep_interval == 0:
                self.full_ep_actions = self.ep_actions
                self.ep_actions = []
                self.full_ep_renders = self.ep_renders
                self.ep_renders = []

        return state, reward, done, info
Exemplo n.º 5
0
def main(env='CliffWalking-v0'):
    env = gym.make(env)
    env = envs.Logger(env, interval=1000)
    env = envs.Torch(env)
    env = envs.Runner(env)
    agent = Agent(env)
    discount = 1.00
    optimizer = optim.SGD(agent.parameters(), lr=0.5, momentum=0.0)
    for t in range(1, 10000):
        transition = env.run(agent, steps=1)[0]

        curr_q = transition.q_action
        next_state = ch.onehot(transition.next_state, dim=env.state_size)
        next_q = agent.qf(next_state).max().detach()
        td_error = ch.temporal_difference(discount, transition.reward,
                                          transition.done, curr_q, next_q)

        optimizer.zero_grad()
        loss = td_error.pow(2).mul(0.5)
        loss.backward()
        optimizer.step()
Exemplo n.º 6
0
    def reset(self):
        self._gen_grid(self.width, self.height)

        # These fields should be defined by _gen_grid
        assert self.start_pos is not None
        assert self.start_dir is not None

        # Check that the agent doesn't overlap with an object
        start_cell = self.grid.get(*self.start_pos)
        assert start_cell is None or start_cell.can_overlap()

        # Place agent
        self.agent_pos = self.start_pos
        self.aget_dir = self.start_dir

        self.carrying = None
        self.step_count = 0

        # Return first observation
        obs = ((self.width - 2) *
               (self.agent_pos[1] - 1) + self.agent_pos[0]) - 1
        return ch.onehot(obs, 81)
Exemplo n.º 7
0
 def _state(self, state):
     idx = (state.y - 1) * self.mdp.height
     idx += self.mdp.width
     return ch.onehot(idx, dim=self.state_size)