def _sac_loss_helper(self, train_batch, weights, ks, log_alpha, fw, gamma, sess): """Emulates SAC loss functions for tf and torch.""" # ks: # 0=log_alpha # 1=target log-alpha (not used) # 2=action hidden bias # 3=action hidden kernel # 4=action out bias # 5=action out kernel # 6=Q hidden bias # 7=Q hidden kernel # 8=Q out bias # 9=Q out kernel # 14=target Q hidden bias # 15=target Q hidden kernel # 16=target Q out bias # 17=target Q out kernel alpha = np.exp(log_alpha) cls = TorchSquashedGaussian if fw == "torch" else SquashedGaussian model_out_t = train_batch[SampleBatch.CUR_OBS] model_out_tp1 = train_batch[SampleBatch.NEXT_OBS] target_model_out_tp1 = train_batch[SampleBatch.NEXT_OBS] # get_policy_output action_dist_t = cls( fc( relu( fc(model_out_t, weights[ks[3]], weights[ks[2]], framework=fw)), weights[ks[5]], weights[ks[4]]), None) policy_t = action_dist_t.deterministic_sample() log_pis_t = action_dist_t.logp(policy_t) if sess: log_pis_t = sess.run(log_pis_t) policy_t = sess.run(policy_t) log_pis_t = np.expand_dims(log_pis_t, -1) # Get policy output for t+1. action_dist_tp1 = cls( fc( relu( fc(model_out_tp1, weights[ks[3]], weights[ks[2]], framework=fw)), weights[ks[5]], weights[ks[4]]), None) policy_tp1 = action_dist_tp1.deterministic_sample() log_pis_tp1 = action_dist_tp1.logp(policy_tp1) if sess: log_pis_tp1 = sess.run(log_pis_tp1) policy_tp1 = sess.run(policy_tp1) log_pis_tp1 = np.expand_dims(log_pis_tp1, -1) # Q-values for the actually selected actions. # get_q_values q_t = fc(relu( fc(np.concatenate([model_out_t, train_batch[SampleBatch.ACTIONS]], -1), weights[ks[7]], weights[ks[6]], framework=fw)), weights[ks[9]], weights[ks[8]], framework=fw) # Q-values for current policy in given current state. # get_q_values q_t_det_policy = fc(relu( fc(np.concatenate([model_out_t, policy_t], -1), weights[ks[7]], weights[ks[6]], framework=fw)), weights[ks[9]], weights[ks[8]], framework=fw) # Target q network evaluation. # target_model.get_q_values q_tp1 = fc(relu( fc(np.concatenate([target_model_out_tp1, policy_tp1], -1), weights[ks[15]], weights[ks[14]], framework=fw)), weights[ks[17]], weights[ks[16]], framework=fw) q_t_selected = np.squeeze(q_t, axis=-1) q_tp1 -= alpha * log_pis_tp1 q_tp1_best = np.squeeze(q_tp1, axis=-1) dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] if fw == "torch": dones = dones.float().numpy() rewards = rewards.numpy() q_tp1_best_masked = (1.0 - dones) * q_tp1_best q_t_selected_target = rewards + gamma * q_tp1_best_masked base_td_error = np.abs(q_t_selected - q_t_selected_target) td_error = base_td_error critic_loss = [ 0.5 * np.mean(np.power(q_t_selected_target - q_t_selected, 2.0)) ] target_entropy = -np.prod((1, )) alpha_loss = -np.mean(log_alpha * (log_pis_t + target_entropy)) actor_loss = np.mean(alpha * log_pis_t - q_t_det_policy) return critic_loss, actor_loss, alpha_loss, td_error
def _ddpg_loss_helper(self, train_batch, weights, ks, fw, gamma, huber_threshold, l2_reg, sess): """Emulates DDPG loss functions for tf and torch.""" model_out_t = train_batch[SampleBatch.CUR_OBS] target_model_out_tp1 = train_batch[SampleBatch.NEXT_OBS] # get_policy_output policy_t = sigmoid(2.0 * fc( relu(fc(model_out_t, weights[ks[1]], weights[ks[0]], framework=fw)), weights[ks[5]], weights[ks[4]])) # Get policy output for t+1 (target model). policy_tp1 = sigmoid(2.0 * fc( relu( fc(target_model_out_tp1, weights[ks[3]], weights[ks[2]], framework=fw)), weights[ks[7]], weights[ks[6]])) # Assume no smooth target policy. policy_tp1_smoothed = policy_tp1 # Q-values for the actually selected actions. # get_q_values q_t = fc(relu( fc(np.concatenate([model_out_t, train_batch[SampleBatch.ACTIONS]], -1), weights[ks[9]], weights[ks[8]], framework=fw)), weights[ks[11]], weights[ks[10]], framework=fw) twin_q_t = fc(relu( fc(np.concatenate([model_out_t, train_batch[SampleBatch.ACTIONS]], -1), weights[ks[13]], weights[ks[12]], framework=fw)), weights[ks[15]], weights[ks[14]], framework=fw) # Q-values for current policy in given current state. # get_q_values q_t_det_policy = fc(relu( fc(np.concatenate([model_out_t, policy_t], -1), weights[ks[9]], weights[ks[8]], framework=fw)), weights[ks[11]], weights[ks[10]], framework=fw) # Target q network evaluation. # target_model.get_q_values q_tp1 = fc(relu( fc(np.concatenate([target_model_out_tp1, policy_tp1_smoothed], -1), weights[ks[17]], weights[ks[16]], framework=fw)), weights[ks[19]], weights[ks[18]], framework=fw) twin_q_tp1 = fc(relu( fc(np.concatenate([target_model_out_tp1, policy_tp1_smoothed], -1), weights[ks[21]], weights[ks[20]], framework=fw)), weights[ks[23]], weights[ks[22]], framework=fw) q_t_selected = np.squeeze(q_t, axis=-1) twin_q_t_selected = np.squeeze(twin_q_t, axis=-1) q_tp1 = np.minimum(q_tp1, twin_q_tp1) q_tp1_best = np.squeeze(q_tp1, axis=-1) dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] if fw == "torch": dones = dones.float().numpy() rewards = rewards.numpy() q_tp1_best_masked = (1.0 - dones) * q_tp1_best q_t_selected_target = rewards + gamma * q_tp1_best_masked td_error = q_t_selected - q_t_selected_target twin_td_error = twin_q_t_selected - q_t_selected_target td_error = td_error + twin_td_error errors = huber_loss(td_error, huber_threshold) + \ huber_loss(twin_td_error, huber_threshold) critic_loss = np.mean(errors) actor_loss = -np.mean(q_t_det_policy) # Add l2-regularization if required. for name, var in weights.items(): if re.match("default_policy/actor_(hidden_0|out)/kernel", name): actor_loss += (l2_reg * l2_loss(var)) elif re.match("default_policy/sequential(_1)?/\\w+/kernel", name): critic_loss += (l2_reg * l2_loss(var)) return critic_loss, actor_loss, td_error