def iteration(self, sample_lists, itr): # Store the samples and evaluate the costs. for m in range(self.M): self.cur[m].sample_list = sample_lists[m] self._eval_cost(m) with Timer(self.timers, 'pol_update'): self._update_policy() # Prepare for next iteration self._advance_iteration_variables()
def _update_trajectories(self): """Compute new linear Gaussian controllers.""" if not hasattr(self, 'new_traj_distr'): self.new_traj_distr = [ self.cur[cond].traj_distr for cond in range(self.M) ] with Timer(self.timers, 'traj_opt'): for cond in range(self.M): self.new_traj_distr[cond], self.cur[cond].eta, self.new_mu[ cond], self.new_sigma[cond] = self.traj_opt_update(cond) self.visualize_local_policy(0)
def iteration(self, sample_lists, _): """ Run iteration of MDGPS-based guided policy search. Args: sample_lists: List of SampleList objects for each condition. _: to match parent class """ # Store the samples and evaluate the costs. for m in range(self.M): self.cur[m].sample_list = sample_lists[m] self._eval_cost(m) # Update dynamics linearizations. self._update_dynamics() # On the first iteration, need to catch policy up to init_traj_distr. if self.iteration_count == 0: self.new_traj_distr = [ self.cur[cond].traj_distr for cond in range(self.M) ] self._update_policy() # Update policy linearizations. with Timer(self.algorithm.timers, 'pol_lin'): for m in range(self.M): self._update_policy_fit(m) # C-step if self.iteration_count > 0: self._stepadjust() self._update_trajectories() # S-step with Timer(self.algorithm.timers, 'pol_update'): self._update_policy() # Prepare for next iteration self._advance_iteration_variables()
def run(self): """Runs training by alternatively taking samples and optimizing the policy.""" if 'load_model' in self._hyperparams: self.iteration_count = self._hyperparams['load_model'][1] self.algorithm.policy_opt.iteration_count = self.iteration_count self.algorithm.policy_opt.restore_model( *self._hyperparams['load_model']) # Global policy static resets if self._hyperparams['num_pol_samples_static'] > 0: self.export_samples(self._take_policy_samples( N=self._hyperparams['num_pol_samples_static'], pol=self.algorithm.policy_opt.policy, rnd=False), '_pol-static', visualize=True) return for itr in range(self._hyperparams['iterations']): self.iteration_count = itr if hasattr(self.algorithm, 'traj_opt'): self.algorithm.traj_opt.iteration_count = itr if hasattr(self.algorithm, 'policy_opt'): self.algorithm.policy_opt.iteration_count = itr print("*** Iteration %02d ***" % itr) if itr == 0 and 'load_initial_samples' in self._hyperparams: # Load trajectory samples print('Loading initial samples ...') sample_files = self._hyperparams['load_initial_samples'] traj_sample_lists = [[] for _ in range(self.algorithm.M)] for sample_file in sample_files: data = np.load(sample_file) X, U = data['X'], data['U'] assert X.shape[0] == self.algorithm.M for m in range(self.algorithm.M): for n in range(X.shape[1]): traj_sample_lists[m].append( self.agent.pack_sample(X[m, n], U[m, n])) traj_sample_lists = [ SampleList(traj_samples) for traj_samples in traj_sample_lists ] else: # Take trajectory samples with Timer(self.algorithm.timers, 'sampling'): for cond in self._train_idx: for i in trange(self._hyperparams['num_samples'], desc='Taking samples'): self._take_sample(cond, i) traj_sample_lists = [ self.agent.get_samples(cond, -self._hyperparams['num_samples']) for cond in self._train_idx ] self.export_samples(traj_sample_lists, visualize=True) # Iteration with Timer(self.algorithm.timers, 'iteration'): self.algorithm.iteration(traj_sample_lists, itr) self.export_dynamics() self.export_controllers() self.export_times() if hasattr(self.algorithm, 'policy_opt') and hasattr( self.algorithm.policy_opt, 'store_model'): self.algorithm.policy_opt.store_model() # Sample learned policies for visualization # LQR policies static resets if self._hyperparams['num_lqr_samples_static'] > 0: self.export_samples(self._take_policy_samples( N=self._hyperparams['num_lqr_samples_static'], pol=None, rnd=False), '_lqr-static', visualize=True) # LQR policies random resets if self._hyperparams['num_lqr_samples_random'] > 0: self.export_samples(self._take_policy_samples( N=self._hyperparams['num_lqr_samples_random'], pol=None, rnd=True), '_lqr-random', visualize=True) # LQR policies state noise if self._hyperparams['num_lqr_samples_random'] > 0: self.export_samples(self._take_policy_samples( N=self._hyperparams['num_lqr_samples_random'], pol=None, rnd=False, randomize_initial_state=24), '_lqr-static-randomized', visualize=True) if hasattr(self.algorithm, 'policy_opt'): # Global policy static resets if self._hyperparams['num_pol_samples_static'] > 0: self.export_samples(self._take_policy_samples( N=self._hyperparams['num_pol_samples_static'], pol=self.algorithm.policy_opt.policy, rnd=False), '_pol-static', visualize=True) # Global policy random resets if self._hyperparams['num_pol_samples_random'] > 0: self.export_samples(self._take_policy_samples( N=self._hyperparams['num_pol_samples_random'], pol=self.algorithm.policy_opt.policy, rnd=True), '_pol-random', visualize=True) # Global policy state noise if self._hyperparams['num_pol_samples_random'] > 0: self.export_samples(self._take_policy_samples( N=self._hyperparams['num_pol_samples_random'], pol=self.algorithm.policy_opt.policy, rnd=False, randomize_initial_state=24), '_pol-static-randomized', visualize=True) self.visualize_training_progress()
def run(self, session_id, itr_load=None): """ Run training by iteratively sampling and taking an iteration. Args: itr_load: If specified, loads algorithm state from that iteration, and resumes training at the next iteration. Returns: None """ itr_start = self._initialize(itr_load) if session_id is not None: self.session_id = session_id for itr in range(itr_start, self._hyperparams['iterations']): self.iteration_count = itr if hasattr(self.algorithm, 'policy_opt'): self.algorithm.policy_opt.iteration_count = itr print("*** Iteration %02d ***" % itr) # Take trajectory samples with Timer(self.algorithm.timers, 'sampling'): for cond in self._train_idx: for i in trange(self._hyperparams['num_samples'], desc='Taking samples'): self._take_sample(itr, cond, i) traj_sample_lists = [ self.agent.get_samples(cond, -self._hyperparams['num_samples']) for cond in self._train_idx ] self.export_samples(traj_sample_lists) # Iteration with Timer(self.algorithm.timers, 'iteration'): self.algorithm.iteration(traj_sample_lists, itr) self.export_dynamics() self.export_controllers() self.export_times() # Sample learned policies for visualization # LQR policies static resets if self._hyperparams['num_lqr_samples_static'] > 0: self.export_samples( self._take_policy_samples( N=self._hyperparams['num_lqr_samples_static'], pol=None, rnd=False), '_lqr-static') # LQR policies random resets if self._hyperparams['num_lqr_samples_random'] > 0: self.export_samples( self._take_policy_samples( N=self._hyperparams['num_lqr_samples_random'], pol=None, rnd=True), '_lqr-random') if hasattr(self.algorithm, 'policy_opt'): # Global policy static resets if self._hyperparams['num_pol_samples_static'] > 0: self.export_samples( self._take_policy_samples( N=self._hyperparams['num_pol_samples_static'], pol=self.algorithm.policy_opt.policy, rnd=False), '_pol-static') # Global policy static resets if self._hyperparams['num_pol_samples_random'] > 0: self.export_samples( self._take_policy_samples( N=self._hyperparams['num_pol_samples_random'], pol=self.algorithm.policy_opt.policy, rnd=True), '_pol-random') self._end()