def save(self, epoch, paths=None): """Save snapshot of current batch. Args: itr(int): Index of iteration (epoch). paths(dict): Batch of samples after preprocessed. """ assert self.has_setup logger.log('Saving snapshot...') params = dict() # Save arguments params['setup_args'] = self.setup_args params['train_args'] = self.train_args # Save states params['env'] = self.env params['algo'] = self.algo if paths: params['paths'] = paths params['last_epoch'] = epoch snapshotter.save_itr_params(epoch, params) logger.log('Saved')
def train(self): plotter = Plotter() if self.plot: plotter.init_plot(self.env, self.policy) self.start_worker() self.init_opt() for itr in range(self.current_itr, self.n_itr): with logger.prefix('itr #{} | '.format(itr)): paths = self.sampler.obtain_samples(itr) samples_data = self.sampler.process_samples(itr, paths) self.log_diagnostics(paths) self.optimize_policy(itr, samples_data) logger.log('Saving snapshot...') params = self.get_itr_snapshot(itr, samples_data) self.current_itr = itr + 1 params['algo'] = self if self.store_paths: params['paths'] = samples_data['paths'] snapshotter.save_itr_params(itr, params) logger.log('saved') logger.log(tabular) if self.plot: plotter.update_plot(self.policy, self.max_path_length) if self.pause_for_plot: input('Plotting evaluation run: Press Enter to ' 'continue...') plotter.close() self.shutdown_worker()
def train(self, sess=None): address = ('localhost', 6000) conn = Client(address) last_average_return = None try: created_session = True if (sess is None) else False if sess is None: sess = tf.compat.v1.Session() sess.__enter__() sess.run(tf.compat.v1.global_variables_initializer()) conn.send(ExpLifecycle.START) self.start_worker(sess) start_time = time.time() for itr in range(self.start_itr, self.n_itr): itr_start_time = time.time() with logger.prefix('itr #%d | ' % itr): logger.log('Obtaining samples...') conn.send(ExpLifecycle.OBTAIN_SAMPLES) paths = self.obtain_samples(itr) logger.log('Processing samples...') conn.send(ExpLifecycle.PROCESS_SAMPLES) samples_data = self.process_samples(itr, paths) last_average_return = samples_data['average_return'] logger.log('Logging diagnostics...') self.log_diagnostics(paths) logger.log('Optimizing policy...') conn.send(ExpLifecycle.OPTIMIZE_POLICY) self.optimize_policy(itr, samples_data) logger.log('Saving snapshot...') params = self.get_itr_snapshot(itr) if self.store_paths: params['paths'] = samples_data['paths'] snapshotter.save_itr_params(itr, params) logger.log('Saved') tabular.record('Time', time.time() - start_time) tabular.record('ItrTime', time.time() - itr_start_time) logger.log(tabular) if self.plot: conn.send(ExpLifecycle.UPDATE_PLOT) self.plotter.update_plot(self.policy, self.max_path_length) if self.pause_for_plot: input('Plotting evaluation run: Press Enter to ' 'continue...') conn.send(ExpLifecycle.SHUTDOWN) self.shutdown_worker() if created_session: sess.close() finally: conn.close() return last_average_return
def train(self): address = ('localhost', 6000) conn = Client(address) try: plotter = Plotter() if self.plot: plotter.init_plot(self.env, self.policy) conn.send(ExpLifecycle.START) self.start_worker() self.init_opt() for itr in range(self.current_itr, self.n_itr): with logger.prefix('itr #{} | '.format(itr)): conn.send(ExpLifecycle.OBTAIN_SAMPLES) paths = self.sampler.obtain_samples(itr) conn.send(ExpLifecycle.PROCESS_SAMPLES) samples_data = self.sampler.process_samples(itr, paths) self.log_diagnostics(paths) conn.send(ExpLifecycle.OPTIMIZE_POLICY) self.optimize_policy(itr, samples_data) logger.log('saving snapshot...') params = self.get_itr_snapshot(itr, samples_data) self.current_itr = itr + 1 params['algo'] = self if self.store_paths: params['paths'] = samples_data['paths'] snapshotter.save_itr_params(itr, params) logger.log('saved') logger.log(tabular) if self.plot: conn.send(ExpLifecycle.UPDATE_PLOT) plotter.update_plot(self.policy, self.max_path_length) if self.pause_for_plot: input('Plotting evaluation run: Press Enter to ' 'continue...') conn.send(ExpLifecycle.SHUTDOWN) plotter.close() self.shutdown_worker() finally: conn.close()