def updateNetwork(self, samples): # organize the mini-batch so that we can request "columns" from the data # e.g. we can get all of the actions, or all of the states with a single call batch = getBatchColumns(samples) # compute Q(s, a) for each sample in mini-batch Qs, x = self.policy_net(batch.states) Qsa = Qs.gather(1, batch.actions).squeeze() # by default Q(s', a') = 0 unless the next states are non-terminal Qspap = torch.zeros(batch.size, device=device) # if we don't have any non-terminal next states, then no need to bootstrap if batch.nterm_sp.shape[0] > 0: Qsp, _ = self.target_net(batch.nterm_sp) # bootstrapping term is the max Q value for the next-state # only assign to indices where the next state is non-terminal Qspap[batch.nterm] = Qsp.max(1).values # compute the empirical MSBE for this mini-batch and let torch auto-diff to optimize # don't worry about detaching the bootstrapping term for semi-gradient Q-learning # the target network handles that target = batch.rewards + batch.gamma * Qspap.detach() td_loss = 0.5 * f.mse_loss(target, Qsa) # make sure we have no gradients left over from previous update self.optimizer.zero_grad() self.target_net.zero_grad() # compute the entire gradient of the network using only the td error td_loss.backward() # update the *policy network* using the combined gradients self.optimizer.step()
def update(self, s, a, sp, r, gamma): self.buffer.add((s, a, sp, r, gamma)) self.steps += 1 if self.steps % self.target_refresh == 0 and self.target_refresh > 1: cloneNetworkWeights(self.value_net, self.target_net) if len(self.buffer) > self.batch_size + 1: samples, idcs = self.buffer.sample(self.batch_size) batch = getBatchColumns(samples) predictions = self.forward(batch) tde = self.updateNetwork(batch, predictions) self.buffer.update_priorities(idcs, tde)
def updateActionNet(self, samples, q_net, target_q_net, optimizer, storeList): batch = getBatchColumns(samples) Qs, x = q_net(batch.states) # Qsa = Qs.squeeze() # for i in range(len(batch.actions)): # storeList.append(Qsa.detach().numpy()[i]) Qspap = torch.zeros(batch.size, device=device) ############ ============ CHECK ================= ############################### if batch.nterm_sp.shape[0] > 0: ## Qsp, _ = target_q_net(batch.nterm_sp) #### Is this correct ???? Qsp_back, _ = self.back_target_q_net(batch.nterm_sp) Qsp_stay, _ = self.stay_target_q_net(batch.nterm_sp) Qsp_forward, _ = self.forward_target_q_net(batch.nterm_sp) Qsp = torch.hstack([Qsp_back, Qsp_stay, Qsp_forward]) # bootstrapping term is the max Q value for the next-state # only assign to indices where the next state is non-terminal Qspap[batch.nterm] = Qsp.max(1).values ############ ============ CHECK ================= ############################### # compute the empirical MSBE for this mini-batch and let torch auto-diff to optimize # don't worry about detaching the bootstrapping term for semi-gradient Q-learning # the target network handles that target = batch.rewards + batch.gamma * Qspap.detach() td_loss = 0.5 * f.mse_loss(target, Qsa) # make sure we have no gradients left over from previous update optimizer.zero_grad() target_q_net.zero_grad() self.back_target_q_net.zero_grad() self.stay_target_q_net.zero_grad() self.forward_target_q_net.zero_grad() # compute the entire gradient of the network using only the td error td_loss.backward() Qs_state_array, _ = q_net(self.state_array) Qsa_mean_states = torch.mean(Qs_state_array, 0) storeList.append(Qsa_mean_states[0].detach().numpy()) # update the *policy network* using the combined gradients optimizer.step()
def updateNetwork(self, samples): self.updates += 1 # organize the mini-batch so that we can request "columns" from the data # e.g. we can get all of the actions, or all of the states with a single call batch = getBatchColumns(samples) # compute Q(s, a) for each sample in mini-batch Qs, x = self.policy_net(batch.states) Qsa = Qs.gather(1, batch.actions).squeeze() # by default Q(s', a') = 0 unless the next states are non-terminal Qspap = torch.zeros(len(samples), device=device) # if we don't have any non-terminal next states, then no need to bootstrap if batch.nterm_sp.shape[0] > 0: Qsp, _ = self.target_net(batch.nterm_sp) # bootstrapping term is the max Q value for the next-state # only assign to indices where the next state is non-terminal Qspap[batch.nterm] = Qsp.max(1).values # compute the empirical MSBE for this mini-batch and let torch auto-diff to optimize # don't worry about detaching the bootstrapping term for semi-gradient Q-learning # the target network handles that target = batch.rewards + batch.gamma * Qspap.detach() td_loss = 0.5 * f.mse_loss(target, Qsa) # compute E[\delta | x] ~= <h, x> with torch.no_grad(): delta_hats = torch.matmul(x, self.h.t()) delta_hat = delta_hats.gather(1, batch.actions) # the gradient correction term is gamma * <h, x> * \nabla_w Q(s', a') # to compute this gradient, we use pytorch auto-diff correction_loss = torch.mean(batch.gamma * delta_hat * Qspap) # make sure we have no gradients left over from previous update self.optimizer.zero_grad() self.target_net.zero_grad() # compute the entire gradient of the network using only the td error td_loss.backward() # if we have non-terminal states in the mini-batch # the compute the correction term using the gradient of the *target network* if batch.nterm_sp.shape[0] > 0: correction_loss.backward() # add the gradients of the target network for the correction term to the gradients for the td error for (policy_param, target_param) in zip(self.policy_net.parameters(), self.target_net.parameters()): policy_param.grad.add_(target_param.grad) # update the *policy network* using the combined gradients self.optimizer.step() # update the secondary weights using a *fixed* feature representation generated by the policy network with torch.no_grad(): delta = target - Qsa dh = (delta - delta_hat) * x # compute the update for each action independently # assume that there is a separate `h` vector for each individual action for a in range(self.actions): mask = (batch.actions == a).squeeze(1) # if this action was never taken in this mini-batch # then skip the update for this action if mask.sum() == 0: continue # the update for `h` minus the regularizer h_update = dh[mask].mean(0) - self.beta * self.h[a] # ADAM optimizer with bias correction # keep a separate set of weights for each action here as well self.v[a] = self.beta_2 * self.v[a] + (1 - self.beta_2) * (h_update**2) self.m[a] = self.beta_1 * self.m[a] + (1 - self.beta_1) * h_update m = self.m[a] / (1 - self.beta_1**self.updates) v = self.v[a] / (1 - self.beta_2**self.updates) self.h[a] = self.h[a] + self.alpha * m / (torch.sqrt(v) + self.eps)