def compute_actor_loss(self, batch): """Compute loss for actor. Preconditions: q_function must have seen up to s_{t-1} and s_{t-1}. policy must have seen up to s_{t-1}. Preconditions: q_function must have seen up to s_t and s_t. policy must have seen up to s_t. """ batch_state = batch['state'] batch_action = batch['action'] batch_size = len(batch_action) # Estimated policy observes s_t onpolicy_actions = self.policy(batch_state, test=False).sample() # Q(s_t, mu(s_t)) is evaluated. # This should not affect the internal state of Q. with state_kept(self.q_function): q = self.q_function(batch_state, onpolicy_actions, test=False) # Estimated Q-function observes s_t and a_t if isinstance(self.q_function, Recurrent): self.q_function.update_state(batch_state, batch_action, test=False) # Since we want to maximize Q, loss is negation of Q loss = -F.sum(q) / batch_size # Update stats self.average_actor_loss *= self.average_loss_decay self.average_actor_loss += ((1 - self.average_loss_decay) * float(loss.data)) return loss
def _compute_y_and_t(self, exp_batch): batch_state = exp_batch['state'] batch_size = len(exp_batch['reward']) qout = self.q_function(batch_state) batch_actions = exp_batch['action'] batch_q = qout.evaluate_actions(batch_actions) # Compute target values with chainer.no_backprop_mode(): target_qout = self.target_q_function(batch_state) batch_next_state = exp_batch['next_state'] with state_kept(self.target_q_function): target_next_qout = self.target_q_function(batch_next_state) next_q_max = F.reshape(target_next_qout.max, (batch_size, )) batch_rewards = exp_batch['reward'] batch_terminal = exp_batch['is_state_terminal'] # T Q: Bellman operator t_q = batch_rewards + exp_batch['discount'] * \ (1.0 - batch_terminal) * next_q_max # T_AL Q: advantage learning operator cur_advantage = F.reshape( target_qout.compute_advantage(batch_actions), (batch_size, )) tal_q = t_q + self.alpha * cur_advantage return batch_q, tal_q
def update_on_policy(self, statevar): assert self.t_start < self.t if not self.disable_online_update: next_values = {} for t in range(self.t_start + 1, self.t): next_values[t - 1] = self.past_values[t] if statevar is None: next_values[self.t - 1] = chainer.Variable( self.xp.zeros_like(self.past_values[self.t - 1].array)) else: with state_kept(self.model): _, v = self.model(statevar) next_values[self.t - 1] = v log_probs = { t: self.past_action_distrib[t].log_prob( self.xp.asarray(self.xp.expand_dims(a, 0))) for t, a in self.past_actions.items() } self.online_batch_losses.append( self.compute_loss(t_start=self.t_start, t_stop=self.t, rewards=self.past_rewards, values=self.past_values, next_values=next_values, log_probs=log_probs)) if len(self.online_batch_losses) == self.batchsize: loss = chainerrl.functions.sum_arrays( self.online_batch_losses) / self.batchsize self.update(loss) self.online_batch_losses = [] self.init_history_data_for_online_update()
def update_on_policy(self, statevar): assert self.t_start < self.t if not self.disable_online_update: if statevar is None: R = 0 else: with chainer.no_backprop_mode(): with state_kept(self.model): action_distrib, action_value = self.model(statevar) v = compute_state_value_as_expected_action_value( action_value, action_distrib) R = v self.update(t_start=self.t_start, t_stop=self.t, R=R, states=self.past_states, actions=self.past_actions, rewards=self.past_rewards, values=self.past_values, action_values=self.past_action_values, action_log_probs=self.past_action_log_prob, action_distribs=self.past_action_distrib, avg_action_distribs=self.past_avg_action_distrib) self.init_history_data_for_online_update()
def _compute_target_values(self, exp_batch): """Compute a batch of target return distributions.""" batch_next_state = exp_batch['next_state'] batch_rewards = exp_batch['reward'] batch_terminal = exp_batch['is_state_terminal'] with chainer.using_config('train', False), state_kept(self.q_function): next_qout = self.q_function(batch_next_state) target_next_qout = self.target_q_function(batch_next_state) next_q_max = target_next_qout.evaluate_actions( next_qout.greedy_actions) batch_size = batch_rewards.shape[0] z_values = target_next_qout.z_values n_atoms = z_values.size # next_q_max: (batch_size, n_atoms) next_q_max = target_next_qout.max_as_distribution.array assert next_q_max.shape == (batch_size, n_atoms), next_q_max.shape # Tz: (batch_size, n_atoms) Tz = (batch_rewards[..., None] + (1.0 - batch_terminal[..., None]) * self.xp.expand_dims(exp_batch['discount'], 1) * z_values[None]) return _apply_categorical_projection(Tz, next_q_max, z_values)
def _compute_y_and_t(self, exp_batch, gamma): batch_state = exp_batch['state'] batch_size = len(exp_batch['reward']) qout = self.q_function(batch_state, test=False) batch_actions = exp_batch['action'] batch_q = qout.evaluate_actions(batch_actions) # Compute target values with chainer.no_backprop_mode(): target_qout = self.target_q_function(batch_state, test=True) batch_next_state = exp_batch['next_state'] with state_kept(self.q_function): next_qout = self.q_function(batch_next_state, test=False) with state_kept(self.target_q_function): target_next_qout = self.target_q_function(batch_next_state, test=True) next_q_max = F.reshape( target_next_qout.evaluate_actions(next_qout.greedy_actions), (batch_size, )) batch_rewards = exp_batch['reward'] batch_terminal = exp_batch['is_state_terminal'] # T Q: Bellman operator t_q = batch_rewards + self.gamma * \ (1.0 - batch_terminal) * next_q_max # T_PAL Q: persistent advantage learning operator cur_advantage = F.reshape( target_qout.compute_advantage(batch_actions), (batch_size, )) next_advantage = F.reshape( target_next_qout.compute_advantage(batch_actions), (batch_size, )) tpal_q = t_q + self.alpha * \ F.maximum(cur_advantage, next_advantage) return batch_q, tpal_q
def compute_critic_loss(self, batch): """Compute loss for critic. Preconditions: target_q_function must have seen up to s_t and a_t. target_policy must have seen up to s_t. q_function must have seen up to s_{t-1}. Postconditions: target_q_function must have seen up to s_{t+1} and a_{t+1}. target_policy must have seen up to s_{t+1}. q_function must have seen up to s_t. """ batch_next_state = batch['next_state'] batch_rewards = batch['reward'] batch_terminal = batch['is_state_terminal'] batch_state = batch['state'] batch_actions = batch['action'] batch_next_actions = batch['next_action'] batchsize = len(batch_rewards) with chainer.no_backprop_mode(): # Target policy observes s_{t+1} next_actions = self.target_policy(batch_next_state, test=True).sample() # Q(s_{t+1}, mu(a_{t+1})) is evaluated. # This should not affect the internal state of Q. with state_kept(self.target_q_function): next_q = self.target_q_function(batch_next_state, next_actions, test=True) # Target Q-function observes s_{t+1} and a_{t+1} if isinstance(self.target_q_function, Recurrent): self.target_q_function.update_state(batch_next_state, batch_next_actions, test=True) target_q = batch_rewards + self.gamma * \ (1.0 - batch_terminal) * F.reshape(next_q, (batchsize,)) # Estimated Q-function observes s_t and a_t predict_q = F.reshape( self.q_function(batch_state, batch_actions, test=False), (batchsize, )) loss = F.mean_squared_error(target_q, predict_q) # Update stats self.average_critic_loss *= self.average_loss_decay self.average_critic_loss += ((1 - self.average_loss_decay) * float(loss.data)) return loss
def _compute_target_values(self, exp_batch, gamma): batch_next_state = exp_batch['next_state'] with state_kept(self.q_function): next_qout = self.q_function(batch_next_state, test=True) target_next_qout = self.target_q_function(batch_next_state, test=True) next_q_max = target_next_qout.evaluate_actions( next_qout.greedy_actions) batch_rewards = exp_batch['reward'] batch_terminal = exp_batch['is_state_terminal'] return batch_rewards + self.gamma * (1.0 - batch_terminal) * next_q_max
def _compute_target_values(self, exp_batch, gamma): batch_next_state = exp_batch['next_state'] with chainer.using_config('train', False), state_kept(self.q_function): next_qout = self.q_function(batch_next_state) target_next_qout = self.target_q_function(batch_next_state) next_q_max = target_next_qout.evaluate_actions( next_qout.greedy_actions) batch_rewards = exp_batch['reward'] batch_terminal = exp_batch['is_state_terminal'] return batch_rewards + self.gamma * (1.0 - batch_terminal) * next_q_max
def update(self, statevar): assert self.t_start < self.t # Update if statevar is None: R = 0 else: with state_kept(self.target_q_function): R = float(self.target_q_function(statevar).max.data) loss = 0 for i in reversed(range(self.t_start, self.t)): R *= self.gamma R += self.past_rewards[i] q = F.reshape(self.past_action_values[i], (1, 1)) # Accumulate gradients of Q-function loss += F.sum( F.huber_loss(q, chainer.Variable( np.asarray([[R]], dtype=np.float32)), delta=1.0)) # Do we need to normalize losses by (self.t - self.t_start)? # Otherwise, loss scales can be different in case of self.t_max # and in case of termination. # I'm not sure but if we need to normalize losses... # loss /= self.t - self.t_start # Compute gradients using thread-specific model self.q_function.zerograds() loss.backward() # Copy the gradients to the globally shared model self.shared_q_function.zerograds() copy_param.copy_grad(self.shared_q_function, self.q_function) # Update the globally shared model self.optimizer.update() self.sync_parameters() if isinstance(self.q_function, Recurrent): self.q_function.unchain_backward() self.past_action_values = {} self.past_states = {} self.past_rewards = {} self.t_start = self.t
def compute_actor_loss(self, batch): """Compute loss for actor. Preconditions: q_function must have seen up to s_{t-1} and s_{t-1}. policy must have seen up to s_{t-1}. Postconditions: q_function must have seen up to s_t and s_t. policy must have seen up to s_t. """ batch_state = batch['state'] batch_action = batch['action'] batch_size = len(batch_action) # Estimated policy observes s_t onpolicy_actions = self.policy(batch_state).sample() # Q(s_t, mu(s_t)) is evaluated. # This should not affect the internal state of Q. with state_kept(self.q_function): q = self.q_function(batch_state, onpolicy_actions) # Estimated Q-function observes s_t and a_t if isinstance(self.q_function, Recurrent): self.q_function.update_state(batch_state, batch_action) # Avoid the numpy #9165 bug (see also: chainer #2744) q = q[:, :] # Since we want to maximize Q, loss is negation of Q loss = -F.sum(q) / batch_size if self.l2_action_penalty: loss += self.l2_action_penalty \ * F.square(onpolicy_actions) / batch_size # Update stats self.average_actor_loss *= self.average_loss_decay self.average_actor_loss += ((1 - self.average_loss_decay) * float(loss.array)) return loss
def update_on_policy(self, statevar): assert self.t_start < self.t if not self.disable_online_update: if statevar is None: R = 0 else: with chainer.no_backprop_mode(): with state_kept(self.model): action_distrib, action_value, v = self.model(statevar) R = float(v.data) self.update( t_start=self.t_start, t_stop=self.t, R=R, states=self.past_states, actions=self.past_actions, rewards=self.past_rewards, values=self.past_values, action_values=self.past_action_values, action_distribs=self.past_action_distrib, action_distribs_mu=None, avg_action_distribs=self.past_avg_action_distrib) self.init_history_data_for_online_update()
def update(self, statevar): assert self.t_start < self.t if statevar is None: R = 0 else: with state_kept(self.model): _, vout, __ = self.model.pi_and_v(statevar) ####################### R = F.cast(vout.data, 'float32') #R = float(vout.data) ####################### pi_loss = 0 v_loss = 0 for i in reversed(range(self.t_start, self.t)): R *= self.gamma R += self.past_rewards[i] if self.use_average_reward: R -= self.average_reward v = self.past_values[i] advantage = R - v if self.use_average_reward: self.average_reward += self.average_reward_tau * \ float(advantage.data) # Accumulate gradients of policy log_prob = self.past_action_log_prob[i] entropy = self.past_action_entropy[i] # Log probability is increased proportionally to advantage ############################## pi_loss -= log_prob * F.cast(advantage.data, 'float32') #pi_loss -= log_prob * float(advantage.data) ############################## # Entropy is maximized pi_loss -= self.beta * entropy # Accumulate gradients of value function v_loss += (v - R) ** 2 / 2 if self.pi_loss_coef != 1.0: pi_loss *= self.pi_loss_coef if self.v_loss_coef != 1.0: v_loss *= self.v_loss_coef # Normalize the loss of sequences truncated by terminal states if self.keep_loss_scale_same and \ self.t - self.t_start < self.t_max: factor = self.t_max / (self.t - self.t_start) pi_loss *= factor v_loss *= factor if self.normalize_grad_by_t_max: pi_loss /= self.t - self.t_start v_loss /= self.t - self.t_start if self.process_idx == 0: logger.debug('pi_loss:%s v_loss:%s', pi_loss.data, v_loss.data) ########################## #total_loss = pi_loss + F.reshape(v_loss, pi_loss.data.shape) total_loss = F.mean(pi_loss + F.reshape(v_loss, pi_loss.data.shape)) ########################## # Compute gradients using thread-specific model self.model.zerograds() total_loss.backward() # Copy the gradients to the globally shared model self.shared_model.zerograds() copy_param.copy_grad( target_link=self.shared_model, source_link=self.model) # Update the globally shared model if self.process_idx == 0: norm = sum(np.sum(np.square(param.grad)) for param in self.optimizer.target.params()) logger.debug('grad norm:%s', norm) self.optimizer.update() if self.process_idx == 0: logger.debug('update') self.sync_parameters() if isinstance(self.model, Recurrent): self.model.unchain_backward() self.past_action_log_prob = {} self.past_action_entropy = {} self.past_states = {} self.past_rewards = {} self.past_values = {} self.t_start = self.t