def ddpg_step(policy_net, policy_net_target, value_net, value_net_target, optimizer_policy, optimizer_value, states, actions, rewards, next_states, masks, gamma, polyak): masks = masks.unsqueeze(-1) rewards = rewards.unsqueeze(-1) """update critic""" values = value_net(states, actions) with torch.no_grad(): target_next_values = value_net_target(next_states, policy_net_target(next_states)) target_values = rewards + gamma * masks * target_next_values value_loss = nn.MSELoss()(values, target_values) optimizer_value.zero_grad() value_loss.backward() optimizer_value.step() """update actor""" policy_loss = - value_net(states, policy_net(states)).mean() optimizer_policy.zero_grad() policy_loss.backward() optimizer_policy.step() """soft update target nets""" policy_net_flat_params = get_flat_params(policy_net) policy_net_target_flat_params = get_flat_params(policy_net_target) set_flat_params(policy_net_target, polyak * policy_net_target_flat_params + (1 - polyak) * policy_net_flat_params) value_net_flat_params = get_flat_params(value_net) value_net_target_flat_params = get_flat_params(value_net_target) set_flat_params(value_net_target, polyak * value_net_target_flat_params + (1 - polyak) * value_net_flat_params) return value_loss, policy_loss
def line_search(model, f, x, step_dir, expected_improve, max_backtracks=10, accept_ratio=0.1): """ max f(x) <=> min -f(x) line search sufficient condition: -f(x_new) <= -f(x) + -e coeff * step_dir perform line search method for choosing step size :param model: :param f: :param x: :param step_dir: direction to update model parameters :param expected_improve: :param max_backtracks: :param accept_ratio: :return: """ f_val = f(False).item() for step_coefficient in [.5**k for k in range(max_backtracks)]: x_new = x + step_coefficient * step_dir set_flat_params(model, x_new) f_val_new = f(False).item() actual_improve = f_val_new - f_val improve = expected_improve * step_coefficient ratio = actual_improve / improve if ratio > accept_ratio: return True, x_new return False, x
def update_policy(policy_net: nn.Module, states, actions, old_log_probs, advantages, max_kl, damping): def get_loss(grad=True): log_probs = policy_net.get_log_prob(states, actions) if not grad: log_probs = log_probs.detach() ratio = torch.exp(log_probs - old_log_probs) loss = (ratio * advantages).mean() return loss def Hvp(v): """ compute vector product of second order derivative of KL_Divergence Hessian and v :param v: vector :return: \nabla \nabla H @ v """ # compute kl divergence between current policy and old policy kl = policy_net.get_kl(states) kl = kl.mean() # first order gradient kl grads = torch.autograd.grad(kl, policy_net.parameters(), create_graph=True) flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) kl_v = (flat_grad_kl * v).sum() # flag_grad_kl.T @ v # second order gradient of kl grads = torch.autograd.grad(kl_v, policy_net.parameters()) flat_grad_grad_kl = torch.cat( [grad.contiguous().view(-1) for grad in grads]).detach() return flat_grad_grad_kl + v * damping # compute first order approximation to Loss loss = get_loss() loss_grads = autograd.grad(loss, policy_net.parameters()) loss_grad = torch.cat([grad.view(-1) for grad in loss_grads]).detach() # g.T # conjugate gradient solve : Hx = g # apply vector product strategy here to compute Hx by `Hvp` # approximation solution of x'= H^(-1)g step_dir = conjugate_gradient(Hvp, loss_grad) # g.T H^(-1) g; another implementation: Hvp(step_dir) @ step_dir shs = Hvp(step_dir).t() @ step_dir lm = torch.sqrt(2 * max_kl / shs) step = lm * step_dir # update direction for policy nets expected_improve = loss_grad.t() @ step """ line search for step size """ current_flat_parameters = get_flat_params(policy_net) # theta success, new_flat_parameters = line_search(policy_net, get_loss, current_flat_parameters, step, expected_improve, 10) set_flat_params(policy_net, new_flat_parameters) # success indicating whether TRPO works as expected return success
def value_objective_grad_func(value_net_flat_params): set_flat_params(value_net, DOUBLE(value_net_flat_params)) for param in value_net.parameters(): if param.grad is not None: param.grad.data.fill_(0) values_pred = value_net(states) value_loss = nn.MSELoss()(values_pred, returns) # weight decay for param in value_net.parameters(): value_loss += param.pow(2).sum() * l2_reg value_loss.backward() # to get the grad objective_value_loss_grad = get_flat_grad_params( value_net).detach().cpu().numpy() return objective_value_loss_grad
def value_objective_func(value_net_flat_params): """ get value_net loss :param value_net_flat_params: numpy :return: """ set_flat_params(value_net, FLOAT(value_net_flat_params)) values_pred = value_net(states) value_loss = nn.MSELoss()(values_pred, returns) # weight decay for param in value_net.parameters(): value_loss += param.pow(2).sum() * l2_reg objective_value_loss = value_loss.item() # print("Current value loss: ", objective_value_loss) return objective_value_loss
def value_objective_grad_func(value_net_flat_params): """ objective function for scipy optimizing """ set_flat_params(value_net, FLOAT(value_net_flat_params)) for param in value_net.parameters(): if param.grad is not None: param.grad.data.fill_(0) values_pred = value_net(states) value_loss = nn.MSELoss()(values_pred, returns) # weight decay for param in value_net.parameters(): value_loss += param.pow(2).sum() * l2_reg value_loss.backward() # to get the grad objective_value_loss_grad = get_flat_grad_params( value_net).detach().cpu().numpy().astype(np.float64) return objective_value_loss_grad
def trpo_step(policy_net, value_net, states, actions, returns, advantages, old_log_probs, max_kl, damping, l2_reg, optimizer_value=None): """ Update by TRPO algorithm """ """update critic""" def value_objective_func(value_net_flat_params): """ get value_net loss :param value_net_flat_params: numpy :return: """ set_flat_params(value_net, FLOAT(value_net_flat_params)) values_pred = value_net(states) value_loss = nn.MSELoss()(values_pred, returns) # weight decay for param in value_net.parameters(): value_loss += param.pow(2).sum() * l2_reg objective_value_loss = value_loss.item() # print("Current value loss: ", objective_value_loss) return objective_value_loss def value_objective_grad_func(value_net_flat_params): """ objective function for scipy optimizing """ set_flat_params(value_net, FLOAT(value_net_flat_params)) for param in value_net.parameters(): if param.grad is not None: param.grad.data.fill_(0) values_pred = value_net(states) value_loss = nn.MSELoss()(values_pred, returns) # weight decay for param in value_net.parameters(): value_loss += param.pow(2).sum() * l2_reg value_loss.backward() # to get the grad objective_value_loss_grad = get_flat_grad_params( value_net).detach().cpu().numpy().astype(np.float64) return objective_value_loss_grad if optimizer_value is None: """ update by scipy optimizing, for detail about L-BFGS-B: ref: https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html#optimize-minimize-lbfgsb """ value_net_flat_params_old = get_flat_params( value_net).detach().cpu().numpy().astype( np.float64) # initial guess res = opt.minimize(value_objective_func, value_net_flat_params_old, method='L-BFGS-B', jac=value_objective_grad_func, options={ "maxiter": 30, "disp": False }) # print("Call L-BFGS-B, result: ", res) value_net_flat_params_new = res.x set_flat_params(value_net, FLOAT(value_net_flat_params_new)) else: """ update by gradient descent """ for _ in range(10): values_pred = value_net(states) value_loss = nn.MSELoss()(values_pred, returns) # weight decay for param in value_net.parameters(): value_loss += param.pow(2).sum() * l2_reg optimizer_value.zero_grad() value_loss.backward() optimizer_value.step() """update policy""" update_policy(policy_net, states, actions, old_log_probs, advantages, max_kl, damping)
def sac_alpha_step(policy_net, q_net_1, q_net_2, alpha, q_net_target_1, q_net_target_2, optimizer_policy, optimizer_q_net_1, optimizer_q_net_2, optimizer_a, states, actions, rewards, next_states, masks, gamma, polyak, target_entropy, update_target=False): rewards = rewards.unsqueeze(-1) masks = masks.unsqueeze(-1) """update qvalue net""" with torch.no_grad(): next_actions, next_log_probs = policy_net.rsample(next_states) target_q_value = torch.min( q_net_target_1(next_states, next_actions), q_net_target_2(next_states, next_actions)) - alpha * next_log_probs target_q_value = rewards + gamma * masks * target_q_value q_value_1 = q_net_1(states, actions) q_value_loss_1 = nn.MSELoss()(q_value_1, target_q_value) optimizer_q_net_1.zero_grad() q_value_loss_1.backward() optimizer_q_net_1.step() q_value_2 = q_net_2(states, actions) q_value_loss_2 = nn.MSELoss()(q_value_2, target_q_value) optimizer_q_net_2.zero_grad() q_value_loss_2.backward() optimizer_q_net_2.step() """update policy net""" new_actions, log_probs = policy_net.rsample(states) min_q = torch.min(q_net_1(states, new_actions), q_net_2(states, new_actions)) policy_loss = (alpha * log_probs - min_q).mean() optimizer_policy.zero_grad() policy_loss.backward() optimizer_policy.step() """update alpha""" alpha_loss = -alpha * (log_probs.detach() + target_entropy).mean() optimizer_a.zero_grad() alpha_loss.backward() optimizer_a.step() if update_target: """ soft update target qvalue net """ q_net_1_flat_params = get_flat_params(q_net_1) q_net_target_1_flat_params = get_flat_params(q_net_target_1) set_flat_params(q_net_target_1, (1 - polyak) * q_net_1_flat_params + polyak * q_net_target_1_flat_params) q_net_2_flat_params = get_flat_params(q_net_2) q_net_target_2_flat_params = get_flat_params(q_net_target_2) set_flat_params(q_net_target_2, (1 - polyak) * q_net_2_flat_params + polyak * q_net_target_2_flat_params) return { "q_value_loss_1": q_value_loss_1, "q_value_loss_2": q_value_loss_2, "policy_loss": policy_loss, "alpha_loss": alpha_loss }
def sac_step(policy_net, value_net, value_net_target, q_net_1, q_net_2, optimizer_policy, optimizer_value, optimizer_q_net_1, optimizer_q_net_2, states, actions, rewards, next_states, masks, gamma, polyak, update_target=False): rewards = rewards.unsqueeze(-1) masks = masks.unsqueeze(-1) """update qvalue net""" q_value_1 = q_net_1(states, actions) q_value_2 = q_net_2(states, actions) with torch.no_grad(): target_next_value = rewards + gamma * \ masks * value_net_target(next_states) q_value_loss_1 = nn.MSELoss()(q_value_1, target_next_value) optimizer_q_net_1.zero_grad() q_value_loss_1.backward() optimizer_q_net_1.step() q_value_loss_2 = nn.MSELoss()(q_value_2, target_next_value) optimizer_q_net_2.zero_grad() q_value_loss_2.backward() optimizer_q_net_2.step() """update policy net""" new_actions, log_probs = policy_net.rsample(states) min_q = torch.min(q_net_1(states, new_actions), q_net_2(states, new_actions)) policy_loss = (log_probs - min_q).mean() optimizer_policy.zero_grad() policy_loss.backward() optimizer_policy.step() """update value net""" target_value = (min_q - log_probs).detach() value_loss = nn.MSELoss()(value_net(states), target_value) optimizer_value.zero_grad() value_loss.backward() optimizer_value.step() if update_target: """ update target value net """ value_net_target_flat_params = get_flat_params(value_net_target) value_net_flat_params = get_flat_params(value_net) set_flat_params(value_net_target, (1 - polyak) * value_net_flat_params + polyak * value_net_target_flat_params) return { "target_value_loss": value_loss, "q_value_loss_1": q_value_loss_1, "q_value_loss_2": q_value_loss_2, "policy_loss": policy_loss }
def td3_step(policy_net, policy_net_target, value_net_1, value_net_target_1, value_net_2, value_net_target_2, optimizer_policy, optimizer_value_1, optimizer_value_2, states, actions, rewards, next_states, masks, gamma, polyak, target_action_noise_std, target_action_noise_clip, action_high, update_policy=False): rewards = rewards.unsqueeze(-1) masks = masks.unsqueeze(-1) """update critic""" with torch.no_grad(): target_action = policy_net_target(next_states) target_action_noise = torch.clamp( torch.randn_like(target_action) * target_action_noise_std, -target_action_noise_clip, target_action_noise_clip) target_action = torch.clamp(target_action + target_action_noise, -action_high, action_high) target_values = rewards + gamma * masks * torch.min( value_net_target_1(next_states, target_action), value_net_target_2(next_states, target_action)) """update value1 target""" values_1 = value_net_1(states, actions) value_loss_1 = nn.MSELoss()(target_values, values_1) optimizer_value_1.zero_grad() value_loss_1.backward() optimizer_value_1.step() """update value2 target""" values_2 = value_net_2(states, actions) value_loss_2 = nn.MSELoss()(target_values, values_2) optimizer_value_2.zero_grad() value_loss_2.backward() optimizer_value_2.step() policy_loss = None if update_policy: """update policy""" policy_loss = -value_net_1(states, policy_net(states)).mean() optimizer_policy.zero_grad() policy_loss.backward() optimizer_policy.step() """soft update target nets""" policy_net_flat_params = get_flat_params(policy_net) policy_net_target_flat_params = get_flat_params(policy_net_target) set_flat_params( policy_net_target, polyak * policy_net_target_flat_params + (1 - polyak) * policy_net_flat_params) value_net_1_flat_params = get_flat_params(value_net_1) value_net_1_target_flat_params = get_flat_params(value_net_target_1) set_flat_params( value_net_target_1, polyak * value_net_1_target_flat_params + (1 - polyak) * value_net_1_flat_params) value_net_2_flat_params = get_flat_params(value_net_2) value_net_2_target_flat_params = get_flat_params(value_net_target_2) set_flat_params( value_net_target_2, polyak * value_net_2_target_flat_params + (1 - polyak) * value_net_2_flat_params) return value_loss_1, value_loss_2, policy_loss