def test_onehot(self): single = 3 oh_single = ch.onehot(single, dim=DIM) self.assertTrue(oh_single.size(0) == 1) self.assertTrue(oh_single.size(1) == DIM) ref = th.zeros(1, DIM) ref[0, single] += 1 self.assertTrue((oh_single - ref).pow(2).sum().item() == 0) multi = th.arange(DIM) multi = ch.onehot(multi, dim=DIM) ref = th.eye(DIM) self.assertTrue((multi - ref).pow(2).sum().item() == 0)
def forward(self, state): state = ch.onehot(state, dim=self.input_size) loc = self.mean(state) density = Categorical(logits=loc) action = density.sample() log_prob = density.log_prob(action).mean().view(-1, 1).detach() return action, {'density': density, 'log_prob': log_prob}
def forward(self, x): x = ch.onehot(x, self.env.state_size) q_values = self.qf(x) action = self.e_greedy(q_values) info = { 'q_action': q_values[:, action], } return action, info
def step(self, action, *args, **kwargs): state, reward, done, info = super(VisdomLogger, self).step(action, *args, **kwargs) if self.interval > 0 and self.num_steps % self.interval == 0: if isinstance(info, tuple): self.update_steps_plots(info[0]['logger_steps_stats']) self.update_ep_plots(info[0]['logger_ep_stats']) else: self.update_steps_plots(info['logger_steps_stats']) self.update_ep_plots(info['logger_ep_stats']) if len(self.full_ep_actions) > 0: self.update_ribbon_plot(self.full_ep_actions, self.ep_actions_win) if len(self.full_ep_renders) > 0: try: # TODO: Remove try clause when merged: # https://github.com/facebookresearch/visdom/pull/595 frames = np.stack(self.full_ep_renders) self.update_video(frames, self.ep_renders_win) self.full_ep_renders = [] except Exception: pass if not self.is_vectorized: # Should record ? if self.num_episodes % self.ep_interval == 0: if self.discrete_actions: action = ch.onehot(action, dim=self.action_size)[0] self.ep_actions.append(action) if self.render and self.can_record: frame = self.env.render(mode='rgb_array') self.ep_renders.append(frame) # Done recording ? if done and (self.num_episodes - 1) % self.ep_interval == 0: self.full_ep_actions = self.ep_actions self.ep_actions = [] self.full_ep_renders = self.ep_renders self.ep_renders = [] return state, reward, done, info
def main(env='CliffWalking-v0'): env = gym.make(env) env = envs.Logger(env, interval=1000) env = envs.Torch(env) env = envs.Runner(env) agent = Agent(env) discount = 1.00 optimizer = optim.SGD(agent.parameters(), lr=0.5, momentum=0.0) for t in range(1, 10000): transition = env.run(agent, steps=1)[0] curr_q = transition.q_action next_state = ch.onehot(transition.next_state, dim=env.state_size) next_q = agent.qf(next_state).max().detach() td_error = ch.temporal_difference(discount, transition.reward, transition.done, curr_q, next_q) optimizer.zero_grad() loss = td_error.pow(2).mul(0.5) loss.backward() optimizer.step()
def reset(self): self._gen_grid(self.width, self.height) # These fields should be defined by _gen_grid assert self.start_pos is not None assert self.start_dir is not None # Check that the agent doesn't overlap with an object start_cell = self.grid.get(*self.start_pos) assert start_cell is None or start_cell.can_overlap() # Place agent self.agent_pos = self.start_pos self.aget_dir = self.start_dir self.carrying = None self.step_count = 0 # Return first observation obs = ((self.width - 2) * (self.agent_pos[1] - 1) + self.agent_pos[0]) - 1 return ch.onehot(obs, 81)
def _state(self, state): idx = (state.y - 1) * self.mdp.height idx += self.mdp.width return ch.onehot(idx, dim=self.state_size)