def get_target_network_action(self, obs, random_action): with torch.no_grad(): dist_param = self.target_network(H.Variable(obs)) action = self.target_network.sample( dist_param, deterministic=(not random_action) or self.require_deterministic) return action
def update(self, obs_batch, action_batch, reward_batch, next_obs_batch, done_mask): obs_batch = self.preprocess_state( H.Variable(torch.from_numpy(obs_batch).type(H.float_tensor))) next_obs_batch = self.preprocess_state( H.Variable(torch.from_numpy(next_obs_batch).type(H.float_tensor))) action_batch = H.Variable(torch.from_numpy(action_batch).long()) reward_batch = H.Variable(torch.from_numpy(reward_batch)) neg_done_mask = H.Variable( torch.from_numpy(1.0 - done_mask).type(H.float_tensor)) if H.use_cuda: action_batch = action_batch.cuda() reward_batch = reward_batch.cuda() # minimize (Q(s, a) - (r + gamma * max Q(s', a'; w'))^2 q_values = self.q_network(obs_batch).gather( 1, action_batch.unsqueeze(1)) # Q(s, a; w) if self.double_q_learning: _, next_state_actions = self.q_network(next_obs_batch).max( 1, keepdim=True) next_max_q_values = self.target_network(next_obs_batch).gather( 1, next_state_actions).squeeze().detach() else: next_max_q_values = self.target_network( next_obs_batch).detach().max(1)[0] # max Q(s', a'; w') td_error = self.calculate_td_error(q_values, next_max_q_values, reward_batch, neg_done_mask) clipped_td_error = td_error.clamp(-self.clip_error, self.clip_error) grad = clipped_td_error * -1.0 self.optimizer.step(q_values, grad=grad.data) self.num_updates += 1 # target networks <- online networks if self.num_updates % self.update_target_freq == 0: self.target_network.load_state_dict(self.q_network.state_dict())
def update(self, obs_batch, action_batch, reward_batch, next_obs_batch, done_mask): obs_batch = self.preprocess_state( H.Variable(torch.from_numpy(obs_batch).type(H.float_tensor))) next_obs_batch = self.preprocess_state( H.Variable(torch.from_numpy(next_obs_batch).type(H.float_tensor))) if self.action_spec['type'] == 'int': action_batch = H.Variable(torch.from_numpy(action_batch).long()) else: action_batch = H.Variable(torch.from_numpy(action_batch)) reward_batch = H.Variable(torch.from_numpy(reward_batch)) neg_done_mask = H.Variable( torch.from_numpy(1.0 - done_mask).type(H.float_tensor)) if H.use_cuda: action_batch = action_batch.cuda() reward_batch = reward_batch.cuda() estimated_rewards = self.estimate_rewards(obs_batch, reward_batch, neg_done_mask) if self.baseline_optimizer is not None: estimated_rewards = estimated_rewards.detach() # optimize the actor model loss_arguments = dict( obs_batch=obs_batch, action_batch=action_batch, reward_batch=estimated_rewards, next_obs_batch=next_obs_batch, neg_done_mask=neg_done_mask, ) self.optimizer.step(self.total_loss, loss_arguments, fn_reference=self.calculate_reference) # optimize the critic model if self.baseline_optimizer is not None: cumulative_rewards = self.calculate_cumulative_rewards( reward_batch, neg_done_mask, self.discount_factor) baseline_loss_arguments = dict( obs_batch=obs_batch, reward=cumulative_rewards, ) self.baseline_optimizer.step(self.baseline.loss, baseline_loss_arguments, fn_reference=self.baseline.reference) self.num_updates += 1
def update(self, obs_batch, action_batch, reward_batch, next_obs_batch, done_mask): obs_batch = self.preprocess_state( H.Variable(torch.from_numpy(obs_batch).type(H.float_tensor))) next_obs_batch = self.preprocess_state( H.Variable(torch.from_numpy(next_obs_batch).type(H.float_tensor))) if self.action_spec['type'] == 'int': action_batch = H.Variable(torch.from_numpy(action_batch).long()) else: action_batch = H.Variable(torch.from_numpy(action_batch)) reward_batch = H.Variable(torch.from_numpy(reward_batch)) neg_done_mask = H.Variable( torch.from_numpy(1.0 - done_mask).type(H.float_tensor)) if H.use_cuda: action_batch = action_batch.cuda() reward_batch = reward_batch.cuda() # predict action using target network next_target_actions = self.get_target_network_action( next_obs_batch, random_action=False) # predict Q values for next states next_q_values = self.predict_target_q( next_obs_batch, next_target_actions, reward_batch, neg_done_mask).detach() q_values = self.critic_network(obs_batch, action_batch) #critic_loss = (q_values - next_q_values).pow(2).mean() critic_loss = F.smooth_l1_loss(q_values, next_q_values) # update critic self.critic_optimizer.step(critic_loss) # update actor predicted_actions = self.get_action( obs_batch, random_action=False, update=True) actor_loss = -self.critic_network(obs_batch, predicted_actions).mean() self.optimizer.step(actor_loss) self.num_updates += 1 # target networks <- online networks if self.num_updates % self.update_target_freq == 0: self.update_target_model(self.target_network, self.network) self.update_target_model(self.target_critic_network, self.critic_network)
def get_action(self, obs, random_action, update): with torch.no_grad(): return self.q_network(H.Variable(obs)).data.max(1)[1].cpu()[0]
def predict(self, obs): obs = self.preprocess_state( torch.from_numpy(obs).type(H.float_tensor).unsqueeze(0)) return self.q_network(H.Variable(obs)).data