def backprop(self, states, locs, R, actions): x, y = zip(*locs) Qsa = self.forward(totorch(states, self.device), torch.tensor(x).to(self.device), torch.tensor(y)).to(self.device) actions_onehot = totorch(one_hot(actions, self.action_size), self.device) Qvalue = torch.sum(Qsa * actions_onehot, axis=1) loss = torch.mean(torch.square(totorch(R).float().cuda() - Qvalue)) loss.backward() self.optim.step() self.optim.zero_grad() return loss.detach().cpu().numpy()
def train(global_model, model, env, nsteps, num_episodes, ID): opt = torch.optim.RMSprop(global_model.parameters(), lr=1e-3) episode = 0 episode_steps = 0 episode_score = 0 T = 0 state = env.reset() start = time.time() while episode < num_episodes: rollout = [] for t in range(nsteps): with torch.no_grad(): policy, value = model(totorch(state[None], device='cpu')) policy, value = tonumpy(policy), tonumpy(value) action = np.random.choice(policy.shape[1], p=policy[0]) next_state, reward, done, info = env.step(action) episode_score += reward rollout.append((state, action, reward, value, done)) state = next_state T += 1 episode_steps += 1 if done or t == nsteps-1: states, actions, rewards, values, dones = stack_many(*zip(*rollout)) with torch.no_grad(): _, last_values = model.forward(totorch(next_state[None], device='cpu')) last_values = last_values.cpu().numpy() R = lambda_return(rewards, values, last_values, dones, gamma=0.9, lambda_=0.95, clip=False) loss = update_params(model, global_model, opt, states, actions, R) #self.T += t if done: episode += 1 state = env.reset() if episode % 1 == 0: time_taken = time.time() - start print(f'worker {ID}, total worker steps {T:,} local episode {episode}, episode score {episode_score} episode steps {episode_steps}, time taken {time_taken:,.1f}s, fps {episode_steps/time_taken:.2f}') episode_steps = 0 episode_score = 0 start = time.time() break
def eval_state(self, state, loc): with torch.no_grad(): x, y = zip(*loc) x, y = torch.tensor(x).to(self.device), torch.tensor(y).to( self.device) state_torch = totorch(state, self.device) Qsa = self.model(state_torch, x, y) return tonumpy(Qsa)
def evaluate(self, state: np.ndarray, hidden: np.ndarray = None, done=None): state = totorch(state, self.device) hidden = totorch_many( *hidden, device=self.device) if hidden is not None else None with torch.no_grad(): policy, value, hidden = self.forward(state, hidden, done) return tonumpy(policy), tonumpy(value), tonumpy_many(*hidden)
def update_params(lm, gm, gopt, states, actions, R): states, R, actions = totorch(states, 'cpu'), totorch(R, 'cpu'), totorch(actions, 'cpu') actions_onehot = F.one_hot(actions.long(), num_classes=lm.action_size) policies, values = lm.forward(states) loss = lm.loss(policies, R, values, actions_onehot) loss.backward() if lm.grad_clip is not None: torch.nn.utils.clip_grad_norm_(lm.parameters(), lm.grad_clip) for local_param, global_param in zip(lm.parameters(), gm.parameters()): global_param._grad = local_param.grad gopt.step() gopt.zero_grad() #self.scheduler.step() lm.load_state_dict(gm.state_dict()) return loss.detach().cpu().numpy()
def get_pixel_control(self, state:np.ndarray): with torch.no_grad(): enc_state = self.policy.model(totorch(state, self.device)) Qaux = self.Qaux(enc_state) return tonumpy(Qaux)
def evaluate(self, state): with torch.no_grad(): policy, value_extr, value_intr = self.forward( totorch(state, self.device)) return tonumpy(policy), tonumpy(value_extr), tonumpy(value_intr)
def evaluate(self, state: np.ndarray): with torch.no_grad(): policy, _ = self.policy.forward(totorch(state, self.policy.device)) value = self.value.forward(totorch(state, self.value.device)) return tonumpy(policy), tonumpy(value)
def get_value(self, state: np.ndarray): with torch.no_grad(): value = self.value.forward(totorch(state, self.value.device)) return tonumpy(value)
def get_policy(self, state: np.ndarray): with torch.no_grad(): policy, Adv = self.policy.forward( totorch(state, self.policy.device)) return tonumpy(policy), tonumpy(Adv)
def evaluate(self, state: np.ndarray): state = totorch(state, self.device) with torch.no_grad(): policy, value = self.forward(state) return tonumpy(policy), tonumpy(value)
def evaluate(self, state): with torch.no_grad(): Qsa = self.forward(totorch(state, self.device)) return Qsa.cpu().numpy()
def get_pixel_control(self, state:np.ndarray, action_reward, hidden): state, action_reward, hidden = totorch(state, self.device), totorch(action_reward, self.device), totorch_many(*hidden, device=self.device) with torch.no_grad(): lstm_state, _ = self.policy.lstm_forward(state, action_reward, hidden, done=None) Qaux = self.Qaux(lstm_state) return tonumpy(Qaux)