def step(policy_net, optimizer_policy, states, actions, advantages, traj_num): """update policy, objective function is average trajectory return""" log_probs = policy_net.get_log_prob(Variable(states), Variable(actions)) print("old log probs", log_probs) policy_loss = -(log_probs * Variable(advantages)).sum() / traj_num # flat_grad = torch_utils.compute_flat_grad(policy_loss, policy_net.parameters(), create_graph=True).detach().numpy() # logger.log("gradient:" + str(flat_grad)) # # # check what would be the outcome if we just add gradient to the current parameters # prev_params = torch_utils.get_flat_params_from(policy_net).detach().numpy() # logger.log("old_parameters" + str(prev_params)) # logger.log("new_parameters handcoded plus" + str(prev_params + flat_grad * 0.01)) # logger.log("new_parameters handcoded minus" + str(prev_params - flat_grad * 0.01)) prev_params = torch_utils.get_flat_params_from(policy_net).detach().numpy() logger.log("old_parameters" + str(prev_params)) optimizer_policy.zero_grad() policy_loss.backward() logger.record_tabular("policy_loss before", policy_loss.item()) # for param in policy_net.parameters(): # logger.log("parameter_grad:" + str(param.grad)) optimizer_policy.step() new_params = torch_utils.get_flat_params_from(policy_net).detach().numpy() logger.log("old_parameters" + str(new_params)) # calculate new loss log_probs = policy_net.get_log_prob(Variable(states), Variable(actions)) print("new log probs",log_probs) policy_loss = -(log_probs * Variable(advantages)).sum() / traj_num logger.record_tabular("policy_loss after", policy_loss.item())
def get_kl_diff(old_param, new_param): prev_params = torch_utils.get_flat_params_from(policy_net) with torch.no_grad(): torch_utils.set_flat_params_to(policy_net, old_param) log_old_prob = torch.clamp(policy_net.get_log_prob( Variable(states, volatile=True), Variable(actions)), min=np.log(1e-6)) torch_utils.set_flat_params_to(policy_net, new_param) log_new_prob = torch.clamp(policy_net.get_log_prob( Variable(states, volatile=True), Variable(actions)), min=np.log(1e-6)) torch_utils.set_flat_params_to(policy_net, prev_params) return torch.mean(torch.exp(log_old_prob) * (log_old_prob-log_new_prob)).numpy()
def _train_SGD(self): # TODO: we need to get here the right observations, actions and next_observations for the model # expert_observations, expert_actions, expert_next_observations = create_torch_var_from_paths(self.expert_data) # now train imitation policy using collect batch of expert_data with MLE on log prob since we have a Gaussian # TODO: do we train mean and variance? or only mean torch_input_batch, torch_output_batch = self.create_torch_var_from_paths(self.expert_data) # split data randomly into training and validation set, let's go with 70 - 30 split numTotalSamples = torch_input_batch.size(0) trainingSize = int(numTotalSamples*0.7) randomIndices = np.random.permutation(np.arange(numTotalSamples)) trainingIndices = randomIndices[:trainingSize] validationIndices = randomIndices[trainingSize:] validation_input_batch = torch_input_batch[validationIndices] validation_output_batch = torch_output_batch[validationIndices] torch_input_batch = torch_input_batch[trainingIndices] torch_output_batch = torch_output_batch[trainingIndices] best_loss = np.inf losses = np.array([best_loss] * 25) with tqdm(total=self.n_itr, file=sys.stdout) as pbar: for epoch in range(self.n_itr+1): with logger.prefix('epoch #%d | ' % epoch): # split into mini batches for training total_batchsize = torch_input_batch.size(0) logger.record_tabular('Iteration', epoch) indices = np.random.permutation(np.arange(total_batchsize)) if isinstance(self.imitationModel, CartPoleModel): logger.record_tabular("theta", str(self.imitationModel.theta.detach().numpy())) logger.record_tabular("std", str(self.imitationModel.std.detach().numpy())) # go through the whole batch for k in range(int(total_batchsize/self.mini_batchsize)): idx = indices[self.mini_batchsize*k:self.mini_batchsize*(k+1)] # TODO: how about numerical stability? log_prob = self.imitationModel.get_log_prob(torch_input_batch[idx, :], torch_output_batch[idx, :]) # note that L2 regularization is in weight decay of optimizer loss = -torch.mean(log_prob) # negative since we want to minimize and not maximize self.optimizer.zero_grad() loss.backward() self.optimizer.step() # calculate the loss on the whole batch log_prob = self.imitationModel.get_log_prob(validation_input_batch, validation_output_batch) loss = -torch.mean(log_prob) # Note: here we add L2 regularization to the loss to log the proper loss # weight decay for param in self.imitationModel.parameters(): loss += param.pow(2).sum() * self.l2_reg logger.record_tabular("loss", loss.item()) # check if loss has decreased in the last 25 itr on the validation set, if not stop training # and return the best found parameters losses[1:] = losses[0:-1] losses[0] = loss if epoch == 0: best_loss = np.min(losses) best_flat_parameters = torch_utils.get_flat_params_from(self.imitationModel).detach().numpy() logger.record_tabular("current_best_loss", best_loss) elif np.min(losses) <= best_loss and not (np.mean(losses) == best_loss): #second condition prevents same error in whole losses # set best loss to new one if smaller or keep it best_loss = np.min(losses) best_flat_parameters = torch_utils.get_flat_params_from(self.imitationModel).detach().numpy() logger.record_tabular("current_best_loss", best_loss) else: pbar.close() print("best loss did not decrease in last 25 steps") print("saving best result...") logger.log("best loss did not decrease in last 25 steps") torch_utils.set_flat_params_to(self.imitationModel, torch_utils.torch.from_numpy(best_flat_parameters)) logger.log("SGD converged") logger.log("saving best result...") params, torch_params = self.get_itr_snapshot(epoch) if not params is None: params["algo"] = self logger.save_itr_params(self.n_itr, params, torch_params) logger.log("saved") break pbar.set_description('epoch: %d' % (1 + epoch)) pbar.update(1) # save result logger.log("saving snapshot...") params, torch_params = self.get_itr_snapshot(epoch) if not params is None: params["algo"] = self logger.save_itr_params(epoch, params, torch_params) logger.log("saved") logger.dump_tabular(with_prefix=False)
def _train_BGFS(self): if not isinstance(self.imitationModel, CartPoleModel): raise NotImplementedError("train BGFS can be only called with CartPoleModel") expert_observations = torch.from_numpy(self.expert_data["observations"]).float() expert_actions = torch.from_numpy(self.expert_data["actions"]).float() expert_obs_diff = torch.from_numpy(self.expert_data["env_infos"]["obs_diff"]).float() # now train imitation policy using collect batch of expert_data with MLE on log prob since we have a Gaussian # TODO: do we train mean and variance? or only mean if self.mode == "imitate_env": input = torch.cat([expert_observations, expert_actions], dim=1) output = expert_obs_diff else: return ValueError("invalid mode") imitation_model = self.imitationModel total_batchsize = input.size(0) def get_negative_likelihood_loss(flat_params): torch_utils.set_flat_params_to(imitation_model, torch_utils.torch.from_numpy(flat_params)) for param in imitation_model.parameters(): if param.grad is not None: param.grad.data.fill_(0) indices = np.random.permutation(np.arange(total_batchsize)) loss = - torch.mean(imitation_model.get_log_prob(input[indices[:self.mini_batchsize]], output[indices[:self.mini_batchsize]])) # weight decay for param in imitation_model.parameters(): loss += param.pow(2).sum() * self.l2_reg loss.backward() # FIX: removed [0] since, mean reduces already it to an int (new functionality of new torch version? return loss.detach().numpy(), \ torch_utils.get_flat_grad_from( imitation_model.parameters()).detach().numpy(). \ astype(np.float64) curr_itr = 0 def callback_fun(flat_params): nonlocal curr_itr torch_utils.set_flat_params_to(imitation_model, torch_utils.torch.from_numpy(flat_params)) # calculate the loss of the whole batch loss = - torch.mean(imitation_model.get_log_prob(input, output)) # weight decay for param in imitation_model.parameters(): loss += param.pow(2).sum() * self.l2_reg loss.backward() if isinstance(self.imitationModel, CartPoleModel): logger.record_tabular("theta", str(self.imitationModel.theta.detach().numpy())) logger.record_tabular("std", str(self.imitationModel.std.detach().numpy())) logger.record_tabular('Iteration', curr_itr) logger.record_tabular("loss", loss.item()) logger.dump_tabular(with_prefix=False) curr_itr += 1 x0 = torch_utils.get_flat_params_from(self.imitationModel).detach().numpy() # only allow positive variables since we know the masses and variance cannot be negative bounds = [(0, np.inf) for _ in x0] flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b( get_negative_likelihood_loss, x0, maxiter=self.n_itr, bounds=bounds, callback=callback_fun) logger.log(str(opt_info)) torch_utils.set_flat_params_to(self.imitationModel, torch.from_numpy(flat_params)) # save result logger.log("saving snapshot...") params, torch_params = self.get_itr_snapshot(0) params["algo"] = self logger.save_itr_params(self.n_itr, params, torch_params) logger.log("saved")
def step(self, policy_net, value_net, states, actions, returns, advantages): """update critic""" values_target = Variable(returns) """calculates the mean kl difference between 2 parameter settings""" def get_kl_diff(old_param, new_param): prev_params = torch_utils.get_flat_params_from(policy_net) with torch.no_grad(): torch_utils.set_flat_params_to(policy_net, old_param) log_old_prob = torch.clamp(policy_net.get_log_prob( Variable(states, volatile=True), Variable(actions)), min=np.log(1e-6)) torch_utils.set_flat_params_to(policy_net, new_param) log_new_prob = torch.clamp(policy_net.get_log_prob( Variable(states, volatile=True), Variable(actions)), min=np.log(1e-6)) torch_utils.set_flat_params_to(policy_net, prev_params) return torch.mean(torch.exp(log_old_prob) * (log_old_prob-log_new_prob)).numpy() def get_value_loss(flat_params): torch_utils.set_flat_params_to(value_net, torch_utils.torch.from_numpy(flat_params)) for param in value_net.parameters(): if param.grad is not None: param.grad.data.fill_(0) values_pred = value_net(Variable(states)) value_loss = (values_pred - values_target).pow(2).mean() # weight decay for param in value_net.parameters(): value_loss += param.pow(2).sum() * self.l2_reg value_loss.backward() # FIX: removed [0] since, mean reduces already it to an int (new functionality of new torch version? return value_loss.data.cpu().numpy(), \ torch_utils.get_flat_grad_from( value_net.parameters()).data.cpu().numpy(). \ astype(np.float64) flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b( get_value_loss, torch_utils.get_flat_params_from(value_net).cpu().numpy(), maxiter=25) torch_utils.set_flat_params_to(value_net, torch.from_numpy(flat_params)) """update policy""" fixed_log_probs = torch.clamp(policy_net.get_log_prob( Variable(states, volatile=True), Variable(actions)), min=np.log(1e-6)).data """define the loss function for TRPO""" def get_loss(volatile=False): # more numberical stable: have a minimum value, s.t. we don't get -inf log_probs = torch.clamp(policy_net.get_log_prob( Variable(states, volatile=volatile), Variable(actions)), min=np.log(1e-6)) ent = policy_net.get_entropy(Variable(states, volatile=volatile), Variable(actions)).mean() action_loss = -Variable(advantages) * torch.exp( log_probs - Variable(fixed_log_probs)) # logger.log("advantage"+str(advantages)) # logger.log("log_probs"+str(log_probs)) # logger.log("mean"+str(torch.mean(torch.exp( # log_probs - Variable(fixed_log_probs))))) # logger.log("action_loss_no_mean"+str(-action_loss)) return action_loss.mean() - self.entropy_coeff * ent """use fisher information matrix for Hessian*vector""" def Fvp_fim(v): M, mu, info = policy_net.get_fim(Variable(states)) mu = mu.view(-1) filter_input_ids = set() if policy_net.is_disc_action else \ {info['std_id']} t = M.new(mu.size()) t[:] = 1 t = Variable(t, requires_grad=True) mu_t = (mu * t).sum() Jt = torch_utils.compute_flat_grad(mu_t, policy_net.parameters(), filter_input_ids=filter_input_ids, create_graph=True) Jtv = (Jt * Variable(v)).sum() Jv = torch.autograd.grad(Jtv, t, retain_graph=True)[0] MJv = Variable(M * Jv.data) mu_MJv = (MJv * mu).sum() JTMJv = torch_utils.compute_flat_grad(mu_MJv, policy_net.parameters(), filter_input_ids=filter_input_ids, retain_graph=True).data JTMJv /= states.shape[0] if not policy_net.is_disc_action: std_index = info['std_index'] JTMJv[std_index: std_index + M.shape[0]] += \ 2 * v[std_index: std_index + M.shape[0]] return JTMJv + v * self.damping """directly compute Hessian*vector from KL""" def Fvp_direct(v): kl = policy_net.get_kl(Variable(states)) kl = kl.mean() grads = torch.autograd.grad(kl, policy_net.parameters(), create_graph=True) flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) kl_v = (flat_grad_kl * Variable(v)).sum() grads = torch.autograd.grad(kl_v, policy_net.parameters()) flat_grad_grad_kl = torch.cat( [grad.contiguous().view(-1) for grad in grads]).data return flat_grad_grad_kl + v * self.damping Fvp = Fvp_fim if self.use_fim else Fvp_direct loss = get_loss() grads = torch.autograd.grad(loss, policy_net.parameters()) loss_grad = torch.cat([grad.view(-1) for grad in grads]).data stepdir = conjugate_gradients(Fvp, -loss_grad, 10) shs = (stepdir.dot(Fvp(stepdir))) lm = np.sqrt(2 * self.max_kl / (shs + 1e-8)) if np.isnan(lm): lm = 1. fullstep = stepdir * lm expected_improve = -loss_grad.dot(fullstep) prev_params = torch_utils.get_flat_params_from(policy_net) success, new_params = \ line_search(policy_net, get_loss, prev_params, fullstep, expected_improve, get_kl_diff, self.max_kl) logger.record_tabular('TRPO_linesearch_success', int(success)) logger.record_tabular("KL_diff", get_kl_diff(prev_params,new_params)) torch_utils.set_flat_params_to(policy_net, new_params) logger.log("old_parameters" + str(prev_params.detach().numpy())) logger.log("new_parameters" + str(new_params.detach().numpy())) return success