def train_from_paths(self, paths): # Concatenate from all the trajectories observations = np.concatenate([path["observations"] for path in paths]) actions = np.concatenate([path["actions"] for path in paths]) advantages = np.concatenate([path["advantages"] for path in paths]) # Advantage whitening advantages = (advantages - np.mean(advantages)) / (np.std(advantages) + 1e-6) # NOTE : advantage should be zero mean in expectation # normalized step size invariant to advantage scaling, # but scaling can help with least squares self.n_steps += len(advantages) # cache return distributions for the paths path_returns = [sum(p["rewards"]) for p in paths] mean_return = np.mean(path_returns) std_return = np.std(path_returns) min_return = np.amin(path_returns) max_return = np.amax(path_returns) base_stats = [mean_return, std_return, min_return, max_return] self.running_score = mean_return if self.running_score is None else \ 0.9*self.running_score + 0.1*mean_return # approx avg of last 10 iters if self.save_logs: self.log_rollout_statistics(paths) # Keep track of times for various computations t_gLL = 0.0 t_FIM = 0.0 self.optim.zero_grad() # Optimization. Negate gradient since the optimizer is minimizing. vpg_grad = -self.flat_vpg(observations, actions, advantages) vector_to_gradients(Variable(torch.from_numpy(vpg_grad).float()), self.policy.trainable_params) closure = self.kl_closure(self.policy, observations, actions, self.policy_kl_fn) info = self.optim.step(closure) self.policy.set_param_values(self.policy.get_param_values()) # Log information if self.save_logs: self.logger.log_kv('alpha', info['alpha']) self.logger.log_kv('delta', info['delta']) # self.logger.log_kv('time_vpg', t_gLL) # self.logger.log_kv('time_npg', t_FIM) # self.logger.log_kv('kl_dist', kl_dist) # self.logger.log_kv('surr_improvement', surr_after - surr_before) self.logger.log_kv('running_score', self.running_score) self.logger.log_kv('steps', self.n_steps) try: success_rate = self.env.env.env.evaluate_success(paths) self.logger.log_kv('success_rate', success_rate) except: pass return base_stats
def step( self, closure, execute_update=True): #Fvp_fn, execute_update=True, closure=None): """Performs a single optimization step. Arguments: Fvp_fn (callable): A closure that accepts a vector of parameters and a vector of length equal to the number of model paramsters and returns the Fisher-vector product. """ state = self.state # State initialization if len(state) == 0: state['step'] = 0 # Set shrinkage to defaults, i.e. no shrinkage state['rho'] = 0.0 state['diag_shrunk'] = 1.0 state['step'] += 1 # Get flat grad g = gradients_to_vector(self._params) if 'ng_prior' not in state: state['ng_prior'] = torch.zeros_like(g) curv_type = self._param_group['curv_type'] if curv_type not in self.valid_curv_types: raise ValueError("Invalid curv_type.") # Create closure to pass to Lanczos and CG if curv_type == 'fisher': Fvp_theta_fn = make_fvp_fun(closure, self._params) elif curv_type == 'gauss_newton': Fvp_theta_fn = make_gnvp_fun(closure, self._params) shrinkage_method = self._param_group['shrinkage_method'] lanczos_amortization = self._param_group['lanczos_amortization'] if shrinkage_method == 'lanczos' and (state['step'] - 1) % lanczos_amortization == 0: # print ("Computing Lanczos shrinkage at step ", state['step']) w = lanczos_iteration(Fvp_theta_fn, self._numel(), k=self._param_group['lanczos_iters']) rho, diag_shrunk = estimate_shrinkage( w, self._numel(), self._param_group['batch_size']) state['rho'] = rho state['diag_shrunk'] = diag_shrunk M = None if self._param_group['cg_precondition_empirical']: # Empirical Fisher is g * g M = (g * g + self._param_group['cg_precondition_regu_coef'] * torch.ones_like(g))**self._param_group['cg_precondition_exp'] # Do CG solve with hvp fn closure extract_tridiag = self._param_group['shrinkage_method'] == 'cg' cg_result = cg_solve( Fvp_theta_fn, g.data.clone(), x_0=self._param_group['cg_prev_init_coef'] * state['ng_prior'], M=M, cg_iters=self._param_group['cg_iters'], cg_residual_tol=self._param_group['cg_residual_tol'], shrunk=self._param_group['shrinkage_method'] is not None, rho=state['rho'], Dshrunk=state['diag_shrunk'], extract_tridiag=extract_tridiag) if extract_tridiag: # print ("Computing CG shrinkage at step ", state['step']) ng, (diag_elems, off_diag_elems) = cg_result w = eigvalsh_tridiagonal(diag_elems, off_diag_elems) rho, diag_shrunk = estimate_shrinkage( w, self._numel(), self._param_group['batch_size']) state['rho'] = rho state['diag_shrunk'] = diag_shrunk else: ng = cg_result state['ng_prior'] = ng.data.clone() # Normalize NG lr = self._param_group['lr'] alpha = torch.sqrt(torch.abs(lr / (torch.dot(g, ng) + 1e-20))) # Unflatten grad vector_to_gradients(ng, self._params) if execute_update: # Apply step for p in self._params: if p.grad is None: continue d_p = p.grad.data p.data.add_(-alpha, d_p) return dict(alpha=alpha, delta=lr, natural_grad=ng)
def step(self, closure, execute_update=True): """Performs a single optimization step. Arguments: Fvp_fn (callable): A closure that accepts a vector of length equal to the number of model paramsters and returns the Fisher-vector product. """ state = self.state param_vec = parameters_to_vector(self._params) # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['m'] = torch.zeros_like(param_vec.data) # Maintain adaptive preconditioner if needed if self._param_group['cg_precondition_empirical']: state['M'] = torch.zeros_like(param_vec.data) # Set shrinkage to defaults, i.e. no shrinkage state['rho'] = 0.0 state['diag_shrunk'] = 1.0 m = state['m'] beta1, beta2 = self._param_group['betas'] state['step'] += 1 bias_correction1 = 1 - beta1**state['step'] bias_correction2 = 1 - beta2**state['step'] # Get flat grad g = gradients_to_vector(self._params) # Update moving average mean m.mul_(beta1).add_(1 - beta1, g) g_hat = m / bias_correction1 theta = parameters_to_vector(self._params) theta_old = parameters_to_vector(self._params_old) if 'ng_prior' not in state: state['ng_prior'] = torch.zeros_like(g_hat) #g_hat.data.clone() if 'max_fisher_spectral_norm' not in state: state['max_fisher_spectral_norm'] = 0.0 curv_type = self._param_group['curv_type'] if curv_type not in self.valid_curv_types: raise ValueError("Invalid curv_type.") if curv_type == 'fisher': weighted_fvp_fn_div_beta2 = self._make_combined_fvp_fun( closure, self._params, self._params_old, bias_correction2=bias_correction2) elif curv_type == 'gauss_newton': weighted_fvp_fn_div_beta2 = self._make_combined_gnvp_fun( closure, self._params, self._params_old, bias_correction2=bias_correction2) fisher_norm = lanczos_iteration(weighted_fvp_fn_div_beta2, self._numel(), k=1)[0] is_max_norm = fisher_norm > state['max_fisher_spectral_norm'] or state[ 'step'] == 1 if is_max_norm: state['max_fisher_spectral_norm'] = fisher_norm if is_max_norm: if self._param_group['assume_locally_linear']: # Update theta_old beta2 portion towards theta theta_old = beta2 * theta_old + (1 - beta2) * theta else: # Do linesearch first to update theta_old. Then can do CG with only one HVP at each itr. ng = self.state['ng_prior'].clone( ) if state['step'] > 1 else g_hat.data.clone() if curv_type == 'fisher': weighted_fvp_fn = self._make_combined_fvp_fun( closure, self._params, self._params_old) f = make_fvp_obj_fun(closure, weighted_fvp_fn, ng) elif curv_type == 'gauss_newton': weighted_fvp_fn = self._make_combined_gnvp_fun( closure, self._params, self._params_old) f = make_gnvp_obj_fun(closure, weighted_fvp_fn, ng) xmin, fmin, alpha = randomized_linesearch( f, theta_old.data, theta.data) theta_old = Variable(xmin.float()) vector_to_parameters(theta_old, self._params_old) # Now that theta_old has been updated, do CG with only theta old # If not max norm, then this will remain the old params. if curv_type == 'fisher': fvp_fn_div_beta2 = make_fvp_fun(closure, self._params_old, bias_correction2=bias_correction2) elif curv_type == 'gauss_newton': fvp_fn_div_beta2 = make_gnvp_fun(closure, self._params_old, bias_correction2=bias_correction2) shrinkage_method = self._param_group['shrinkage_method'] lanczos_amortization = self._param_group['lanczos_amortization'] if shrinkage_method == 'lanczos' and (state['step'] - 1) % lanczos_amortization == 0: # print ("Computing Lanczos shrinkage at step ", state['step']) w = lanczos_iteration(fvp_fn_div_beta2, self._numel(), k=self._param_group['lanczos_iters']) rho, diag_shrunk = estimate_shrinkage( w, self._numel(), self._param_group['batch_size']) state['rho'] = rho state['diag_shrunk'] = diag_shrunk M = None if self._param_group['cg_precondition_empirical']: # Empirical Fisher is g * g V = state['M'] Mt = (g * g + self._param_group['cg_precondition_regu_coef'] * torch.ones_like(g))**self._param_group['cg_precondition_exp'] Vhat = V.mul(beta2).add(1 - beta2, Mt) / bias_correction2 V = torch.max(V, Vhat) M = V extract_tridiag = self._param_group['shrinkage_method'] == 'cg' cg_result = cg_solve( fvp_fn_div_beta2, g_hat.data.clone(), x_0=self._param_group['cg_prev_init_coef'] * state['ng_prior'], M=M, cg_iters=self._param_group['cg_iters'], cg_residual_tol=self._param_group['cg_residual_tol'], shrunk=self._param_group['shrinkage_method'] is not None, rho=state['rho'], Dshrunk=state['diag_shrunk'], extract_tridiag=extract_tridiag) if extract_tridiag: # print ("Computing CG shrinkage at step ", state['step']) ng, (diag_elems, off_diag_elems) = cg_result w = eigvalsh_tridiagonal(diag_elems, off_diag_elems) rho, diag_shrunk = estimate_shrinkage( w, self._numel(), self._param_group['batch_size']) state['rho'] = rho state['diag_shrunk'] = diag_shrunk else: ng = cg_result self.state['ng_prior'] = ng.data.clone() # Normalize NG lr = self._param_group['lr'] alpha = torch.sqrt(torch.abs(lr / (torch.dot(g_hat, ng) + 1e-20))) # Unflatten grad vector_to_gradients(ng, self._params) if execute_update: # Apply step for p in self._params: if p.grad is None: continue d_p = p.grad.data p.data.add_(-alpha, d_p) return dict(alpha=alpha, delta=lr, natural_grad=ng)
def step(self, closure, execute_update=True): """Performs a single optimization step. Arguments: Fvp_fn (callable): A closure that accepts a vector of parameters and a vector of length equal to the number of model paramsters and returns the Fisher-vector product. """ # Update theta old for all blocks first, only approx update is supported params_i = 0 params_j = 0 for gi, group in enumerate(self.param_groups): params = group['params'] params_j += len(params) num_params = self._numel(gi, params) # print ("num_params: ", num_params, params_i, params_j) state = self.state[gi] if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['m'] = torch.zeros(num_params) # Maintain adaptive preconditioner if needed if group['cg_precondition_empirical']: state['M'] = torch.zeros(num_params) # Set shrinkage to defaults, i.e. no shrinkage state['rho'] = 0.0 state['diag_shrunk'] = 1.0 state['lagged'] = [] for i in range(len(params)): state['lagged'].append(params[i] + torch.randn(params[i].shape) * 0.0001) beta1, beta2 = group['betas'] theta = parameters_to_vector(params) theta_old = parameters_to_vector(state['lagged']) # Update theta_old beta2 portion towards theta theta_old = beta2 * theta_old + (1 - beta2) * theta vector_to_parameters(theta_old, state['lagged']) # print (theta_old) # input("") info = {} # If doing block diag, perform the update for each param group params_i = 0 params_j = 0 for gi, group in enumerate(self.param_groups): params = group['params'] params_j += len(params) num_params = self._numel(gi, params) # NOTE: state is initialized above state = self.state[gi] m = state['m'] beta1, beta2 = group['betas'] state['step'] += 1 params_old = state['lagged'] # bias_correction1 = 1 - beta1**state['step'] bias_correction2 = 1 - beta2**state['step'] # Get flat grad g = gradients_to_vector(params) # Update moving average mean m.mul_(beta1).add_(1 - beta1, g) g_hat = m / bias_correction1 if 'ng_prior' not in state: state['ng_prior'] = torch.zeros_like( g) #g_hat) #g_hat.data.clone() curv_type = group['curv_type'] if curv_type not in self.valid_curv_types: raise ValueError("Invalid curv_type.") # Now that theta_old has been updated, do CG with only theta old if curv_type == 'fisher': fvp_fn_div_beta2 = make_fvp_fun_idx( closure, params_old, params_i, params_j, bias_correction2=bias_correction2) elif curv_type == 'gauss_newton': fvp_fn_div_beta2 = make_gnvp_fun( closure, params_old, bias_correction2=bias_correction2) shrinkage_method = group['shrinkage_method'] lanczos_amortization = group['lanczos_amortization'] if shrinkage_method == 'lanczos' and ( state['step'] - 1) % lanczos_amortization == 0: # print ("Computing Lanczos shrinkage at step ", state['step']) w = lanczos_iteration(fvp_fn_div_beta2, num_params, k=group['lanczos_iters']) rho, diag_shrunk = estimate_shrinkage(w, num_params, group['batch_size']) state['rho'] = rho state['diag_shrunk'] = diag_shrunk M = None if group['cg_precondition_empirical']: # Empirical Fisher is g * g V = state['M'] Mt = (g * g + group['cg_precondition_regu_coef'] * torch.ones_like(g))**group['cg_precondition_exp'] Vhat = V.mul(beta2).add(1 - beta2, Mt) / bias_correction2 V = torch.max(V, Vhat) M = V extract_tridiag = group['shrinkage_method'] == 'cg' cg_result = cg_solve(fvp_fn_div_beta2, g_hat.data.clone(), x_0=group['cg_prev_init_coef'] * state['ng_prior'], M=M, cg_iters=group['cg_iters'], cg_residual_tol=group['cg_residual_tol'], shrunk=group['shrinkage_method'] is not None, rho=state['rho'], Dshrunk=state['diag_shrunk'], extract_tridiag=extract_tridiag) if extract_tridiag: # print ("Computing CG shrinkage at step ", state['step']) ng, (diag_elems, off_diag_elems) = cg_result w = eigvalsh_tridiagonal(diag_elems, off_diag_elems) rho, diag_shrunk = estimate_shrinkage(w, num_params, group['batch_size']) state['rho'] = rho state['diag_shrunk'] = diag_shrunk else: ng = cg_result # print ("NG: ", ng) state['ng_prior'] = ng.data.clone() # Normalize NG lr = group['lr'] alpha = torch.sqrt(torch.abs(lr / (torch.dot(g_hat, ng) + 1e-20))) # Unflatten grad vector_to_gradients(ng, params) if execute_update: # Apply step for p in params: if p.grad is None: continue d_p = p.grad.data p.data.add_(-alpha, d_p) params_i = params_j info[gi] = dict(alpha=alpha, delta=lr, natural_grad=ng) return info
def step(self, loss): #, closure=None): """Performs a single optimization step. Arguments: Fvp_fn (callable): A closure that accepts a vector of parameters and a vector of length equal to the number of model paramsters and returns the Fisher-vector product. """ state = self.state # State initialization if len(state) == 0: state['step'] = 0 state['m'] = torch.zeros((self._numel(),)) state['Ft'] = torch.zeros((self._numel(), self._numel())) state['step'] += 1 # Get flat grad g = gradients_to_vector(self._params) # shrunk = self._param_group['shrunk'] # Compute Fisher Gt = self.H(self._params, loss) if self.adaptive: m = state['m'] Ft = state['Ft'] beta1, beta2 = self._param_group['betas'] bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] m.mul_(beta1).add_(1 - beta1, g) g_hat = m / bias_correction1 Ft.mul_(beta2).add_(1 - beta2, Gt) Ft_hat = Ft / bias_correction2 ng = torch.pinverse(Ft_hat) @ g_hat H = Ft_hat # alpha = float(torch.sqrt(ng.dot(ng)) / (ng.view(-1, 1).t() @ Ft_hat @ ng.view(-1, 1))) else: ng = torch.pinverse(Gt) @ g H = Gt # alpha = float(torch.sqrt(ng.dot(ng)) / (ng.view(-1, 1).t() @ Gt @ ng.view(-1, 1))) lr = self._param_group['lr'] alpha = torch.sqrt(torch.abs(lr / (torch.dot(g, ng) + 1e-20))) # alpha *= 0.1 # Unflatten grad vector_to_gradients(ng, self._params) # Apply step for p in self._params: if p.grad is None: continue d_p = p.grad.data p.data.add_(-alpha, d_p) return dict(alpha=alpha, H=H.clone()) #, delta=lr)
def train_from_paths(self, paths): # Concatenate from all the trajectories observations = np.concatenate([path["observations"] for path in paths]) actions = np.concatenate([path["actions"] for path in paths]) advantages = np.concatenate([path["advantages"] for path in paths]) # Advantage whitening advantages = (advantages - np.mean(advantages)) / (np.std(advantages) + 1e-6) # NOTE : advantage should be zero mean in expectation # normalized step size invariant to advantage scaling, # but scaling can help with least squares self.n_steps += len(advantages) # cache return distributions for the paths path_returns = [sum(p["rewards"]) for p in paths] mean_return = np.mean(path_returns) std_return = np.std(path_returns) min_return = np.amin(path_returns) max_return = np.amax(path_returns) base_stats = [mean_return, std_return, min_return, max_return] self.running_score = mean_return if self.running_score is None else \ 0.9*self.running_score + 0.1*mean_return # approx avg of last 10 iters if self.save_logs: self.log_rollout_statistics(paths) # Keep track of times for various computations t_gLL = 0.0 t_FIM = 0.0 # Optimization algorithm # -------------------------- self.optim.zero_grad() surr_before = self.CPI_surrogate(observations, actions, advantages).data.numpy().ravel()[0] # VPG ts = timer.time() vpg_grad = self.flat_vpg(observations, actions, advantages) vector_to_gradients(Variable(torch.from_numpy(vpg_grad).float()), self.policy.trainable_params) t_gLL += timer.time() - ts # NPG # Note: unlike the standard NPG, negation is not needed here since the optimizer does not # apply the update step. ts = timer.time() closure = self.kl_closure(self.policy, observations, actions, self.policy_kl_fn) info = self.optim.step(closure, execute_update=False) npg_grad = info['natural_grad'].data.numpy() t_FIM += timer.time() - ts # Step size computation # -------------------------- n_step_size = 2.0 * self.kl_dist alpha = np.sqrt( np.abs(n_step_size / (np.dot(vpg_grad.T, npg_grad) + 1e-20))) # Policy update # -------------------------- curr_params = self.policy.get_param_values() for k in range(100): new_params = curr_params + alpha * npg_grad self.policy.set_param_values(new_params, set_new=True, set_old=False) kl_dist = self.kl_old_new(observations, actions).data.numpy().ravel()[0] surr_after = self.CPI_surrogate(observations, actions, advantages).data.numpy().ravel()[0] if kl_dist < self.kl_dist: break else: alpha = 0.9 * alpha # backtrack # print("Step size too high. Backtracking. | kl = %f | surr diff = %f" % \ # (kl_dist, surr_after-surr_before) ) if k == 99: alpha = 0.0 new_params = curr_params + alpha * npg_grad self.policy.set_param_values(new_params, set_new=True, set_old=False) kl_dist = self.kl_old_new(observations, actions).data.numpy().ravel()[0] surr_after = self.CPI_surrogate(observations, actions, advantages).data.numpy().ravel()[0] self.policy.set_param_values(new_params, set_new=True, set_old=True) # Log information if self.save_logs: self.logger.log_kv('alpha', alpha) self.logger.log_kv('delta', n_step_size) self.logger.log_kv('time_vpg', t_gLL) self.logger.log_kv('time_npg', t_FIM) self.logger.log_kv('kl_dist', kl_dist) self.logger.log_kv('surr_improvement', surr_after - surr_before) self.logger.log_kv('running_score', self.running_score) self.logger.log_kv('steps', self.n_steps) return base_stats
def step(self, closure, execute_update=True): """Performs a single optimization step. Arguments: Fvp_fn (callable): A closure that accepts a vector of parameters and a vector of length equal to the number of model paramsters and returns the Fisher-vector product. """ info = {} # If doing block diag, perform the update for each param group params_i = 0 params_j = 0 for gi, group in enumerate(self.param_groups): params = group['params'] params_j += len(params) state = self.state[gi] if len(state) == 0: state['step'] = 0 # Set shrinkage to defaults, i.e. no shrinkage state['rho'] = 0.0 state['diag_shrunk'] = 1.0 state['step'] += 1 g = gradients_to_vector(params) if 'ng_prior' not in state: state['ng_prior'] = torch.zeros_like(g) curv_type = group['curv_type'] if curv_type not in self.valid_curv_types: raise ValueError("Invalid curv_type.") # Create closure to pass to Lanczos and CG if curv_type == 'fisher': Fvp_theta_fn = make_fvp_fun_idx(closure, params, params_i, params_j) elif curv_type == 'gauss_newton': # Pass indices instead of actual params, since these params should be the same at # the model params anyway. Then the closure should set only the subset of params # and only return the tmp_params from that subset. # This would require that the param groups are order in a specific manner? Fvp_theta_fn = make_gnvp_fun_idx(closure, params, params_i, params_j) num_params = self._numel(gi, params) shrinkage_method = group['shrinkage_method'] lanczos_amortization = group['lanczos_amortization'] if shrinkage_method == 'lanczos' and ( state['step'] - 1) % lanczos_amortization == 0: # print ("Computing Lanczos shrinkage at step ", state['step']) w = lanczos_iteration(Fvp_theta_fn, num_params, k=group['lanczos_iters']) rho, diag_shrunk = estimate_shrinkage(w, num_params, group['batch_size']) state['rho'] = rho state['diag_shrunk'] = diag_shrunk M = None if group['cg_precondition_empirical']: # Empirical Fisher is g * g M = (g * g + group['cg_precondition_regu_coef'] * torch.ones_like(g))**group['cg_precondition_exp'] # Do CG solve with hvp fn closure extract_tridiag = group['shrinkage_method'] == 'cg' cg_result = cg_solve(Fvp_theta_fn, g.data.clone(), x_0=group['cg_prev_init_coef'] * state['ng_prior'], M=M, cg_iters=group['cg_iters'], cg_residual_tol=group['cg_residual_tol'], shrunk=group['shrinkage_method'] is not None, rho=state['rho'], Dshrunk=state['diag_shrunk'], extract_tridiag=extract_tridiag) if extract_tridiag: # print ("Computing CG shrinkage at step ", state['step']) ng, (diag_elems, off_diag_elems) = cg_result w = eigvalsh_tridiagonal(diag_elems, off_diag_elems) rho, diag_shrunk = estimate_shrinkage(w, num_params, group['batch_size']) state['rho'] = rho state['diag_shrunk'] = diag_shrunk else: ng = cg_result state['ng_prior'] = ng.data.clone() # Normalize NG lr = group['lr'] alpha = torch.sqrt(torch.abs(lr / (torch.dot(g, ng) + 1e-20))) # Unflatten grad vector_to_gradients(ng, params) if execute_update: # Apply step for p in params: if p.grad is None: continue d_p = p.grad.data p.data.add_(-alpha, d_p) params_i = params_j info[gi] = dict(alpha=alpha, delta=lr, natural_grad=ng) return info