def backtrack_fn(s): assign(initial_parameters + s.data, net.parameters()) test_pds = net(all_states) test_action_log_probs = net.get_loglikelihood(test_pds, actions) new_reward = surrogate_reward(advs, new=test_action_log_probs, old=old_log_ps).mean() if new_reward <= surr_rew or net.calc_kl(pds, test_pds).mean() > params.MAX_KL: return -float('inf') return new_reward - surr_rew
def backtrack_fn(s): assign(initial_parameters + s.data, net.parameters()) test_pds = net(all_states) test_action_log_probs = net.get_loglikelihood(test_pds, actions) new_reward = surrogate_reward(advs, new=test_action_log_probs, old=old_log_ps).mean() # surr_new is the surrogate before optimization. # We need to make sure the loss is improving, and KL between old probabilites are not too large. if params.TRPO_KL_REDUCE_FUNC == 'mean': kl_metric = net.calc_kl(pds, test_pds).mean() elif params.TRPO_KL_REDUCE_FUNC == 'max': kl_metric = net.calc_kl(pds, test_pds).max() else: raise ValueError("unknown reduce function " + params.TRPO_KL_REDUCE_FUNC) if new_reward <= surr_rew or kl_metric > params.MAX_KL: return -float('inf') return new_reward - surr_rew
def backtrack_fn(s): assign(initial_parameters + s.data, net.parameters()) test_pds = net(all_states) test_action_log_probs = net.get_loglikelihood(test_pds, actions) new_reward = surrogate_reward(advs, new=test_action_log_probs, old=old_log_ps).mean() if params.USE_CONS == 'all': if new_reward <= surr_rew or net.calc_kl( pds, test_pds).mean() > params.MAX_KL: return -float('inf') elif params.USE_CONS == 'kl': if net.calc_kl(pds, test_pds).mean() > params.MAX_KL: return -float('inf') elif params.USE_CONS == 'rew': if new_reward <= surr_rew: return -float('inf') elif params.USE_CONS == 'none': pass else: raise NotImplementedError("No such constraints") return new_reward - surr_rew
def trpo_step(all_states, actions, old_log_ps, rewards, returns, not_dones, advs, net, params, store, opt_step): ''' Trust Region Policy Optimization Runs K epochs of TRPO as in https://arxiv.org/abs/1502.05477 Inputs: - all_states, the historical value of all the states - actions, the actions that the policy sampled - old_log_ps, the probability of the actions that the policy sampled - advs, advantages as estimated by GAE - net, policy network to train [WILL BE MUTATED] - params, additional placeholder for parameters like EPS Returns: - The TRPO loss; main job is to mutate the net ''' # Initial setup initial_parameters = flatten(net.parameters()).clone() # all_states is in shape (experience_size, observation_size). Usually 2048 experiences. # Get mean and std of action distribution for all experiences. pds = net(all_states) # And compute the log probabilities for the actions chosen at rollout time. action_log_probs = net.get_loglikelihood(pds, actions) # Calculate losses surr_rew = surrogate_reward(advs, new=action_log_probs, old=old_log_ps).mean() grad = ch.autograd.grad(surr_rew, net.parameters(), retain_graph=True) # This represents the computation of gradient, and will be used to obtain 2nd order. flat_grad = flatten(grad) # Make fisher product estimator. Only use a fraction of examples. num_samples = int(all_states.shape[0] * params.FISHER_FRAC_SAMPLES) selected = np.random.choice(range(all_states.shape[0]), num_samples, replace=False) detached_selected_pds = select_prob_dists(pds, selected, detach=True) selected_pds = select_prob_dists(pds, selected, detach=False) # Construct the KL divergence which we will optimize on. This is essentially 0, but what we care about is the Hessian. # We want to know when the network parameter changes, how the K-L divergence of network output changes. kl = net.calc_kl(detached_selected_pds, selected_pds).mean() # g is the gradient of the KL divergence w.r.t to parameters. It is 0 at the starting point. g = flatten(ch.autograd.grad(kl, net.parameters(), create_graph=True)) ''' Fisher matrix to vector x product. Essentially, a Hessian-vector product of K-L divergence w.r.t network parameter. ''' def fisher_product(x, damp_coef=1.): contig_flat = lambda q: ch.cat([y.contiguous().view(-1) for y in q]) # z is the gradient-vector product. Take the derivation of it to get Hessian vector product. z = g @ x hv = ch.autograd.grad(z, net.parameters(), retain_graph=True) return contig_flat(hv).detach() + x*params.DAMPING * damp_coef # Find KL constrained gradient step # The Fisher matrix A is unknown, but we can compute the product. # flat_grad is the right-hand side value b. Want to solve x in Ax = b step = cg_solve(fisher_product, flat_grad, params.CG_STEPS) # Return the solution. "step" has size of network parameters. max_step_coeff = (2 * params.MAX_KL / (step @ fisher_product(step)))**(0.5) max_trpo_step = max_step_coeff * step if store and params.SHOULD_LOG_KL: kl_approximation_logging(all_states, pds, flat_grad, step, net, store) kl_vs_second_order_approx(all_states, pds, net, max_trpo_step, params, store, opt_step) # Backtracking line search with ch.no_grad(): # Backtracking function, which gives the improvement on objective given an update direction s. def backtrack_fn(s): assign(initial_parameters + s.data, net.parameters()) test_pds = net(all_states) test_action_log_probs = net.get_loglikelihood(test_pds, actions) new_reward = surrogate_reward(advs, new=test_action_log_probs, old=old_log_ps).mean() # surr_new is the surrogate before optimization. # We need to make sure the loss is improving, and KL between old probabilites are not too large. if params.TRPO_KL_REDUCE_FUNC == 'mean': kl_metric = net.calc_kl(pds, test_pds).mean() elif params.TRPO_KL_REDUCE_FUNC == 'max': kl_metric = net.calc_kl(pds, test_pds).max() else: raise ValueError("unknown reduce function " + params.TRPO_KL_REDUCE_FUNC) if new_reward <= surr_rew or kl_metric > params.MAX_KL: return -float('inf') return new_reward - surr_rew expected_improve = flat_grad @ max_trpo_step # max_trpo_step is the search direction. Backtracking line search will find a scaler for it. # expected_improve is the expected decrease in loss estimated by gradient. # backtracking_line_search will try a scaler 0.5, 0.25, 0.125, etc to achieve expected improvement. final_step = backtracking_line_search(backtrack_fn, max_trpo_step, expected_improve, num_tries=params.MAX_BACKTRACK) assign(initial_parameters + final_step, net.parameters()) # entropy regularization not used for TRPO so return 0. return surr_rew.item(), 0.0, 0.0
def robust_ppo_step(all_states, actions, old_log_ps, rewards, returns, not_dones, advs, net, params, store, opt_step, relaxed_net, eps_scheduler, beta_scheduler): ''' Proximal Policy Optimization with robustness regularizer Runs K epochs of PPO as in https://arxiv.org/abs/1707.06347 Inputs: - all_states, the historical value of all the states - actions, the actions that the policy sampled - old_log_ps, the log probability of the actions that the policy sampled - advs, advantages as estimated by GAE - net, policy network to train [WILL BE MUTATED] - params, additional placeholder for parameters like EPS Returns: - The PPO loss; main job is to mutate the net ''' # Storing batches of stuff # if store is not None: # orig_dists = net(all_states) ### ACTUAL PPO OPTIMIZATION START if params.SHARE_WEIGHTS: orig_vs = net.get_value(all_states).squeeze(-1).view([params.NUM_ACTORS, -1]) old_vs = orig_vs.detach() # We treat all PPO epochs as one epoch. eps_scheduler.set_epoch_length(params.PPO_EPOCHS * params.NUM_MINIBATCHES) beta_scheduler.set_epoch_length(params.PPO_EPOCHS * params.NUM_MINIBATCHES) # We count from 1. eps_scheduler.step_epoch() beta_scheduler.step_epoch() if params.HISTORY_LENGTH > 0: # LSTM policy. Need to go over all episodes instead of states. # We normalize all advantages at once instead of batch by batch, since each batch may contain different number of samples. normalized_advs = adv_normalize(advs) batches, alive_masks, time_masks, lengths = pack_history([all_states, actions, old_log_ps, normalized_advs], not_dones, max_length=params.HISTORY_LENGTH) for _ in range(params.PPO_EPOCHS): if params.HISTORY_LENGTH > 0: # LSTM policy. Need to go over all episodes instead of states. params.POLICY_ADAM.zero_grad() hidden = None surrogate = 0.0 for i, batch in enumerate(batches): # Now we get chunks of time sequences, each of them with a maximum length of params.HISTORY_LENGTH. # select log probabilities, advantages of this minibatch. batch_states, batch_actions, batch_old_log_ps, batch_advs = batch mask = time_masks[i] """ print('batch states', batch_states.size()) print('batch actions', batch_actions.size()) print('batch old_log_ps', batch_old_log_ps.size()) print('batch advs', batch_advs.size()) print('alive mask', alive_masks[i].size(), alive_masks[i].sum()) print('mask', mask.size()) """ # keep only the alive hidden states. if hidden is not None: # print('hidden[0]', hidden[0].size()) hidden = [h[:, alive_masks[i], :].detach() for h in hidden] # print('hidden[0]', hidden[0].size()) # dist contains mean and variance of Gaussian. mean, std, hidden = net.multi_forward(batch_states, hidden=hidden) dist = mean, std # Convert state distribution to log likelyhood. new_log_ps = net.get_loglikelihood(dist, batch_actions) # print('batch new_log_ps', new_log_ps.size()) """ print('old') print(batch_old_log_ps) print('new') print(new_log_ps * mask) print('diff') print((batch_old_log_ps - new_log_ps * mask).pow(2).sum().item()) """ shape_equal_cmp(new_log_ps, batch_old_log_ps) # Calculate rewards # the surrogate rewards is basically exp(new_log_ps - old_log_ps) * advantage # dimension is the same as minibatch size. # We already normalized advs before. No need to normalize here. unclp_rew = surrogate_reward(batch_advs, new=new_log_ps, old=batch_old_log_ps, mask=mask, normalize=False) clp_rew = surrogate_reward(batch_advs, new=new_log_ps, old=batch_old_log_ps, clip_eps=params.CLIP_EPS, mask=mask, normalize=False) # Total loss, is the min of clipped and unclipped reward for each state, averaged. surrogate_batch = (-ch.min(unclp_rew, clp_rew) * mask).sum() # We sum the batch loss here because each batch contains uneven number of trajactories. surrogate = surrogate + surrogate_batch # Divide surrogate loss by number of samples in this batch. surrogate = surrogate / all_states.size(0) # Calculate entropy bonus # So far, the entropy only depends on std and does not depend on time. No need to mask. entropy_bonus = net.entropies(dist) # Calculate regularizer under state perturbation. eps_scheduler.step_batch() beta_scheduler.step_batch() batch_action_means = None current_eps = eps_scheduler.get_eps() stdev = ch.exp(net.log_stdev) if params.ROBUST_PPO_DETACH_STDEV: # Detach stdev so that it won't be too large. stdev = stdev.detach() if params.ROBUST_PPO_METHOD == "sgld": kl_upper_bound = get_state_kl_bound_sgld(net, all_states, None, eps=current_eps, steps=params.ROBUST_PPO_PGD_STEPS, stdev=stdev, not_dones=not_dones).mean() else: raise ValueError(f"Unsupported robust PPO method {params.ROBUST_PPO_METHOD}") entropy = -params.ENTROPY_COEFF * entropy_bonus loss = surrogate + entropy + params.ROBUST_PPO_REG * kl_upper_bound # optimizer (only ADAM) loss.backward() if params.CLIP_GRAD_NORM != -1: ch.nn.utils.clip_grad_norm(net.parameters(), params.CLIP_GRAD_NORM) params.POLICY_ADAM.step() else: # Memoryless policy. # State is in shape (experience_size, observation_size). Usually 2048. state_indices = np.arange(all_states.shape[0]) np.random.shuffle(state_indices) # We use a minibatch of states to do optimization, and each epoch contains several iterations. splits = np.array_split(state_indices, params.NUM_MINIBATCHES) # A typical mini-batch size is 2048/32=64 for selected in splits: def sel(*args): return [v[selected] for v in args] # old_log_ps: log probabilities of actions sampled based in experience buffer. # advs: advantages of these states. # both old_log_ps and advs are in shape (experience_size,) = 2048. tup = sel(all_states, actions, old_log_ps, advs) # select log probabilities, advantages of this minibatch. batch_states, batch_actions, batch_old_log_ps, batch_advs = tup # Forward propagation on current parameters (being constantly updated), to get distribution of these states # dist contains mean and variance of Gaussian. dist = net(batch_states) # Convert state distribution to log likelyhood. new_log_ps = net.get_loglikelihood(dist, batch_actions) shape_equal_cmp(new_log_ps, batch_old_log_ps) # Calculate rewards # the surrogate rewards is basically exp(new_log_ps - old_log_ps) * advantage # dimension is the same as minibatch size. unclp_rew = surrogate_reward(batch_advs, new=new_log_ps, old=batch_old_log_ps) clp_rew = surrogate_reward(batch_advs, new=new_log_ps, old=batch_old_log_ps, clip_eps=params.CLIP_EPS) # Calculate entropy bonus entropy_bonus = net.entropies(dist).mean() # Calculate regularizer under state perturbation. eps_scheduler.step_batch() beta_scheduler.step_batch() batch_action_means = dist[0] current_eps = eps_scheduler.get_eps() stdev = ch.exp(net.log_stdev) if params.ROBUST_PPO_DETACH_STDEV: # Detach stdev so that it won't be too large. stdev = stdev.detach() if params.ROBUST_PPO_METHOD == "convex-relax": kl_upper_bound = get_state_kl_bound(relaxed_net, batch_states, batch_action_means, eps=current_eps, beta=beta_scheduler.get_eps(), stdev=stdev).mean() elif params.ROBUST_PPO_METHOD == "sgld": kl_upper_bound = get_state_kl_bound_sgld(net, batch_states, batch_action_means, eps=current_eps, steps=params.ROBUST_PPO_PGD_STEPS, stdev=stdev).mean() else: raise ValueError(f"Unsupported robust PPO method {params.ROBUST_PPO_METHOD}") # Total loss, is the min of clipped and unclipped reward for each state, averaged. surrogate = -ch.min(unclp_rew, clp_rew).mean() entropy = -params.ENTROPY_COEFF * entropy_bonus loss = surrogate + entropy + params.ROBUST_PPO_REG * kl_upper_bound # If we are sharing weights, take the value step simultaneously # (since the policy and value networks depend on the same weights) if params.SHARE_WEIGHTS: tup = sel(returns, not_dones, old_vs) batch_returns, batch_not_dones, batch_old_vs = tup val_loss = value_step(batch_states, batch_returns, batch_advs, batch_not_dones, net.get_value, None, params, store, old_vs=batch_old_vs, opt_step=opt_step) loss += params.VALUE_MULTIPLIER * val_loss # Optimizer step (Adam or SGD) if params.POLICY_ADAM is None: grad = ch.autograd.grad(loss, net.parameters()) flat_grad = flatten(grad) if params.CLIP_GRAD_NORM != -1: norm_grad = ch.norm(flat_grad) flat_grad = flat_grad if norm_grad <= params.CLIP_GRAD_NORM else \ flat_grad / norm_grad * params.CLIP_GRAD_NORM assign(flatten(net.parameters()) - params.PPO_LR * flat_grad, net.parameters()) else: params.POLICY_ADAM.zero_grad() loss.backward() if params.CLIP_GRAD_NORM != -1: ch.nn.utils.clip_grad_norm(net.parameters(), params.CLIP_GRAD_NORM) params.POLICY_ADAM.step() # Logging. kl_upper_bound = kl_upper_bound.item() surrogate = surrogate.item() entropy_bonus = entropy_bonus.item() print(f'eps={eps_scheduler.get_eps():8.6f}, beta={beta_scheduler.get_eps():8.6f}, kl={kl_upper_bound:10.5g}, ' f'surrogate={surrogate:8.5f}, entropy={entropy_bonus:8.5f}, loss={loss.item():8.5f}') std = ch.exp(net.log_stdev) print(f'std_min={std.min().item():8.5f}, std_max={std.max().item():8.5f}, std_mean={std.mean().item():8.5f}') if store is not None: # TODO: ADV: add row name suffix row ={ 'eps': eps_scheduler.get_eps(), 'beta': beta_scheduler.get_eps(), 'kl': kl_upper_bound, 'surrogate': surrogate, 'entropy': entropy_bonus, 'loss': loss.item(), } store.log_table_and_tb('robust_ppo_data', row) return loss.item(), surrogate, entropy_bonus
def robust_ppo_step(all_states, actions, old_log_ps, rewards, returns, not_dones, advs, net, params, store, opt_step, relaxed_net, eps_scheduler, beta_scheduler): ''' Proximal Policy Optimization with robustness regularizer Runs K epochs of PPO as in https://arxiv.org/abs/1707.06347 Inputs: - all_states, the historical value of all the states - actions, the actions that the policy sampled - old_log_ps, the log probability of the actions that the policy sampled - advs, advantages as estimated by GAE - net, policy network to train [WILL BE MUTATED] - params, additional placeholder for parameters like EPS Returns: - The PPO loss; main job is to mutate the net ''' # Storing batches of stuff # if store is not None: # orig_dists = net(all_states) ### ACTUAL PPO OPTIMIZATION START if params.SHARE_WEIGHTS: orig_vs = net.get_value(all_states).squeeze(-1).view( [params.NUM_ACTORS, -1]) old_vs = orig_vs.detach() # We treat all PPO epochs as one epoch. eps_scheduler.set_epoch_length(params.PPO_EPOCHS * params.NUM_MINIBATCHES) beta_scheduler.set_epoch_length(params.PPO_EPOCHS * params.NUM_MINIBATCHES) # We count from 1. eps_scheduler.step_epoch() beta_scheduler.step_epoch() for _ in range(params.PPO_EPOCHS): # State is in shape (experience_size, observation_size). Usually 2048. state_indices = np.arange(all_states.shape[0]) np.random.shuffle(state_indices) # We use a minibatch of states to do optimization, and each epoch contains several iterations. splits = np.array_split(state_indices, params.NUM_MINIBATCHES) # A typical mini-batch size is 2048/32=64 for selected in splits: def sel(*args): return [v[selected] for v in args] # old_log_ps: log probabilities of actions sampled based in experience buffer. # advs: advantages of these states. # both old_log_ps and advs are in shape (experience_size,) = 2048. tup = sel(all_states, actions, old_log_ps, advs) # select log probabilities, advantages of this minibatch. batch_states, batch_actions, batch_old_log_ps, batch_advs = tup # Forward propagation on current parameters (being constantly updated), to get distribution of these states # dist contains mean and variance of Gaussian. dist = net(batch_states) # Convert state distribution to log likelyhood. new_log_ps = net.get_loglikelihood(dist, batch_actions) shape_equal_cmp(new_log_ps, batch_old_log_ps) # Calculate rewards # the surrogate rewards is basically exp(new_log_ps - old_log_ps) * advantage # dimension is the same as minibatch size. unclp_rew = surrogate_reward(batch_advs, new=new_log_ps, old=batch_old_log_ps) clp_rew = surrogate_reward(batch_advs, new=new_log_ps, old=batch_old_log_ps, clip_eps=params.CLIP_EPS) # Calculate entropy bonus entropy_bonus = net.entropies(dist).mean() # Calculate regularizer under state perturbation. eps_scheduler.step_batch() beta_scheduler.step_batch() batch_action_means = dist[0] current_eps = eps_scheduler.get_eps() stdev = ch.exp(net.log_stdev) if params.ROBUST_PPO_DETACH_STDEV: # Detach stdev so that it won't be too large. stdev = stdev.detach() if params.ROBUST_PPO_METHOD == "convex-relax": kl_upper_bound = get_state_kl_bound( relaxed_net, batch_states, batch_action_means, eps=current_eps, beta=beta_scheduler.get_eps(), stdev=stdev).mean() elif params.ROBUST_PPO_METHOD == "sgld": kl_upper_bound = get_state_kl_bound_sgld( net, batch_states, batch_action_means, eps=current_eps, steps=params.ROBUST_PPO_PGD_STEPS, stdev=stdev).mean() else: raise ValueError( f"Unsupported robust PPO method {params.ROBUST_PPO_METHOD}" ) # Total loss, is the min of clipped and unclipped reward for each state, averaged. surrogate = -ch.min(unclp_rew, clp_rew).mean() entropy = -params.ENTROPY_COEFF * entropy_bonus loss = surrogate + entropy + params.ROBUST_PPO_REG * kl_upper_bound # If we are sharing weights, take the value step simultaneously # (since the policy and value networks depend on the same weights) if params.SHARE_WEIGHTS: tup = sel(returns, not_dones, old_vs) batch_returns, batch_not_dones, batch_old_vs = tup val_loss = value_step(batch_states, batch_returns, batch_advs, batch_not_dones, net.get_value, None, params, store, old_vs=batch_old_vs, opt_step=opt_step) loss += params.VALUE_MULTIPLIER * val_loss # Optimizer step (Adam or SGD) if params.POLICY_ADAM is None: grad = ch.autograd.grad(loss, net.parameters()) flat_grad = flatten(grad) if params.CLIP_GRAD_NORM != -1: norm_grad = ch.norm(flat_grad) flat_grad = flat_grad if norm_grad <= params.CLIP_GRAD_NORM else \ flat_grad / norm_grad * params.CLIP_GRAD_NORM assign( flatten(net.parameters()) - params.PPO_LR * flat_grad, net.parameters()) else: params.POLICY_ADAM.zero_grad() loss.backward() if params.CLIP_GRAD_NORM != -1: ch.nn.utils.clip_grad_norm(net.parameters(), params.CLIP_GRAD_NORM) params.POLICY_ADAM.step() # Logging. kl_upper_bound = kl_upper_bound.item() surrogate = surrogate.item() entropy_bonus = entropy_bonus.item() print( f'eps={eps_scheduler.get_eps():8.6f}, beta={beta_scheduler.get_eps():8.6f}, kl={kl_upper_bound:8.5f}, ' f'surrogate={surrogate:8.5f}, entropy={entropy_bonus:8.5f}, loss={loss.item():8.5f}' ) std = ch.exp(net.log_stdev) print( f'std_min={std.min().item():8.5f}, std_max={std.max().item():8.5f}, std_mean={std.mean().item():8.5f}' ) if store is not None: row = { 'eps': eps_scheduler.get_eps(), 'beta': beta_scheduler.get_eps(), 'kl': kl_upper_bound, 'surrogate': surrogate, 'entropy': entropy_bonus, 'loss': loss.item(), } store.log_table_and_tb('robust_ppo_data', row) return loss
def ppo_step(all_states, actions, old_log_ps, rewards, returns, not_dones, advs, net, params, store, opt_step): ''' Proximal Policy Optimization Runs K epochs of PPO as in https://arxiv.org/abs/1707.06347 Inputs: - all_states, the historical value of all the states - actions, the actions that the policy sampled - old_log_ps, the log probability of the actions that the policy sampled - advs, advantages as estimated by GAE - net, policy network to train [WILL BE MUTATED] - params, additional placeholder for parameters like EPS Returns: - The PPO loss; main job is to mutate the net ''' # Storing batches of stuff # if store is not None: # orig_dists = net(all_states) ### ACTUAL PPO OPTIMIZATION START if params.SHARE_WEIGHTS: orig_vs = net.get_value(all_states).squeeze(-1).view( [params.NUM_ACTORS, -1]) old_vs = orig_vs.detach() for _ in range(params.PPO_EPOCHS): # State is in shape (experience_size, observation_size). Usually 2048. state_indices = np.arange(all_states.shape[0]) np.random.shuffle(state_indices) # We use a minibatch of states to do optimization, and each epoch contains several iterations. splits = np.array_split(state_indices, params.NUM_MINIBATCHES) # A typical mini-batch size is 2048/32=64 for selected in splits: def sel(*args): return [v[selected] for v in args] # old_log_ps: log probabilities of actions sampled based in experience buffer. # advs: advantages of these states. # both old_log_ps and advs are in shape (experience_size,) = 2048. tup = sel(all_states, actions, old_log_ps, advs) # select log probabilities, advantages of this minibatch. batch_states, batch_actions, batch_old_log_ps, batch_advs = tup # Forward propagation on current parameters (being constantly updated), to get distribution of these states # dist contains mean and variance of Gaussian. dist = net(batch_states) # Convert state distribution to log likelyhood. new_log_ps = net.get_loglikelihood(dist, batch_actions) shape_equal_cmp(new_log_ps, batch_old_log_ps) # Calculate rewards # the surrogate rewards is basically exp(new_log_ps - old_log_ps) * advantage # dimension is the same as minibatch size. unclp_rew = surrogate_reward(batch_advs, new=new_log_ps, old=batch_old_log_ps) clp_rew = surrogate_reward(batch_advs, new=new_log_ps, old=batch_old_log_ps, clip_eps=params.CLIP_EPS) # Calculate entropy bonus entropy_bonus = net.entropies(dist).mean() # Total loss, is the min of clipped and unclipped reward for each state, averaged. surrogate = -ch.min(unclp_rew, clp_rew).mean() entropy = -params.ENTROPY_COEFF * entropy_bonus loss = surrogate + entropy # If we are sharing weights, take the value step simultaneously # (since the policy and value networks depend on the same weights) if params.SHARE_WEIGHTS: tup = sel(returns, not_dones, old_vs) batch_returns, batch_not_dones, batch_old_vs = tup val_loss = value_step(batch_states, batch_returns, batch_advs, batch_not_dones, net.get_value, None, params, store, old_vs=batch_old_vs, opt_step=opt_step) loss += params.VALUE_MULTIPLIER * val_loss # Optimizer step (Adam or SGD) if params.POLICY_ADAM is None: grad = ch.autograd.grad(loss, net.parameters()) flat_grad = flatten(grad) if params.CLIP_GRAD_NORM != -1: norm_grad = ch.norm(flat_grad) flat_grad = flat_grad if norm_grad <= params.CLIP_GRAD_NORM else \ flat_grad / norm_grad * params.CLIP_GRAD_NORM assign( flatten(net.parameters()) - params.PPO_LR * flat_grad, net.parameters()) else: params.POLICY_ADAM.zero_grad() loss.backward() if params.CLIP_GRAD_NORM != -1: ch.nn.utils.clip_grad_norm(net.parameters(), params.CLIP_GRAD_NORM) params.POLICY_ADAM.step() print( f'surrogate={surrogate.item():8.5f}, entropy={entropy_bonus.item():8.5f}, loss={loss.item():8.5f}' ) std = ch.exp(net.log_stdev) print( f'std_min={std.min().item():8.5f}, std_max={std.max().item():8.5f}, std_mean={std.mean().item():8.5f}' ) return loss
def trpo_step(all_states, actions, old_log_ps, rewards, returns, not_dones, advs, net, params, store, opt_step): ''' Trust Region Policy Optimization Runs K epochs of TRPO as in https://arxiv.org/abs/1502.05477 Inputs: - all_states, the historical value of all the states - actions, the actions that the policy sampled - old_log_ps, the probability of the actions that the policy sampled - advs, advantages as estimated by GAE - net, policy network to train [WILL BE MUTATED] - params, additional placeholder for parameters like EPS Returns: - The TRPO loss; main job is to mutate the net ''' # Initial setup initial_parameters = flatten(net.parameters()).clone() pds = net(all_states) action_log_probs = net.get_loglikelihood(pds, actions) # Calculate losses surr_rew = surrogate_reward(advs, new=action_log_probs, old=old_log_ps).mean() grad = ch.autograd.grad(surr_rew, net.parameters(), retain_graph=True) flat_grad = flatten(grad) if params.USE_CONJ: # Make fisher product estimator num_samples = int(all_states.shape[0] * params.FISHER_FRAC_SAMPLES) selected = np.random.choice(range(all_states.shape[0]), num_samples, replace=False) detached_selected_pds = select_prob_dists(pds, selected, detach=True) selected_pds = select_prob_dists(pds, selected, detach=False) kl = net.calc_kl(detached_selected_pds, selected_pds).mean() g = flatten(ch.autograd.grad(kl, net.parameters(), create_graph=True)) def fisher_product(x, damp_coef=1.): contig_flat = lambda q: ch.cat( [y.contiguous().view(-1) for y in q]) z = g @ x hv = ch.autograd.grad(z, net.parameters(), retain_graph=True) return contig_flat(hv).detach() + x * params.DAMPING * damp_coef # Find KL constrained gradient step step = cg_solve(fisher_product, flat_grad, params.CG_STEPS) max_step_coeff = (2 * params.MAX_KL / (step @ fisher_product(step)))**(0.5) max_trpo_step = max_step_coeff * step if store and params.SHOULD_LOG_KL: kl_approximation_logging(all_states, pds, flat_grad, step, net, store) kl_vs_second_order_approx(all_states, pds, net, max_trpo_step, params, store, opt_step) else: max_trpo_step = flat_grad.clone() * params.PPO_LR_ADAM # Backtracking line search with ch.no_grad(): # Backtracking function def backtrack_fn(s): assign(initial_parameters + s.data, net.parameters()) test_pds = net(all_states) test_action_log_probs = net.get_loglikelihood(test_pds, actions) new_reward = surrogate_reward(advs, new=test_action_log_probs, old=old_log_ps).mean() if params.USE_CONS == 'all': if new_reward <= surr_rew or net.calc_kl( pds, test_pds).mean() > params.MAX_KL: return -float('inf') elif params.USE_CONS == 'kl': if net.calc_kl(pds, test_pds).mean() > params.MAX_KL: return -float('inf') elif params.USE_CONS == 'rew': if new_reward <= surr_rew: return -float('inf') elif params.USE_CONS == 'none': pass else: raise NotImplementedError("No such constraints") return new_reward - surr_rew expected_improve = flat_grad @ max_trpo_step final_step = backtracking_line_search(backtrack_fn, max_trpo_step, expected_improve, num_tries=params.MAX_BACKTRACK) assign(initial_parameters + final_step, net.parameters()) return surr_rew
def ppo_step(all_states, actions, old_log_ps, rewards, returns, not_dones, advs, net, params, store, opt_step): ''' Proximal Policy Optimization Runs K epochs of PPO as in https://arxiv.org/abs/1707.06347 Inputs: - all_states, the historical value of all the states - actions, the actions that the policy sampled - old_log_ps, the log probability of the actions that the policy sampled - advs, advantages as estimated by GAE - net, policy network to train [WILL BE MUTATED] - params, additional placeholder for parameters like EPS Returns: - The PPO loss; main job is to mutate the net ''' # Storing batches of stuff if store is not None: orig_dists = net(all_states) ### ACTUAL PPO OPTIMIZATION START if params.SHARE_WEIGHTS: orig_vs = net.get_value(all_states).squeeze(-1).view( [params.NUM_ACTORS, -1]) old_vs = orig_vs.detach() for _ in range(params.PPO_EPOCHS): state_indices = np.arange(all_states.shape[0]) np.random.shuffle(state_indices) splits = np.array_split(state_indices, params.NUM_MINIBATCHES) for selected in splits: def sel(*args): return [v[selected] for v in args] tup = sel(all_states, actions, old_log_ps, advs) batch_states, batch_actions, batch_old_log_ps, batch_advs = tup dist = net(batch_states) new_log_ps = net.get_loglikelihood(dist, batch_actions) shape_equal_cmp(new_log_ps, batch_old_log_ps) # Calculate rewards unclp_rew = surrogate_reward(batch_advs, new=new_log_ps, old=batch_old_log_ps) clp_rew = surrogate_reward(batch_advs, new=new_log_ps, old=batch_old_log_ps, clip_eps=params.CLIP_EPS) # Calculate entropy bonus entropy_bonus = net.entropies(dist).mean() # Total loss surrogate = -ch.min(unclp_rew, clp_rew).mean() entropy = -params.ENTROPY_COEFF * entropy_bonus loss = surrogate + entropy # If we are sharing weights, take the value step simultaneously # (since the policy and value networks depend on the same weights) if params.SHARE_WEIGHTS: tup = sel(returns, not_dones, old_vs) batch_returns, batch_not_dones, batch_old_vs = tup val_loss = value_step(batch_states, batch_returns, batch_advs, batch_not_dones, net.get_value, None, params, store, old_vs=batch_old_vs, opt_step=opt_step) loss += params.VALUE_MULTIPLIER * val_loss # Optimizer step (Adam or SGD) if params.POLICY_ADAM is None: grad = ch.autograd.grad(loss, net.parameters()) flat_grad = flatten(grad) if params.CLIP_GRAD_NORM != -1: norm_grad = ch.norm(flat_grad) flat_grad = flat_grad if norm_grad <= params.CLIP_GRAD_NORM else \ flat_grad / norm_grad * params.CLIP_GRAD_NORM assign( flatten(net.parameters()) - params.PPO_LR * flat_grad, net.parameters()) else: params.POLICY_ADAM.zero_grad() loss.backward() if params.CLIP_GRAD_NORM != -1: ch.nn.utils.clip_grad_norm(net.parameters(), params.CLIP_GRAD_NORM) params.POLICY_ADAM.step() return loss