def train(self): """ Trains policy on env using algo """ time_total = time.time() ''' --------------- worker looping --------------- ''' futures = [worker.start.remote() for worker in self.workers] logger.log('Start looping...') ray.get(futures) logger.logkv('Trainer-TimeTotal', time.time() - time_total) logger.dumpkvs() logger.log('***** Training finished ******')
def prepare_start(self, env_pickle, policy_pickle, baseline_pickle, dynamics_model_pickle, feed_dict, algo_str, config): import tensorflow as tf self.sess = sess = tf.Session(config=config) with sess.as_default(): """ --------------------- Construct instances -------------------""" from asynch_mb.samplers.bptt_samplers.bptt_sampler import BPTTSampler from asynch_mb.samplers.base import SampleProcessor from asynch_mb.algos.ppo import PPO from asynch_mb.algos.trpo import TRPO env = pickle.loads(env_pickle) policy = pickle.loads(policy_pickle) baseline = pickle.loads(baseline_pickle) dynamics_model = pickle.loads(dynamics_model_pickle) sess.run(tf.initializers.global_variables()) self.policy = policy self.baseline = baseline self.model_sampler = BPTTSampler(env=env, policy=policy, dynamics_model=dynamics_model, **feed_dict['model_sampler']) self.model_sample_processor = SampleProcessor( baseline=baseline, **feed_dict['model_sample_processor']) if algo_str == 'meppo': self.algo = PPO(policy=policy, **feed_dict['algo']) elif algo_str == 'metrpo': self.algo = TRPO(policy=policy, **feed_dict['algo']) else: raise NotImplementedError(f'got algo_str {algo_str}') """ -------------------- Pull pickled model from model parameter server ---------------- """ dynamics_model = pickle.loads(dynamics_model_pickle) self.model_sampler.dynamics_model = dynamics_model if hasattr(self.model_sampler, 'vec_env'): self.model_sampler.vec_env.dynamics_model = dynamics_model """ -------------------- Step and Push ------------------- """ self.step() self.push() logger.dumpkvs() return 1
def train(self): """ Trains policy on env using algo """ worker_data_queue, worker_model_queue, worker_policy_queue = self.queues worker_data_remote, worker_model_remote, worker_policy_remote = self.remotes for p in self.ps: p.start() ''' --------------- worker warm-up --------------- ''' logger.log('Prepare start...') worker_data_remote.send('prepare start') worker_data_queue.put(self.initial_random_samples) assert worker_data_remote.recv() == 'loop ready' worker_model_remote.send('prepare start') assert worker_model_remote.recv() == 'loop ready' worker_policy_remote.send('prepare start') assert worker_policy_remote.recv() == 'loop ready' time_total = time.time() ''' --------------- worker looping --------------- ''' logger.log('Start looping...') for remote in self.remotes: remote.send('start loop') ''' --------------- collect info --------------- ''' for remote in self.remotes: assert remote.recv() == 'loop done' logger.log('\n------------all workers exit loops -------------') for remote in self.remotes: assert remote.recv() == 'worker closed' for p in self.ps: p.terminate() logger.logkv('Trainer-TimeTotal', time.time() - time_total) logger.dumpkvs() logger.log("*****Training finished")
def start(self): logger.log(f"\n================ {self.name} starts ===============") time_start = time.time() with self.sess.as_default(): # loop while not ray.get(self.stop_cond.is_set.remote()): do_synch, do_step = self.step_wrapper() self.synch_counter += do_synch self.step_counter += do_step # logging logger.logkv(self.name + '-TimeSoFar', time.time() - time_start) logger.logkv(self.name + '-TotalStep', self.step_counter) logger.logkv(self.name + '-TotalSynch', self.synch_counter) logger.dumpkvs() self.set_stop_cond() logger.log( f"\n================== {self.name} closed ===================")
def train(self): """ Trains policy on env using algo Pseudocode: for itr in n_itr: for step in num_inner_grad_steps: sampler.sample() algo.compute_updated_dists() algo.optimize_policy() sampler.update_goals() """ with self.sess.as_default() as sess: # initialize uninitialized vars (only initialize vars that were not loaded) # uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))] # sess.run(tf.variables_initializer(uninit_vars)) sess.run(tf.global_variables_initializer()) start_time = time.time() for itr in range(self.start_itr, self.n_itr): itr_start_time = time.time() logger.log("\n ---------------- Iteration %d ----------------" % itr) time_env_sampling_start = time.time() if self.initial_random_samples and itr == 0: logger.log("Obtaining random samples from the environment...") env_paths = self.env_sampler.obtain_samples(log=True, random=True, log_prefix='Data-EnvSampler-') else: logger.log("Obtaining samples from the environment using the policy...") env_paths = self.env_sampler.obtain_samples(log=True, log_prefix='Data-EnvSampler-') # Add sleeping time to match parallel experiment # time.sleep(10) logger.record_tabular('Data-TimeEnvSampling', time.time() - time_env_sampling_start) logger.log("Processing environment samples...") # first processing just for logging purposes time_env_samp_proc = time.time() samples_data = self.dynamics_sample_processor.process_samples(env_paths, log=True, log_prefix='Data-EnvTrajs-') self.env.log_diagnostics(env_paths, prefix='Data-EnvTrajs-') logger.record_tabular('Data-TimeEnvSampleProc', time.time() - time_env_samp_proc) ''' --------------- fit dynamics model --------------- ''' time_fit_start = time.time() self.dynamics_model.update_buffer(samples_data['observations'], samples_data['actions'], samples_data['next_observations'], check_init=True) buffer = None if not self.sample_from_buffer else samples_data logger.record_tabular('Model-TimeModelFit', time.time() - time_fit_start) ''' --------------- MAML steps --------------- ''' times_dyn_sampling = [] times_dyn_sample_processing = [] times_optimization = [] times_step = [] remaining_model_idx = list(range(self.dynamics_model.num_models)) valid_loss_rolling_average_prev = None with_new_data = True for id_step in range(self.repeat_steps): for epoch in range(self.num_epochs_per_step): logger.log("Training dynamics model for %i epochs ..." % 1) remaining_model_idx, valid_loss_rolling_average = self.dynamics_model.fit_one_epoch( remaining_model_idx, valid_loss_rolling_average_prev, with_new_data, log_tabular=True, prefix='Model-') with_new_data = False for step in range(self.num_grad_policy_per_step): logger.log("\n ---------------- Grad-Step %d ----------------" % int(itr * self.repeat_steps * self.num_grad_policy_per_step + id_step * self.num_grad_policy_per_step + step)) step_start_time = time.time() """ -------------------- Sampling --------------------------""" logger.log("Obtaining samples from the model...") time_env_sampling_start = time.time() paths = self.model_sampler.obtain_samples(log=True, log_prefix='Policy-', buffer=buffer) sampling_time = time.time() - time_env_sampling_start """ ----------------- Processing Samples ---------------------""" logger.log("Processing samples from the model...") time_proc_samples_start = time.time() samples_data = self.model_sample_processor.process_samples(paths, log='all', log_prefix='Policy-') proc_samples_time = time.time() - time_proc_samples_start if type(paths) is list: self.log_diagnostics(paths, prefix='Policy-') else: self.log_diagnostics(sum(paths.values(), []), prefix='Policy-') """ ------------------ Policy Update ---------------------""" logger.log("Optimizing policy...") # This needs to take all samples_data so that it can construct graph for meta-optimization. time_optimization_step_start = time.time() self.algo.optimize_policy(samples_data) optimization_time = time.time() - time_optimization_step_start times_dyn_sampling.append(sampling_time) times_dyn_sample_processing.append(proc_samples_time) times_optimization.append(optimization_time) times_step.append(time.time() - step_start_time) """ ------------------- Logging Stuff --------------------------""" logger.logkv('Iteration', itr) logger.logkv('n_timesteps', self.env_sampler.total_timesteps_sampled) logger.logkv('Policy-TimeSampleProc', np.sum(times_dyn_sample_processing)) logger.logkv('Policy-TimeSampling', np.sum(times_dyn_sampling)) logger.logkv('Policy-TimeAlgoOpt', np.sum(times_optimization)) logger.logkv('Policy-TimeStep', np.sum(times_step)) logger.logkv('Time', time.time() - start_time) logger.logkv('ItrTime', time.time() - itr_start_time) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr) logger.save_itr_params(itr, params) logger.log("Saved") logger.dumpkvs() if itr == 0: sess.graph.finalize() logger.logkv('Trainer-TimeTotal', time.time() - start_time) logger.log("Training finished") self.sess.close()
def train(self): """ Trains policy on env using algo Pseudocode: for itr in n_itr: for step in num_inner_grad_steps: sampler.sample() algo.compute_updated_dists() algo.optimize_policy() sampler.update_goals() """ with self.sess.as_default() as sess: # initialize uninitialized vars (only initialize vars that were not loaded) # uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))] # sess.run(tf.variables_initializer(uninit_vars)) sess.run(tf.global_variables_initializer()) if type(self.meta_steps_per_iter) is tuple: meta_steps_per_iter = np.linspace(self.meta_steps_per_iter[0] , self.meta_steps_per_iter[1], self.n_itr).astype(np.int) else: meta_steps_per_iter = [self.meta_steps_per_iter] * self.n_itr start_time = time.time() for itr in range(self.start_itr, self.n_itr): itr_start_time = time.time() logger.log("\n ---------------- Iteration %d ----------------" % itr) time_env_sampling_start = time.time() if self.initial_random_samples and itr == 0: logger.log("Obtaining random samples from the environment...") env_paths = self.env_sampler.obtain_samples(log=True, random=True, log_prefix='EnvSampler-') else: logger.log("Obtaining samples from the environment using the policy...") env_paths = self.env_sampler.obtain_samples(log=True, log_prefix='EnvSampler-') logger.record_tabular('Time-EnvSampling', time.time() - time_env_sampling_start) logger.log("Processing environment samples...") # first processing just for logging purposes time_env_samp_proc = time.time() if type(env_paths) is dict or type(env_paths) is collections.OrderedDict: env_paths = list(env_paths.values()) idxs = np.random.choice(range(len(env_paths)), size=self.num_rollouts_per_iter, replace=False) env_paths = sum([env_paths[idx] for idx in idxs], []) elif type(env_paths) is list: idxs = np.random.choice(range(len(env_paths)), size=self.num_rollouts_per_iter, replace=False) env_paths = [env_paths[idx] for idx in idxs] else: raise TypeError samples_data = self.dynamics_sample_processor.process_samples(env_paths, log=True, log_prefix='EnvTrajs-') self.env.log_diagnostics(env_paths, prefix='EnvTrajs-') logger.record_tabular('Time-EnvSampleProc', time.time() - time_env_samp_proc) ''' --------------- fit dynamics model --------------- ''' time_fit_start = time.time() logger.log("Training dynamics model for %i epochs ..." % (self.dynamics_model_max_epochs)) self.dynamics_model.fit(samples_data['observations'], samples_data['actions'], samples_data['next_observations'], epochs=self.dynamics_model_max_epochs, verbose=True, log_tabular=True) buffer = None if not self.sample_from_buffer else samples_data logger.record_tabular('Time-ModelFit', time.time() - time_fit_start) ''' ------------ log real performance --------------- ''' if self.log_real_performance: logger.log("Evaluating the performance of the real policy") self.policy.switch_to_pre_update() env_paths = self.env_sampler.obtain_samples(log=True, log_prefix='PrePolicy-') samples_data = self.model_sample_processor.process_samples(env_paths, log='all', log_prefix='PrePolicy-') self.algo._adapt(samples_data) env_paths = self.env_sampler.obtain_samples(log=True, log_prefix='PostPolicy-') self.model_sample_processor.process_samples(env_paths, log='all', log_prefix='PostPolicy-') ''' --------------- MAML steps --------------- ''' times_dyn_sampling = [] times_dyn_sample_processing = [] times_meta_sampling = [] times_inner_step = [] times_total_inner_step = [] times_outer_step = [] times_maml_steps = [] for meta_itr in range(meta_steps_per_iter[itr]): logger.log("\n ---------------- Meta-Step %d ----------------" % int(sum(meta_steps_per_iter[:itr]) + meta_itr)) self.policy.switch_to_pre_update() # Switch to pre-update policy all_samples_data, all_paths = [], [] list_sampling_time, list_inner_step_time, list_outer_step_time, list_proc_samples_time = [], [], [], [] time_maml_steps_start = time.time() start_total_inner_time = time.time() for step in range(self.num_inner_grad_steps+1): logger.log("\n ** Adaptation-Step %d **" % step) """ -------------------- Sampling --------------------------""" logger.log("Obtaining samples...") time_env_sampling_start = time.time() paths = self.model_sampler.obtain_samples(log=True, log_prefix='Step_%d-' % step, buffer=buffer) list_sampling_time.append(time.time() - time_env_sampling_start) all_paths.append(paths) """ ----------------- Processing Samples ---------------------""" logger.log("Processing samples...") time_proc_samples_start = time.time() samples_data = self.model_sample_processor.process_samples(paths, log='all', log_prefix='Step_%d-' % step) all_samples_data.append(samples_data) list_proc_samples_time.append(time.time() - time_proc_samples_start) self.log_diagnostics(sum(list(paths.values()), []), prefix='Step_%d-' % step) """ ------------------- Inner Policy Update --------------------""" time_inner_step_start = time.time() if step < self.num_inner_grad_steps: logger.log("Computing inner policy updates...") self.algo._adapt(samples_data) list_inner_step_time.append(time.time() - time_inner_step_start) total_inner_time = time.time() - start_total_inner_time time_maml_opt_start = time.time() """ ------------------ Outer Policy Update ---------------------""" logger.log("Optimizing policy...") # This needs to take all samples_data so that it can construct graph for meta-optimization. time_outer_step_start = time.time() self.algo.optimize_policy(all_samples_data) times_inner_step.append(list_inner_step_time) times_total_inner_step.append(total_inner_time) times_outer_step.append(time.time() - time_outer_step_start) times_meta_sampling.append(np.sum(list_sampling_time)) times_dyn_sampling.append(list_sampling_time) times_dyn_sample_processing.append(list_proc_samples_time) times_maml_steps.append(time.time() - time_maml_steps_start) """ ------------------- Logging Stuff --------------------------""" logger.logkv('Itr', itr) if self.log_real_performance: logger.logkv('n_timesteps', self.env_sampler.total_timesteps_sampled/(3 * self.policy.meta_batch_size) * self.num_rollouts_per_iter) else: logger.logkv('n_timesteps', self.env_sampler.total_timesteps_sampled/self.policy.meta_batch_size * self.num_rollouts_per_iter) logger.logkv('AvgTime-OuterStep', np.mean(times_outer_step)) logger.logkv('AvgTime-InnerStep', np.mean(times_inner_step)) logger.logkv('AvgTime-TotalInner', np.mean(times_total_inner_step)) logger.logkv('AvgTime-InnerStep', np.mean(times_inner_step)) logger.logkv('AvgTime-SampleProc', np.mean(times_dyn_sample_processing)) logger.logkv('AvgTime-Sampling', np.mean(times_dyn_sampling)) logger.logkv('AvgTime-MAMLSteps', np.mean(times_maml_steps)) logger.logkv('Time', time.time() - start_time) logger.logkv('ItrTime', time.time() - itr_start_time) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr) logger.save_itr_params(itr, params) logger.log("Saved") logger.dumpkvs() if itr == 0: sess.graph.finalize() logger.log("Training finished") self.sess.close()
def train(self): """ Trains policy on env using algo Pseudocode: for itr in n_itr: for step in num_inner_grad_steps: sampler.sample() algo.compute_updated_dists() algo.optimize_policy() sampler.update_goals() """ with self.sess.as_default() as sess: # initialize uninitialized vars (only initialize vars that were not loaded) # uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))] # sess.run(tf.variables_initializer(uninit_vars)) sess.run(tf.global_variables_initializer()) start_time = time.time() for itr in range(self.start_itr, self.n_itr): itr_start_time = time.time() logger.log( "\n ---------------- Iteration %d ----------------" % itr) time_env_sampling_start = time.time() if itr == 0: logger.log( "Obtaining random samples from the environment...") self.env_sampler.total_samples *= self.num_rollouts_per_iter env_paths = self.env_sampler.obtain_samples( log=True, random=self.initial_random_samples, log_prefix='Data-EnvSampler-', verbose=True) self.env_sampler.total_samples /= self.num_rollouts_per_iter time_env_samp_proc = time.time() samples_data = self.dynamics_sample_processor.process_samples( env_paths, log=True, log_prefix='Data-EnvTrajs-') self.env.log_diagnostics(env_paths, prefix='Data-EnvTrajs-') logger.record_tabular('Data-TimeEnvSampleProc', time.time() - time_env_samp_proc) buffer = samples_data if self.sample_from_buffer else None ''' --------------- fit dynamics model --------------- ''' logger.log("Training dynamics model for %i epochs ..." % self.dynamics_model_max_epochs) time_fit_start = time.time() self.dynamics_model.fit(samples_data['observations'], samples_data['actions'], samples_data['next_observations'], epochs=self.dynamics_model_max_epochs, verbose=False, log_tabular=True, prefix='Model-') logger.record_tabular('Model-TimeModelFit', time.time() - time_fit_start) env_paths = [] for id_rollout in range(self.num_rollouts_per_iter): times_dyn_sampling = [] times_dyn_sample_processing = [] times_optimization = [] times_step = [] grad_steps_per_rollout = self.grad_steps_per_rollout for step in range(grad_steps_per_rollout): # logger.log("\n ---------------- Grad-Step %d ----------------" % int(grad_steps_per_rollout*itr*self.num_rollouts_per_iter, # + id_rollout * grad_steps_per_rollout + step)) step_start_time = time.time() """ -------------------- Sampling --------------------------""" logger.log("Obtaining samples from the model...") time_env_sampling_start = time.time() paths = self.model_sampler.obtain_samples( log=True, log_prefix='Policy-', buffer=buffer) sampling_time = time.time() - time_env_sampling_start """ ----------------- Processing Samples ---------------------""" logger.log("Processing samples from the model...") time_proc_samples_start = time.time() samples_data = self.model_sample_processor.process_samples( paths, log='all', log_prefix='Policy-') proc_samples_time = time.time( ) - time_proc_samples_start if type(paths) is list: self.log_diagnostics(paths, prefix='Policy-') else: self.log_diagnostics(sum(paths.values(), []), prefix='Policy-') """ ------------------ Policy Update ---------------------""" logger.log("Optimizing policy...") time_optimization_step_start = time.time() self.algo.optimize_policy(samples_data) optimization_time = time.time( ) - time_optimization_step_start times_dyn_sampling.append(sampling_time) times_dyn_sample_processing.append(proc_samples_time) times_optimization.append(optimization_time) times_step.append(time.time() - step_start_time) logger.log( "Obtaining random samples from the environment...") env_paths.extend( self.env_sampler.obtain_samples( log=True, log_prefix='Data-EnvSampler-', verbose=True)) logger.record_tabular('Data-TimeEnvSampling', time.time() - time_env_sampling_start) logger.log("Processing environment samples...") """ ------------------- Logging Stuff --------------------------""" logger.logkv('Iteration', itr) logger.logkv('n_timesteps', self.env_sampler.total_timesteps_sampled) logger.logkv('Policy-TimeSampleProc', np.sum(times_dyn_sample_processing)) logger.logkv('Policy-TimeSampling', np.sum(times_dyn_sampling)) logger.logkv('Policy-TimeAlgoOpt', np.sum(times_optimization)) logger.logkv('Policy-TimeStep', np.sum(times_step)) logger.logkv('Time', time.time() - start_time) logger.logkv('ItrTime', time.time() - itr_start_time) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr) logger.save_itr_params(itr, params) logger.log("Saved") logger.dumpkvs() if itr == 0: sess.graph.finalize() logger.logkv('Trainer-TimeTotal', time.time() - start_time) logger.log("Training finished") self.sess.close()
def __call__( self, exp_dir, policy_pickle, env_pickle, baseline_pickle, dynamics_model_pickle, feed_dict, queue_prev, queue, queue_next, remote, start_itr, n_itr, stop_cond, need_query, auto_push, config, ): time_start = time.time() self.name = current_process().name logger.configure(dir=exp_dir + '/' + self.name, format_strs=['csv', 'stdout', 'log'], snapshot_mode=self.snapshot_mode, snapshot_gap=self.snapshot_gap) self.n_itr = n_itr self.queue_prev = queue_prev self.queue = queue self.queue_next = queue_next self.stop_cond = stop_cond # FIXME: specify CPU/GPU usage here import tensorflow as tf def _init_vars(): sess = tf.get_default_session() sess.run(tf.initializers.global_variables()) with tf.Session(config=config).as_default(): self.construct_from_feed_dict( policy_pickle, env_pickle, baseline_pickle, dynamics_model_pickle, feed_dict, ) _init_vars() # warm up self.itr_counter = start_itr if self.verbose: print('{} waiting for starting msg from trainer...'.format( self.name)) assert remote.recv() == 'prepare start' self.prepare_start() remote.send('loop ready') logger.dumpkvs() logger.log("\n============== {} is ready =============".format( self.name)) assert remote.recv() == 'start loop' total_push, total_synch, total_step = 0, 0, 0 while not self.stop_cond.is_set(): if self.verbose: logger.log( "\n------------------------- {} starting new loop ------------------" .format(self.name)) if need_query: # poll time_poll = time.time() queue_prev.put('push') time_poll = time.time() - time_poll logger.logkv('{}-TimePoll'.format(self.name), time_poll) do_push, do_synch, do_step = self.process_queue() # step if do_step: self.itr_counter += 1 self.step() if auto_push: do_push += 1 self.push() # Assuming doing autopush for all assert do_push == 1 assert do_step == 1 total_push += do_push total_synch += do_synch total_step += do_step logger.logkv(self.name + '-TimeSoFar', time.time() - time_start) logger.logkv(self.name + '-TotalPush', total_push) logger.logkv(self.name + '-TotalSynch', total_synch) logger.logkv(self.name + '-TotalStep', total_step) if total_synch > 0: logger.logkv(self.name + '-StepPerSynch', total_step / total_synch) logger.dumpkvs() logger.log( "\n========================== {} {}, total {} ===================" .format( self.name, (do_push, do_synch, do_step), (total_push, total_synch, total_step), )) self.set_stop_cond() remote.send('loop done') logger.log("\n================== {} closed ===================".format( self.name)) remote.send('worker closed')
def train(self): """ Trains policy on env using algo Pseudocode: for itr in n_itr: for step in num_inner_grad_steps: sampler.sample() algo.compute_updated_dists() algo.optimize_policy() sampler.update_goals() """ with self.sess.as_default() as sess: # initialize uninitialized vars (only initialize vars that were not loaded) uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))] sess.run(tf.variables_initializer(uninit_vars)) start_time = time.time() for itr in range(self.start_itr, self.n_itr): self.sampler.update_tasks() itr_start_time = time.time() logger.log("\n ---------------- Iteration %d ----------------" % itr) logger.log("Sampling set of tasks/goals for this meta-batch...") """ -------------------- Sampling --------------------------""" logger.log("Obtaining samples...") time_env_sampling_start = time.time() paths = self.sampler.obtain_samples(log=True, log_prefix='train-') sampling_time = time.time() - time_env_sampling_start """ ----------------- Processing Samples ---------------------""" logger.log("Processing samples...") time_proc_samples_start = time.time() samples_data = self.sample_processor.process_samples(paths, log='all', log_prefix='train-') proc_samples_time = time.time() - time_proc_samples_start if type(paths) is list: self.log_diagnostics(paths, prefix='train-') else: self.log_diagnostics(sum(paths.values(), []), prefix='train-') """ ------------------ Policy Update ---------------------""" logger.log("Optimizing policy...") # This needs to take all samples_data so that it can construct graph for meta-optimization. time_optimization_step_start = time.time() self.algo.optimize_policy(samples_data) """ ------------------- Logging Stuff --------------------------""" logger.logkv('Itr', itr) logger.logkv('n_timesteps', self.sampler.total_timesteps_sampled) logger.logkv('Time-Optimization', time.time() - time_optimization_step_start) logger.logkv('Time-SampleProc', np.sum(proc_samples_time)) logger.logkv('Time-Sampling', sampling_time) logger.logkv('Time', time.time() - start_time) logger.logkv('ItrTime', time.time() - itr_start_time) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr) logger.save_itr_params(itr, params) logger.log("Saved") logger.dumpkvs() if itr == 0: sess.graph.finalize() logger.log("Training finished") self.sess.close()
def train(self): """ Trains policy on env using algo Pseudocode: for itr in n_itr: for step in num_inner_grad_steps: sampler.sample() algo.compute_updated_dists() algo.optimize_policy() sampler.update_goals() """ with self.sess.as_default() as sess: # initialize uninitialized vars (only initialize vars that were not loaded) uninit_vars = [ var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var)) ] sess.run(tf.variables_initializer(uninit_vars)) start_time = time.time() for itr in range(self.start_itr, self.n_itr): itr_start_time = time.time() logger.log( "\n ---------------- Iteration %d ----------------" % itr) time_env_sampling_start = time.time() if self.initial_random_samples and itr == 0: logger.log( "Obtaining random samples from the environment...") env_paths = self.sampler.obtain_samples(log=True, random=True, log_prefix='') elif self.initial_sinusoid_samples and itr == 0: logger.log( "Obtaining sinusoidal samples from the environment using the policy..." ) env_paths = self.sampler.obtain_samples(log=True, log_prefix='', sinusoid=True) else: logger.log( "Obtaining samples from the environment using the policy..." ) env_paths = self.sampler.obtain_samples(log=True, log_prefix='') logger.record_tabular('Time-EnvSampling', time.time() - time_env_sampling_start) logger.log("Processing environment samples...") # first processing just for logging purposes time_env_samp_proc = time.time() samples_data = self.dynamics_sample_processor.process_samples( env_paths, log=True, log_prefix='EnvTrajs-') logger.record_tabular('Time-EnvSampleProc', time.time() - time_env_samp_proc) ''' --------------- fit dynamics model --------------- ''' time_fit_start = time.time() logger.log("Training dynamics model for %i epochs ..." % (self.dynamics_model_max_epochs)) self.dynamics_model.fit(samples_data['observations'], samples_data['actions'], samples_data['next_observations'], epochs=self.dynamics_model_max_epochs, verbose=False, log_tabular=True) logger.record_tabular('Time-ModelFit', time.time() - time_fit_start) """ ------------------- Logging Stuff --------------------------""" logger.logkv('Itr', itr) logger.logkv('n_timesteps', self.sampler.total_timesteps_sampled) logger.logkv('Time', time.time() - start_time) logger.logkv('ItrTime', time.time() - itr_start_time) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr) self.log_diagnostics(env_paths, '') logger.save_itr_params(itr, params) logger.log("Saved") logger.dumpkvs() if itr == 0: sess.graph.finalize() logger.log("Training finished") self.sess.close()