def train(self, sess=None): created_session = True if (sess is None) else False if sess is None: sess = tf.Session() sess.__enter__() sess.run(tf.global_variables_initializer()) self.start_worker(sess) start_time = time.time() for itr in range(self.start_itr, self.n_itr): itr_start_time = time.time() with logger.prefix('itr #%d | ' % itr): params = self.optimize_policy(itr, ) if self.plot: self.plotter.update_plot(self.policy, self.max_path_length) if self.pause_for_plot: input("Plotting evaluation run: Press Enter to " "continue...") logger.log("Saving snapshot...") logger.save_itr_params(itr, params) logger.log("Saved") logger.record_tabular('IterTime', time.time() - itr_start_time) logger.record_tabular('Time', time.time() - start_time) logger.dump_tabular() self.shutdown_worker() if created_session: sess.close()
def train(self): plotter = Plotter() if self.plot: plotter.init_plot(self.env, self.policy) self.start_worker() self.init_opt() for itr in range(self.current_itr, self.n_itr): with logger.prefix('itr #%d | ' % itr): paths = self.sampler.obtain_samples(itr) samples_data = self.sampler.process_samples(itr, paths) self.log_diagnostics(paths) self.optimize_policy(itr, samples_data) logger.log("saving snapshot...") params = self.get_itr_snapshot(itr, samples_data) self.current_itr = itr + 1 params["algo"] = self if self.store_paths: params["paths"] = samples_data["paths"] logger.save_itr_params(itr, params) logger.log("saved") logger.dump_tabular(with_prefix=False) if self.plot: plotter.update_plot(self.policy, self.max_path_length) if self.pause_for_plot: input("Plotting evaluation run: Press Enter to " "continue...") plotter.close() self.shutdown_worker()
def optimize_policy(self, itr, samples_data): self._top_algo.optimize_policy(itr, samples_data) with logger.prefix('ASA | '): make_skill, skill_subpath = self.decide_new_skill(samples_data) if make_skill: new_skill_pol, new_skill_id = self.create_and_train_new_skill(skill_subpath) self.integrate_new_skill(new_skill_id, skill_subpath)
def train_once(self, itr, paths): itr_start_time = time.time() with logger.prefix('itr #%d | ' % itr): self.log_diagnostics(paths) logger.log("Optimizing policy...") self.optimize_policy(itr, paths) logger.record_tabular('IterTime', time.time() - itr_start_time) logger.dump_tabular()
def train(self, sess=None): address = ("localhost", 6000) conn = Client(address) last_average_return = None try: created_session = True if (sess is None) else False if sess is None: sess = tf.Session() sess.__enter__() sess.run(tf.global_variables_initializer()) conn.send(ExpLifecycle.START) self.start_worker(sess) start_time = time.time() for itr in range(self.start_itr, self.n_itr): itr_start_time = time.time() with logger.prefix('itr #%d | ' % itr): logger.log("Obtaining samples...") conn.send(ExpLifecycle.OBTAIN_SAMPLES) paths = self.obtain_samples(itr) logger.log("Processing samples...") conn.send(ExpLifecycle.PROCESS_SAMPLES) samples_data = self.process_samples(itr, paths) last_average_return = samples_data["average_return"] logger.log("Logging diagnostics...") self.log_diagnostics(paths) logger.log("Optimizing policy...") conn.send(ExpLifecycle.OPTIMIZE_POLICY) self.optimize_policy(itr, samples_data) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr, samples_data) if self.store_paths: params["paths"] = samples_data["paths"] logger.save_itr_params(itr, params) logger.log("Saved") logger.record_tabular('Time', time.time() - start_time) logger.record_tabular('ItrTime', time.time() - itr_start_time) logger.dump_tabular(with_prefix=False) if self.plot: conn.send(ExpLifecycle.UPDATE_PLOT) self.plotter.update_plot(self.policy, self.max_path_length) if self.pause_for_plot: input("Plotting evaluation run: Press Enter to " "continue...") conn.send(ExpLifecycle.SHUTDOWN) self.shutdown_worker() if created_session: sess.close() finally: conn.close() return last_average_return
def _training_step(self, itr): itr_start_time = time.time() with logger.prefix('itr #%d | ' % itr): self._sampling() self._bookkeeping() self._memory_selection(itr) self._policy_optimization(itr) if itr % self.evaluation_interval == 0: self._policy_evaluation() self._log_diagnostics(itr) logger.record_tabular('Time', time.time() - self.start_time) logger.record_tabular('ItrTime', time.time() - itr_start_time) logger.dump_tabular(with_prefix=False)
def train(self, sess=None): created_session = True if (sess is None) else False if sess is None: sess = tf.Session() sess.__enter__() sess.run(tf.global_variables_initializer()) self.start_worker(sess) start_time = time.time() last_average_return = None for itr in range(self.start_itr, self.n_itr): itr_start_time = time.time() with logger.prefix('itr #%d | ' % itr): logger.log("Obtaining samples...") paths = self.obtain_samples(itr) logger.log("Processing samples...") samples_data = self.process_samples(itr, paths) last_average_return = samples_data["average_return"] logger.log("Logging diagnostics...") self.log_diagnostics(paths) logger.log("Optimizing policy...") self.optimize_policy(itr, samples_data) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr, samples_data) if self.store_paths: params["paths"] = samples_data["paths"] logger.save_itr_params(itr, params) logger.log("Saved") logger.record_tabular('Time', time.time() - start_time) logger.record_tabular('ItrTime', time.time() - itr_start_time) logger.dump_tabular(with_prefix=False) if self.plot: self.plotter.update_plot(self.policy, self.max_path_length) if self.pause_for_plot: input("Plotting evaluation run: Press Enter to " "continue...") self.shutdown_worker() if created_session: sess.close() return last_average_return
def train(self): with tf.Session() as sess: sess.run(tf.initialize_all_variables()) self.start_worker(sess) start_time = time.time() self.num_samples = 0 for itr in range(self.start_itr, self.n_itr): itr_start_time = time.time() with logger.prefix('itr #%d | ' % itr): logger.log("Obtaining new samples...") paths = self.obtain_samples(itr) for path in paths: self.num_samples += len(path["rewards"]) logger.log("total num samples..." + str(self.num_samples)) logger.log("Processing samples...") samples_data = self.process_samples(itr, paths) logger.log("Logging diagnostics...") self.log_diagnostics(paths) logger.log("Optimizing policy...") self.outer_optimize(samples_data) for sub_itr in range(self.n_sub_itr): logger.log("Minibatch Optimizing...") self.inner_optimize(samples_data) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr, samples_data) # , **kwargs) if self.store_paths: params["paths"] = samples_data["paths"] logger.save_itr_params(itr, params) logger.log("Saved") logger.record_tabular('Time', time.time() - start_time) logger.record_tabular('ItrTime', time.time() - itr_start_time) logger.dump_tabular(with_prefix=False) #if self.plot: # self.update_plot() # if self.pause_for_plot: # input("Plotting evaluation run: Press Enter to " # "continue...") self.shutdown_worker()
def train(self): address = ("localhost", 6000) conn = Client(address) try: plotter = Plotter() if self.plot: plotter.init_plot(self.env, self.policy) conn.send(ExpLifecycle.START) self.start_worker() self.init_opt() for itr in range(self.current_itr, self.n_itr): with logger.prefix('itr #%d | ' % itr): conn.send(ExpLifecycle.OBTAIN_SAMPLES) paths = self.sampler.obtain_samples(itr) conn.send(ExpLifecycle.PROCESS_SAMPLES) samples_data = self.sampler.process_samples(itr, paths) self.log_diagnostics(paths) conn.send(ExpLifecycle.OPTIMIZE_POLICY) self.optimize_policy(itr, samples_data) logger.log("saving snapshot...") params = self.get_itr_snapshot(itr, samples_data) self.current_itr = itr + 1 params["algo"] = self if self.store_paths: params["paths"] = samples_data["paths"] logger.save_itr_params(itr, params) logger.log("saved") logger.dump_tabular(with_prefix=False) if self.plot: conn.send(ExpLifecycle.UPDATE_PLOT) plotter.update_plot(self.policy, self.max_path_length) if self.pause_for_plot: input("Plotting evaluation run: Press Enter to " "continue...") conn.send(ExpLifecycle.SHUTDOWN) plotter.close() self.shutdown_worker() finally: conn.close()
def train(self, sess=None): created_session = True if (sess is None) else False if sess is None: sess = tf.Session() sess.__enter__() sess.run(tf.global_variables_initializer()) self.start_worker(sess) if self.use_target: self.f_init_target() episode_rewards = [] episode_policy_losses = [] episode_qf_losses = [] epoch_ys = [] epoch_qs = [] last_average_return = None for epoch in range(self.n_epochs): self.success_history.clear() with logger.prefix('epoch #%d | ' % epoch): for epoch_cycle in range(self.n_epoch_cycles): paths = self.obtain_samples(epoch) samples_data = self.process_samples(epoch, paths) episode_rewards.extend( samples_data["undiscounted_returns"]) self.success_history.extend( samples_data["success_history"]) self.log_diagnostics(paths) for train_itr in range(self.n_train_steps): if self.replay_buffer.n_transitions_stored >= self.min_buffer_size: # noqa: E501 self.evaluate = True qf_loss, y, q, policy_loss = self.optimize_policy( epoch, samples_data) episode_policy_losses.append(policy_loss) episode_qf_losses.append(qf_loss) epoch_ys.append(y) epoch_qs.append(q) if self.plot: self.plotter.update_plot(self.policy, self.max_path_length) if self.pause_for_plot: input("Plotting evaluation run: Press Enter to " "continue...") logger.log("Training finished") logger.log("Saving snapshot #{}".format(epoch)) params = self.get_itr_snapshot(epoch, samples_data) logger.save_itr_params(epoch, params) logger.log("Saved") if self.evaluate: logger.record_tabular('Epoch', epoch) logger.record_tabular('AverageReturn', np.mean(episode_rewards)) logger.record_tabular('StdReturn', np.std(episode_rewards)) logger.record_tabular('Policy/AveragePolicyLoss', np.mean(episode_policy_losses)) logger.record_tabular('QFunction/AverageQFunctionLoss', np.mean(episode_qf_losses)) logger.record_tabular('QFunction/AverageQ', np.mean(epoch_qs)) logger.record_tabular('QFunction/MaxQ', np.max(epoch_qs)) logger.record_tabular('QFunction/AverageAbsQ', np.mean(np.abs(epoch_qs))) logger.record_tabular('QFunction/AverageY', np.mean(epoch_ys)) logger.record_tabular('QFunction/MaxY', np.max(epoch_ys)) logger.record_tabular('QFunction/AverageAbsY', np.mean(np.abs(epoch_ys))) if self.input_include_goal: logger.record_tabular('AverageSuccessRate', np.mean(self.success_history)) last_average_return = np.mean(episode_rewards) if not self.smooth_return: episode_rewards = [] episode_policy_losses = [] episode_qf_losses = [] epoch_ys = [] epoch_qs = [] logger.dump_tabular(with_prefix=False) self.shutdown_worker() if created_session: sess.close() return last_average_return
def train(self, sess=None): created_session = True if (sess is None) else False if sess is None: sess = tf.Session() sess.__enter__() sess.run(tf.global_variables_initializer()) # Initialize some missing variables uninitialized_vars = [] for var in tf.all_variables(): try: sess.run(var) except tf.errors.FailedPreconditionError: print("Uninitialized var: ", var) uninitialized_vars.append(var) init_new_vars_op = tf.initialize_variables(uninitialized_vars) sess.run(init_new_vars_op) self.start_worker(sess) start_time = time.time() last_average_return = None samples_total = 0 for itr in range(self.start_itr, self.n_itr): if samples_total >= self.max_samples: print("WARNING: Total max num of samples collected: %d >= %d" % (samples_total, self.max_samples)) break itr_start_time = time.time() with logger.prefix('itr #%d | ' % itr): logger.log("Obtaining samples...") paths = self.obtain_samples(itr) samples_total += self.batch_size logger.log("Processing samples...") samples_data = self.process_samples(itr, paths) last_average_return = samples_data["average_return"] logger.log("Logging diagnostics...") self.log_diagnostics(paths) logger.log("Optimizing policy...") self.optimize_policy(itr, samples_data) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr, samples_data) # import pdb; pdb.set_trace() if self.store_paths: ## WARN: Beware that data is saved to hdf in float32 by default # see param float_nptype h5u.append_train_iter_data(h5file=self.hdf, data=samples_data["paths"], data_group="traj_data/", teacher_indx=self.teacher_indx, itr=None, float_nptype=np.float32) # params["paths"] = samples_data["paths"] logger.save_itr_params(itr, params) logger.log("Saved") logger.record_tabular('Time', time.time() - start_time) logger.record_tabular('ItrTime', time.time() - itr_start_time) self.log_env_info(samples_data["env_infos"]) logger.dump_tabular(with_prefix=False) if self.plot: self.plotter.update_plot(self.policy, self.max_path_length) if self.pause_for_plot: input( "Plotting evaluation run: Press Enter to continue..." ) # Showing policy from time to time if self.record_every_itr is not None and self.record_every_itr > 0 and itr % self.record_every_itr == 0: self.record_policy(env=self.env, policy=self.policy, itr=itr) if self.play_every_itr is not None and self.play_every_itr > 0 and itr % self.play_every_itr == 0: self.play_policy(env=self.env, policy=self.policy) # Recording a few episodes at the end if self.record_end_ep_num is not None: for i in range(self.record_end_ep_num): self.record_policy(env=self.env, policy=self.policy, itr=itr, postfix="_%02d" % i) # Reporting termination criteria if itr >= self.n_itr - 1: print( "TERM CRITERIA: Max number of iterations reached itr: %d , itr_max: %d" % (itr, self.n_itr - 1)) if samples_total >= self.max_samples: print( "TERM CRITERIA: Total max num of samples collected: %d >= %d" % (samples_total, self.max_samples)) self.shutdown_worker() if created_session: sess.close()
def run_task(*_): # Configure TF session config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config).as_default() as tf_session: ## Load data from itr_N.pkl with open(snapshot_file, 'rb') as file: saved_data = dill.load(file) ## Construct PathTrie and find missing skill description # This is basically ASA.decide_new_skill min_length = 3 max_length = 5 action_map = {0: 's', 1: 'L', 2: 'R'} min_f_score = 1 max_results = 10 aggregations = [] # sublist of ['mean', 'most_freq', 'nearest_mean', 'medoid'] or 'all' paths = saved_data['paths'] path_trie = PathTrie(saved_data['hrl_policy'].num_skills) for path in paths: actions = path['actions'].argmax(axis=1).tolist() observations = path['observations'] path_trie.add_all_subpaths( actions, observations, min_length=min_length, max_length=max_length ) logger.log('Searched {} rollouts'.format(len(paths))) frequent_paths = path_trie.items( action_map=action_map, min_count=10, # len(paths) * 2 min_f_score=min_f_score, max_results=max_results, aggregations=aggregations ) logger.log('Found {} frequent paths: [index, actions, count, f-score]'.format(len(frequent_paths))) for i, f_path in enumerate(frequent_paths): logger.log(' {:2}: {:{pad}}\t{}\t{:.3f}'.format( i, f_path['actions_text'], f_path['count'], f_path['f_score'], pad=max_length)) top_subpath = frequent_paths[0] start_obss = top_subpath['start_observations'] end_obss = top_subpath['end_observations'] ## Prepare elements for training # Environment base_env = saved_data['env'].env.env # <NormalizedEnv<MinibotEnv instance>> skill_learning_env = TfEnv( SkillLearningEnv( # base env that was wrapped in HierarchizedEnv (not fully unwrapped - may be normalized!) env=base_env, start_obss=start_obss, end_obss=end_obss ) ) # Skill policy hrl_policy = saved_data['hrl_policy'] new_skill_policy, new_skill_id = hrl_policy.create_new_skill( end_obss=end_obss ) # Baseline - clone baseline specified in low_algo_kwargs, or top-algo`s baseline low_algo_kwargs = dict(saved_data['low_algo_kwargs']) baseline_to_clone = low_algo_kwargs.get('baseline', saved_data['baseline']) baseline = Serializable.clone( # to create blank baseline obj=baseline_to_clone, name='{}Skill{}'.format(type(baseline_to_clone).__name__, new_skill_id) ) low_algo_kwargs['baseline'] = baseline low_algo_cls = saved_data['low_algo_cls'] # Set custom training params (should`ve been set in asa_basic_run) low_algo_kwargs['batch_size'] = 2500 low_algo_kwargs['max_path_length'] = 50 low_algo_kwargs['n_itr'] = 500 # Algorithm algo = low_algo_cls( env=skill_learning_env, policy=new_skill_policy, **low_algo_kwargs ) # Logger parameters logger_snapshot_dir_before = logger.get_snapshot_dir() logger_snapshot_mode_before = logger.get_snapshot_mode() logger_snapshot_gap_before = logger.get_snapshot_gap() # No need to change snapshot dir in this script, it is used in ASA-algo.create_and_train_new_skill() # logger.set_snapshot_dir(os.path.join( # logger_snapshot_dir_before, # 'skill{}'.format(new_skill_id) # )) logger.set_snapshot_mode('none') logger.set_tensorboard_step_key('Iteration') ## Train new skill with logger.prefix('Skill {} | '.format(new_skill_id)): algo.train(sess=tf_session) ## Save new policy and its end_obss (we`ll construct skill stopping function # from them manually in asa_resume_with_new_skill.py) out_file = os.path.join(logger.get_snapshot_dir(), 'final.pkl') with open(out_file, 'wb') as file: out_data = { 'policy': new_skill_policy, 'subpath': top_subpath } dill.dump(out_data, file) # Restore logger parameters logger.set_snapshot_dir(logger_snapshot_dir_before) logger.set_snapshot_mode(logger_snapshot_mode_before) logger.set_snapshot_gap(logger_snapshot_gap_before)
def create_and_train_new_skill(self, skill_subpath): """ Create and train a new skill based on given subpath. The new skill policy and ID are returned, and also saved in self._hrl_policy. """ ## Prepare elements for training # Environment skill_learning_env = TfEnv( SkillLearningEnv( # base env that was wrapped in HierarchizedEnv (not fully unwrapped - may be normalized!) env=self.env.env.env, start_obss=skill_subpath['start_observations'], end_obss=skill_subpath['end_observations'] ) ) # Skill policy new_skill_pol, new_skill_id = self._hrl_policy.create_new_skill(skill_subpath['end_observations']) # blank policy to be trained # Baseline - clone baseline specified in low_algo_kwargs, or top-algo`s baseline # We need to clone baseline, as each skill policy must have its own instance la_kwargs = dict(self._low_algo_kwargs) baseline_to_clone = la_kwargs.get('baseline', self.baseline) baseline = Serializable.clone( # to create blank baseline obj=baseline_to_clone, name='{}Skill{}'.format(type(baseline_to_clone).__name__, new_skill_id) ) la_kwargs['baseline'] = baseline # Algorithm algo = self._low_algo_cls( env=skill_learning_env, policy=new_skill_pol, **la_kwargs ) # Logger parameters logger.dump_tabular(with_prefix=False) logger.log('Launching training of the new skill') logger_snapshot_dir_before = logger.get_snapshot_dir() logger_snapshot_mode_before = logger.get_snapshot_mode() logger_snapshot_gap_before = logger.get_snapshot_gap() logger.set_snapshot_dir(os.path.join( logger_snapshot_dir_before, 'skill{}'.format(new_skill_id) )) logger.set_snapshot_mode('none') # logger.set_snapshot_gap(max(1, np.floor(la_kwargs['n_itr'] / 10))) logger.push_tabular_prefix('Skill{}/'.format(new_skill_id)) logger.set_tensorboard_step_key('Iteration') # Train new skill with logger.prefix('Skill {} | '.format(new_skill_id)): algo.train(sess=self._tf_sess) # Restore logger parameters logger.pop_tabular_prefix() logger.set_snapshot_dir(logger_snapshot_dir_before) logger.set_snapshot_mode(logger_snapshot_mode_before) logger.set_snapshot_gap(logger_snapshot_gap_before) logger.log('Training of the new skill finished') return new_skill_pol, new_skill_id