def _value_func_estimate(self): if len(self.buffer) < 32: return self.target_usage += 1 transitions = self.buffer.sample(32) batch = Transition(*zip(*transitions)) state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) next_state_batch = torch.cat(batch.next_state) state_action_values = self.dqn(state_batch).gather(1, action_batch) next_state_values = self.target(next_state_batch).max(1)[0].detach() expected_state_action_values = (next_state_values * self.gamma) + reward_batch loss = self.loss_fn( state_action_values, expected_state_action_values.unsqueeze(1), ) self.optimizer.zero_grad() loss.backward() for param in self.dqn.parameters(): param.grad.data.clamp_(-1, 1) self.optimizer.step() if self.target_usage == 10: self.target_usage = 0 self.target.load_state_dict(self.dqn.state_dict()) self.target.eval()
def learn_from_buffer(self): if len(self.training_data) < self.batch_size: return try: loss_ensemble = 0. for i in range(0, 10): transitions = self.training_data.sample(self.batch_size) batch = Transition(*zip(*transitions)) non_final_mask = torch.tensor(tuple( map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool) non_final_next_states = torch.cat( [s for s in batch.next_state if s is not None]) state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) state_action_values = self.model(state_batch).gather( 1, action_batch) next_state_values = torch.zeros(self.batch_size, device=device, dtype=torch.double) next_state_values[non_final_mask] = self.target_net( non_final_next_states).max(1)[0].detach() expected_state_action_values = self.gamma * next_state_values + reward_batch loss = self.loss_fn(state_action_values, expected_state_action_values.unsqueeze(1)) loss_ensemble += loss.item() self.optimizer.zero_grad() loss.backward() # for param in self.model.parameters(): # param.grad.data.clamp_(-1, 1) self.optimizer.step() self.running_loss = 0.8 * self.running_loss + 0.2 * loss_ensemble self.epsilon = 0.999 * self.epsilon except: print('{}: no non-terminal state'.format(self.agent_name))
def learn_from_buffer(self): for i in range(self.ensemble_size): if len(self.training_datas[i]) < self.batch_size: return loss_ensemble = 0.0 for _ in range(10): for i in range(self.ensemble_size): transitions = self.training_datas[i].sample(self.batch_size) batch = Transition(*zip(*transitions)) non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool) state_batch = torch.cat(batch.state) reward_batch = torch.cat(batch.reward) state_action_values = self.models[i](state_batch) all_none = True for s in batch.next_state: if s is not None: all_none = False next_state_values = torch.zeros(self.batch_size, device=device, dtype=torch.double) if not all_none: non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]) next_state_values[non_final_mask] = self.target_nets[i](non_final_next_states).max(1)[0].reshape((-1)).detach() expected_state_action_values = self.gamma * next_state_values + reward_batch loss = self.loss_fn(state_action_values, expected_state_action_values.unsqueeze(1)) loss_ensemble += loss.item() self.optimizers[i].zero_grad() loss.backward() for param in self.models[i].parameters(): param.grad.data.clamp_(-1, 1) self.optimizers[i].step() self.running_loss = 0.8 * self.running_loss + 0.2 * loss_ensemble
def _value_func_estimate(self): if len(self.buffer) < 32: return self.target_usage += 1 transitions = self.buffer.sample(32) batch = Transition(*zip(*transitions)) state_batch = torch.cat(batch.state) reward_batch = torch.cat(batch.reward) reward_l_batch = [] reward_u_batch = [] for state in state_batch: cur_state = state.tolist() reward, std = self.reward_gp.predict(np.array([cur_state]), True) reward_l_batch.append( torch.tensor([reward[0] - self.beta * std[0]], dtype=torch.double)) reward_u_batch.append( torch.tensor([reward[0] + self.beta * std[0]], dtype=torch.double)) reward_l_batch = torch.cat(reward_l_batch) reward_u_batch = torch.cat(reward_u_batch) action_batch = torch.cat(batch.action) next_state_batch = torch.cat(batch.next_state) state_action_values = self.dqn(state_batch).gather(1, action_batch) next_state_values = self.target(next_state_batch).max(1)[0].detach() expected_state_action_values = (next_state_values * self.gamma) + reward_batch loss = self.loss_fn( state_action_values, expected_state_action_values.unsqueeze(1), ) self.optimizer.zero_grad() loss.backward() for param in self.dqn.parameters(): param.grad.data.clamp_(-1, 1) self.optimizer.step() state_action_values = self.dqn_u(state_batch).gather(1, action_batch) next_state_values = self.target_u(next_state_batch).max(1)[0].detach() expected_state_action_values = (next_state_values * self.gamma) + reward_u_batch loss = self.loss_fn( state_action_values, expected_state_action_values.unsqueeze(1), ) self.optimizer_u.zero_grad() loss.backward() for param in self.dqn_u.parameters(): param.grad.data.clamp_(-1, 1) self.optimizer_u.step() state_action_values = self.dqn_l(state_batch).gather(1, action_batch) next_state_values = self.target_l(next_state_batch).max(1)[0].detach() expected_state_action_values = (next_state_values * self.gamma) + reward_l_batch loss = self.loss_fn( state_action_values, expected_state_action_values.unsqueeze(1), ) self.optimizer_l.zero_grad() loss.backward() for param in self.dqn_l.parameters(): param.grad.data.clamp_(-1, 1) self.optimizer_l.step() if self.target_usage == 10: self.target_usage = 0 self.target.load_state_dict(self.dqn.state_dict()) self.target.eval() self.target_u.load_state_dict(self.dqn_u.state_dict()) self.target_u.eval() self.target_l.load_state_dict(self.dqn_l.state_dict()) self.target_l.eval()
def optimize_model(self): """ optimize model for 1 time step by sampling transitions, calculating TD errors, loss and updating the memory transitions' probabilities """ if len(self.memory) < BATCH_SIZE: return transitions = self.memory.sample(BATCH_SIZE) weights = [t[1] for t in transitions] positions = [t[2] for t in transitions] transitions = [t[0] for t in transitions] max_weight = self.memory.max_weight batch = Transition(*zip(*transitions)) # Compute a mask of non-final states and concatenate the batch elements # (a final state would've been the one after which simulation ended) non_final_mask = torch.tensor(tuple( map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool) non_final_next_states = torch.cat( [s for s in batch.next_state if s is not None]) state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) state_action_values = self.policy_net(state_batch).gather( 1, action_batch) next_state_values = torch.zeros(BATCH_SIZE, device=device) # DDQN: separate action selection and evaluation with torch.no_grad(): new_actions = self.policy_net(non_final_next_states).max( 1)[1].view(-1, 1) next_state_values[non_final_mask] = self.target_net( non_final_next_states).gather(1, new_actions).detach().view(-1) # Compute the expected Q values expected_state_action_values = (next_state_values * GAMMA) + reward_batch # calculate TD error TD_error = (expected_state_action_values.unsqueeze(1) - state_action_values).detach().view(-1).cpu().numpy() TD_error = np.abs(TD_error) # calculate importance sampling weights IS_weights = np.divide(weights, max_weight) # multiply the weights to the loss function to inject into the weight update square_indices = np.where(TD_error <= 1)[0] for i, w in enumerate(IS_weights): if i in square_indices: factor = torch.tensor(np.sqrt(w), device=device, dtype=torch.float32) else: factor = torch.tensor(w, device=device, dtype=torch.float32) state_action_values[i] *= factor expected_state_action_values[i] *= factor # Compute Huber loss loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1)) # Optimize the model self.optimizer.zero_grad() loss.backward() # update priorities of the replayed transitions self.memory.update(positions, TD_error) for param in self.policy_net.parameters(): param.grad.data.clamp_(-1, 1) self.optimizer.step()