def step(self, random=False): time_step = time.time() '''------------- Obtaining samples from the environment -----------''' if self.verbose: logger.log("Data is obtaining samples...") env_paths = self.env_sampler.obtain_samples( log=True, random=random, log_prefix='Data-EnvSampler-', ) '''-------------- Processing environment samples -------------------''' if self.verbose: logger.log("Data is processing environment samples...") samples_data = self.dynamics_sample_processor.process_samples( env_paths, log=True, log_prefix='Data-EnvTrajs-', ) self.samples_data_arr.append(samples_data) time_step = time.time() - time_step time_sleep = max(self.simulation_sleep - time_step, 0) time.sleep(time_sleep) logger.logkv('Data-TimeStep', time_step) logger.logkv('Data-TimeSleep', time_sleep) # save snapshot params = self.get_itr_snapshot() logger.save_itr_params(self.itr_counter, params)
def step(self, random=False): time_step = time.time() '''------------- Obtaining samples from the environment -----------''' if self.verbose: logger.log("Data is obtaining samples...") env_paths = self.env_sampler.obtain_samples( log=True, random=random, log_prefix='Data-EnvSampler-', ) '''-------------- Processing environment samples -------------------''' if self.verbose: logger.log("Data is processing environment samples...") samples_data = self.dynamics_sample_processor.process_samples( env_paths, log=True, log_prefix='Data-EnvTrajs-', ) time_step = time.time() - time_step time_sleep = max(self.time_sleep - time_step, 0) time.sleep(time_sleep) logger.logkv('Data-TimeStep', time_step) logger.logkv('Data-TimeSleep', time_sleep) return samples_data
def process_queue(self): do_push, do_synch = 0, 0 data = None while True: try: if self.verbose: logger.log('{} try'.format(self.name)) new_data = self.queue.get_nowait() if new_data == 'push': # only happens when next worker has need_query = True if do_push == 0: # only push once do_push += 1 self.push() else: do_synch = 1 data = new_data except Empty: break if do_synch: self._synch(data) do_step = 1 # - do_synch if self.verbose: logger.log( '{} finishes processing queue with {}, {}, {}......'.format( self.name, do_push, do_synch, do_step)) return do_push, do_synch, do_step
def _synch(self, samples_data_arr, check_init=False): time_synch = time.time() if self.verbose: logger.log('Model at {} is synchronizing...'.format( self.itr_counter)) obs = np.concatenate([ samples_data['observations'] for samples_data in samples_data_arr ]) act = np.concatenate( [samples_data['actions'] for samples_data in samples_data_arr]) obs_next = np.concatenate([ samples_data['next_observations'] for samples_data in samples_data_arr ]) self.dynamics_model.update_buffer( obs=obs, act=act, obs_next=obs_next, check_init=check_init, ) # Reset variables for early stopping condition logger.logkv('Model-AvgEpochs', self.sum_model_itr / self.dynamics_model.num_models) self.sum_model_itr = 0 self.with_new_data = True self.remaining_model_idx = list(range(self.dynamics_model.num_models)) self.valid_loss_rolling_average = None time_synch = time.time() - time_synch logger.logkv('Model-TimeSynch', time_synch)
def optimize_policy(self, samples_data, log=True, prefix='', verbose=False): """ Performs MAML outer step Args: samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and meta task log (bool) : whether to log statistics Returns: None """ input_dict = self._extract_input_dict(samples_data, self._optimization_keys, prefix='train') if verbose: logger.log("Optimizing") loss_before = self.optimizer.optimize(input_val_dict=input_dict) if verbose: logger.log("Computing statistics") loss_after = self.optimizer.loss(input_val_dict=input_dict) if log: logger.logkv(prefix + 'LossBefore', loss_before) logger.logkv(prefix + 'LossAfter', loss_after)
def pull(self): time_synch = time.time() if self.verbose: logger.log('Policy is synchronizing...') model_params = ray.get(self.model_ps.pull.remote()) assert isinstance(model_params, dict) self.model_sampler.dynamics_model.set_shared_params(model_params) if hasattr(self.model_sampler, 'vec_env'): self.model_sampler.vec_env.dynamics_model.set_shared_params( model_params) logger.logkv('Policy-TimePull', time.time() - time_synch)
def _synch(self, dynamics_model_state_pickle): time_synch = time.time() if self.verbose: logger.log('Policy is synchronizing...') dynamics_model_state = pickle.loads(dynamics_model_state_pickle) assert isinstance(dynamics_model_state, dict) self.model_sampler.dynamics_model.set_shared_params( dynamics_model_state) if hasattr(self.model_sampler, 'vec_env'): self.model_sampler.vec_env.dynamics_model.set_shared_params( dynamics_model_state) time_synch = time.time() - time_synch logger.logkv('Policy-TimeSynch', time_synch)
def push(self): time_push = time.time() state_pickle = pickle.dumps( self.dynamics_model.get_shared_param_values()) assert state_pickle is not None while self.queue_next.qsize() > 5: try: logger.log('Model is off loading data from queue_next...') _ = self.queue_next.get_nowait() except Empty: break self.queue_next.put(state_pickle) time_push = time.time() - time_push logger.logkv('Model-TimePush', time_push)
def train(self): """ Trains policy on env using algo """ time_total = time.time() ''' --------------- worker looping --------------- ''' futures = [worker.start.remote() for worker in self.workers] logger.log('Start looping...') ray.get(futures) logger.logkv('Trainer-TimeTotal', time.time() - time_total) logger.dumpkvs() logger.log('***** Training finished ******')
def push(self): time_push = time.time() policy_state_pickle = pickle.dumps( self.policy.get_shared_param_values()) assert policy_state_pickle is not None while self.queue_next.qsize() > 5: try: logger.log('Policy is off loading data from queue_next...') _ = self.queue_next.get_nowait() except Empty: # very rare chance to reach here break self.queue_next.put(policy_state_pickle) time_push = time.time() - time_push logger.logkv('Policy-TimePush', time_push)
def step(self, obs=None, act=None, obs_next=None): time_model_fit = time.time() """ --------------- fit dynamics model --------------- """ if self.verbose: logger.log('Model at iteration {} is training for one epoch...'.format(self.step_counter)) self.remaining_model_idx, self.valid_loss_rolling_average = self.dynamics_model.fit_one_epoch( remaining_model_idx=self.remaining_model_idx, valid_loss_rolling_average_prev=self.valid_loss_rolling_average, with_new_data=self.with_new_data, verbose=self.verbose, log_tabular=True, prefix='Model-', ) self.with_new_data = False logger.logkv('Model-TimeStep', time.time() - time_model_fit)
def step(self, random=False): time_step = time.time() '''------------- Obtaining samples from the environment -----------''' if self.verbose: logger.log("Data is obtaining samples...") env_paths = self.env_sampler.obtain_samples( log=True, random=random, log_prefix='Data-EnvSampler-', ) '''-------------- Processing environment samples -------------------''' if self.verbose: logger.log("Data is processing samples...") if type(env_paths) is dict or type(env_paths) is OrderedDict: env_paths = list(env_paths.values()) idxs = np.random.choice(range(len(env_paths)), size=self.num_rollouts_per_iter, replace=False) env_paths = sum([env_paths[idx] for idx in idxs], []) elif type(env_paths) is list: idxs = np.random.choice(range(len(env_paths)), size=self.num_rollouts_per_iter, replace=False) env_paths = [env_paths[idx] for idx in idxs] else: raise TypeError samples_data = self.dynamics_sample_processor.process_samples( env_paths, log=True, log_prefix='Data-EnvTrajs-', ) self.samples_data_arr.append(samples_data) time_step = time.time() - time_step time_sleep = max(self.simulation_sleep - time_step, 0) time.sleep(time_sleep) logger.logkv('Data-TimeStep', time_step) logger.logkv('Data-TimeSleep', time_sleep)
def optimize_policy(self, all_samples_data, log=True, prefix='', verbose=False): """ Performs MAML outer step Args: all_samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and meta task log (bool) : whether to log statistics Returns: None """ meta_op_input_dict = self._extract_input_dict_meta_op( all_samples_data, self._optimization_keys) if verbose: logger.log("Computing KL before") mean_kl_before = self.optimizer.constraint_val(meta_op_input_dict) if verbose: logger.log("Computing loss before") loss_before = self.optimizer.loss(meta_op_input_dict) if verbose: logger.log("Optimizing") self.optimizer.optimize(meta_op_input_dict) if verbose: logger.log("Computing loss after") loss_after = self.optimizer.loss(meta_op_input_dict) if verbose: logger.log("Computing KL after") mean_kl = self.optimizer.constraint_val(meta_op_input_dict) if log: logger.logkv(prefix + 'MeanKLBefore', mean_kl_before) logger.logkv(prefix + 'MeanKL', mean_kl) logger.logkv(prefix + 'LossBefore', loss_before) logger.logkv(prefix + 'LossAfter', loss_after) logger.logkv(prefix + 'dLoss', loss_before - loss_after)
def start(self): logger.log(f"\n================ {self.name} starts ===============") time_start = time.time() with self.sess.as_default(): # loop while not ray.get(self.stop_cond.is_set.remote()): do_synch, do_step = self.step_wrapper() self.synch_counter += do_synch self.step_counter += do_step # logging logger.logkv(self.name + '-TimeSoFar', time.time() - time_start) logger.logkv(self.name + '-TotalStep', self.step_counter) logger.logkv(self.name + '-TotalSynch', self.synch_counter) logger.dumpkvs() self.set_stop_cond() logger.log( f"\n================== {self.name} closed ===================")
def step(self): time_step = time.time() """ -------------------- Sampling --------------------------""" if self.verbose: logger.log("Policy is obtaining samples ...") paths = self.model_sampler.obtain_samples(log=True, log_prefix='Policy-') """ ----------------- Processing Samples ---------------------""" if self.verbose: logger.log("Policy is processing samples ...") samples_data = self.model_sample_processor.process_samples( paths, log='all', log_prefix='Policy-') if type(paths) is list: self.log_diagnostics(paths, prefix='Policy-') else: self.log_diagnostics(sum(paths.values(), []), prefix='Policy-') """ ------------------ Policy Update ---------------------""" if self.verbose: logger.log("Policy optimization...") # This needs to take all samples_data so that it can construct graph for meta-optimization. self.algo.optimize_policy(samples_data, log=True, verbose=False, prefix='Policy-') self.policy = self.model_sampler.policy logger.logkv('Policy-TimeStep', time.time() - time_step)
def step(self): time_step = time.time() ''' --------------- MAML steps --------------- ''' self.policy.switch_to_pre_update() # Switch to pre-update policy all_samples_data = [] for step in range(self.num_inner_grad_steps + 1): if self.verbose: logger.log("Policy Adaptation-Step %d **" % step) """ -------------------- Sampling --------------------------""" paths = self.model_sampler.obtain_samples(log=True, log_prefix='Policy-', buffer=None) """ ----------------- Processing Samples ---------------------""" samples_data = self.model_sample_processor.process_samples( paths, log='all', log_prefix='Policy-') all_samples_data.append(samples_data) self.log_diagnostics(sum(list(paths.values()), []), prefix='Policy-') """ ------------------- Inner Policy Update --------------------""" if step < self.num_inner_grad_steps: self.algo._adapt(samples_data) """ ------------------ Outer Policy Update ---------------------""" if self.verbose: logger.log("Policy is optimizing...") # This needs to take all samples_data so that it can construct graph for meta-optimization. self.algo.optimize_policy(all_samples_data, prefix='Policy-') time_step = time.time() - time_step self.policy = self.model_sampler.policy logger.logkv('Policy-TimeStep', time_step)
def train(self): """ Trains policy on env using algo """ worker_data_queue, worker_model_queue, worker_policy_queue = self.queues worker_data_remote, worker_model_remote, worker_policy_remote = self.remotes for p in self.ps: p.start() ''' --------------- worker warm-up --------------- ''' logger.log('Prepare start...') worker_data_remote.send('prepare start') worker_data_queue.put(self.initial_random_samples) assert worker_data_remote.recv() == 'loop ready' worker_model_remote.send('prepare start') assert worker_model_remote.recv() == 'loop ready' worker_policy_remote.send('prepare start') assert worker_policy_remote.recv() == 'loop ready' time_total = time.time() ''' --------------- worker looping --------------- ''' logger.log('Start looping...') for remote in self.remotes: remote.send('start loop') ''' --------------- collect info --------------- ''' for remote in self.remotes: assert remote.recv() == 'loop done' logger.log('\n------------all workers exit loops -------------') for remote in self.remotes: assert remote.recv() == 'worker closed' for p in self.ps: p.terminate() logger.logkv('Trainer-TimeTotal', time.time() - time_total) logger.dumpkvs() logger.log("*****Training finished")
def step(self, random_sinusoid=(False, False)): time_step = time.time() if self.itr_counter == 1 and self.env_sampler.policy.dynamics_model.normalization is None: if self.verbose: logger.log('Data starts first step...') self.env_sampler.policy.dynamics_model = pickle.loads( self.queue.get()) if self.verbose: logger.log('Data first step done...') '''------------- Obtaining samples from the environment -----------''' if self.verbose: logger.log("Data is obtaining samples...") env_paths = self.env_sampler.obtain_samples( log=True, random=random_sinusoid[0], sinusoid=random_sinusoid[1], log_prefix='Data-EnvSampler-', ) '''-------------- Processing environment samples -------------------''' if self.verbose: logger.log("Data is processing samples...") samples_data = self.dynamics_sample_processor.process_samples( env_paths, log=True, log_prefix='Data-EnvTrajs-', ) self.samples_data_arr.append(samples_data) time_step = time.time() - time_step time_sleep = max(self.simulation_sleep - time_step, 0) time.sleep(time_sleep) logger.logkv('Data-TimeStep', time_step) logger.logkv('Data-TimeSleep', time_sleep)
def process_queue(self): do_push = 0 samples_data_arr = [] while True: try: if not self.remaining_model_idx: logger.log( 'Model at iteration {} is block waiting for data'. format(self.itr_counter)) # FIXME: check stop_cond time_wait = time.time() samples_data_arr_pickle = self.queue.get() time_wait = time.time() - time_wait logger.logkv('Model-TimeBlockWait', time_wait) self.remaining_model_idx = list( range(self.dynamics_model.num_models)) else: if self.verbose: logger.log('Model try get_nowait.........') samples_data_arr_pickle = self.queue.get_nowait() if samples_data_arr_pickle == 'push': # Only push once before executing another step if do_push == 0: do_push = 1 self.push() else: samples_data_arr.extend( pickle.loads(samples_data_arr_pickle)) except Empty: break do_synch = len(samples_data_arr) if do_synch: self._synch(samples_data_arr) do_step = 1 if self.verbose: logger.log( 'Model finishes processing queue with {}, {}, {}......'.format( do_push, do_synch, do_step)) return do_push, do_synch, do_step
from asynch_mb.logger import logger
def train(self): """ Trains policy on env using algo Pseudocode: for itr in n_itr: for step in num_inner_grad_steps: sampler.sample() algo.compute_updated_dists() algo.optimize_policy() sampler.update_goals() """ with self.sess.as_default() as sess: # initialize uninitialized vars (only initialize vars that were not loaded) uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))] sess.run(tf.variables_initializer(uninit_vars)) start_time = time.time() for itr in range(self.start_itr, self.n_itr): self.sampler.update_tasks() itr_start_time = time.time() logger.log("\n ---------------- Iteration %d ----------------" % itr) logger.log("Sampling set of tasks/goals for this meta-batch...") """ -------------------- Sampling --------------------------""" logger.log("Obtaining samples...") time_env_sampling_start = time.time() paths = self.sampler.obtain_samples(log=True, log_prefix='train-') sampling_time = time.time() - time_env_sampling_start """ ----------------- Processing Samples ---------------------""" logger.log("Processing samples...") time_proc_samples_start = time.time() samples_data = self.sample_processor.process_samples(paths, log='all', log_prefix='train-') proc_samples_time = time.time() - time_proc_samples_start if type(paths) is list: self.log_diagnostics(paths, prefix='train-') else: self.log_diagnostics(sum(paths.values(), []), prefix='train-') """ ------------------ Policy Update ---------------------""" logger.log("Optimizing policy...") # This needs to take all samples_data so that it can construct graph for meta-optimization. time_optimization_step_start = time.time() self.algo.optimize_policy(samples_data) """ ------------------- Logging Stuff --------------------------""" logger.logkv('Itr', itr) logger.logkv('n_timesteps', self.sampler.total_timesteps_sampled) logger.logkv('Time-Optimization', time.time() - time_optimization_step_start) logger.logkv('Time-SampleProc', np.sum(proc_samples_time)) logger.logkv('Time-Sampling', sampling_time) logger.logkv('Time', time.time() - start_time) logger.logkv('ItrTime', time.time() - itr_start_time) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr) logger.save_itr_params(itr, params) logger.log("Saved") logger.dumpkvs() if itr == 0: sess.graph.finalize() logger.log("Training finished") self.sess.close()
def __call__( self, exp_dir, policy_pickle, env_pickle, baseline_pickle, dynamics_model_pickle, feed_dict, queue_prev, queue, queue_next, remote, start_itr, n_itr, stop_cond, need_query, auto_push, config, ): time_start = time.time() self.name = current_process().name logger.configure(dir=exp_dir + '/' + self.name, format_strs=['csv', 'stdout', 'log'], snapshot_mode=self.snapshot_mode, snapshot_gap=self.snapshot_gap) self.n_itr = n_itr self.queue_prev = queue_prev self.queue = queue self.queue_next = queue_next self.stop_cond = stop_cond # FIXME: specify CPU/GPU usage here import tensorflow as tf def _init_vars(): sess = tf.get_default_session() sess.run(tf.initializers.global_variables()) with tf.Session(config=config).as_default(): self.construct_from_feed_dict( policy_pickle, env_pickle, baseline_pickle, dynamics_model_pickle, feed_dict, ) _init_vars() # warm up self.itr_counter = start_itr if self.verbose: print('{} waiting for starting msg from trainer...'.format( self.name)) assert remote.recv() == 'prepare start' self.prepare_start() remote.send('loop ready') logger.dumpkvs() logger.log("\n============== {} is ready =============".format( self.name)) assert remote.recv() == 'start loop' total_push, total_synch, total_step = 0, 0, 0 while not self.stop_cond.is_set(): if self.verbose: logger.log( "\n------------------------- {} starting new loop ------------------" .format(self.name)) if need_query: # poll time_poll = time.time() queue_prev.put('push') time_poll = time.time() - time_poll logger.logkv('{}-TimePoll'.format(self.name), time_poll) do_push, do_synch, do_step = self.process_queue() # step if do_step: self.itr_counter += 1 self.step() if auto_push: do_push += 1 self.push() # Assuming doing autopush for all assert do_push == 1 assert do_step == 1 total_push += do_push total_synch += do_synch total_step += do_step logger.logkv(self.name + '-TimeSoFar', time.time() - time_start) logger.logkv(self.name + '-TotalPush', total_push) logger.logkv(self.name + '-TotalSynch', total_synch) logger.logkv(self.name + '-TotalStep', total_step) if total_synch > 0: logger.logkv(self.name + '-StepPerSynch', total_step / total_synch) logger.dumpkvs() logger.log( "\n========================== {} {}, total {} ===================" .format( self.name, (do_push, do_synch, do_step), (total_push, total_synch, total_step), )) self.set_stop_cond() remote.send('loop done') logger.log("\n================== {} closed ===================".format( self.name)) remote.send('worker closed')
def optimize(self, input_val_dict, verbose=False): """ Carries out the optimization step Args: inputs (list): inputs for the optimization extra_inputs (list): extra inputs for the optimization subsample_grouped_inputs (None or list): subsample data from each element of the list """ if verbose: logger.log("Start CG optimization") logger.log("computing loss before") loss_before = self.loss(input_val_dict) if verbose: logger.log("performing update") logger.log("computing gradient") gradient = self.gradient(input_val_dict) if verbose: logger.log("gradient computed") logger.log("computing descent direction") Hx = self._hvp_approach.build_eval(input_val_dict) descent_direction = conjugate_gradients(Hx, gradient, cg_iters=self._cg_iters) initial_step_size = np.sqrt( 2.0 * self._max_constraint_val * (1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8))) if np.isnan(initial_step_size): logger.log("Initial step size is NaN! Rejecting the step!") return initial_descent_step = initial_step_size * descent_direction if verbose: logger.log("descent direction computed") prev_params = self._target.get_param_values() prev_params_values = _flatten_params(prev_params) loss, constraint_val, n_iter, violated = 0, 0, 0, False for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange( self._max_backtracks)): cur_step = ratio * initial_descent_step cur_params_values = prev_params_values - cur_step cur_params = _unflatten_params(cur_params_values, params_example=prev_params) self._target.set_params(cur_params) loss, constraint_val = self.loss( input_val_dict), self.constraint_val(input_val_dict) if loss < loss_before and constraint_val <= self._max_constraint_val: break """ ------------------- Logging Stuff -------------------------- """ if np.isnan(loss): violated = True logger.log("Line search violated because loss is NaN") if np.isnan(constraint_val): violated = True logger.log("Line search violated because constraint %s is NaN" % self._constraint_name) if loss >= loss_before: violated = True logger.log("Line search violated because loss not improving") if constraint_val >= self._max_constraint_val: violated = True logger.log( "Line search violated because constraint %s is violated" % self._constraint_name) if violated and not self._accept_violation: logger.log("Line search condition violated. Rejecting the step!") self._target.set_params(prev_params) if verbose: logger.log("backtrack iters: %d" % n_iter) logger.log("computing loss after") logger.log("optimization finished")
def train(self): """ Trains policy on env using algo Pseudocode: for itr in n_itr: for step in num_inner_grad_steps: sampler.sample() algo.compute_updated_dists() algo.optimize_policy() sampler.update_goals() """ with self.sess.as_default() as sess: # initialize uninitialized vars (only initialize vars that were not loaded) # uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))] # sess.run(tf.variables_initializer(uninit_vars)) sess.run(tf.global_variables_initializer()) start_time = time.time() for itr in range(self.start_itr, self.n_itr): itr_start_time = time.time() logger.log( "\n ---------------- Iteration %d ----------------" % itr) time_env_sampling_start = time.time() if itr == 0: logger.log( "Obtaining random samples from the environment...") self.env_sampler.total_samples *= self.num_rollouts_per_iter env_paths = self.env_sampler.obtain_samples( log=True, random=self.initial_random_samples, log_prefix='Data-EnvSampler-', verbose=True) self.env_sampler.total_samples /= self.num_rollouts_per_iter time_env_samp_proc = time.time() samples_data = self.dynamics_sample_processor.process_samples( env_paths, log=True, log_prefix='Data-EnvTrajs-') self.env.log_diagnostics(env_paths, prefix='Data-EnvTrajs-') logger.record_tabular('Data-TimeEnvSampleProc', time.time() - time_env_samp_proc) buffer = samples_data if self.sample_from_buffer else None ''' --------------- fit dynamics model --------------- ''' logger.log("Training dynamics model for %i epochs ..." % self.dynamics_model_max_epochs) time_fit_start = time.time() self.dynamics_model.fit(samples_data['observations'], samples_data['actions'], samples_data['next_observations'], epochs=self.dynamics_model_max_epochs, verbose=False, log_tabular=True, prefix='Model-') logger.record_tabular('Model-TimeModelFit', time.time() - time_fit_start) env_paths = [] for id_rollout in range(self.num_rollouts_per_iter): times_dyn_sampling = [] times_dyn_sample_processing = [] times_optimization = [] times_step = [] grad_steps_per_rollout = self.grad_steps_per_rollout for step in range(grad_steps_per_rollout): # logger.log("\n ---------------- Grad-Step %d ----------------" % int(grad_steps_per_rollout*itr*self.num_rollouts_per_iter, # + id_rollout * grad_steps_per_rollout + step)) step_start_time = time.time() """ -------------------- Sampling --------------------------""" logger.log("Obtaining samples from the model...") time_env_sampling_start = time.time() paths = self.model_sampler.obtain_samples( log=True, log_prefix='Policy-', buffer=buffer) sampling_time = time.time() - time_env_sampling_start """ ----------------- Processing Samples ---------------------""" logger.log("Processing samples from the model...") time_proc_samples_start = time.time() samples_data = self.model_sample_processor.process_samples( paths, log='all', log_prefix='Policy-') proc_samples_time = time.time( ) - time_proc_samples_start if type(paths) is list: self.log_diagnostics(paths, prefix='Policy-') else: self.log_diagnostics(sum(paths.values(), []), prefix='Policy-') """ ------------------ Policy Update ---------------------""" logger.log("Optimizing policy...") time_optimization_step_start = time.time() self.algo.optimize_policy(samples_data) optimization_time = time.time( ) - time_optimization_step_start times_dyn_sampling.append(sampling_time) times_dyn_sample_processing.append(proc_samples_time) times_optimization.append(optimization_time) times_step.append(time.time() - step_start_time) logger.log( "Obtaining random samples from the environment...") env_paths.extend( self.env_sampler.obtain_samples( log=True, log_prefix='Data-EnvSampler-', verbose=True)) logger.record_tabular('Data-TimeEnvSampling', time.time() - time_env_sampling_start) logger.log("Processing environment samples...") """ ------------------- Logging Stuff --------------------------""" logger.logkv('Iteration', itr) logger.logkv('n_timesteps', self.env_sampler.total_timesteps_sampled) logger.logkv('Policy-TimeSampleProc', np.sum(times_dyn_sample_processing)) logger.logkv('Policy-TimeSampling', np.sum(times_dyn_sampling)) logger.logkv('Policy-TimeAlgoOpt', np.sum(times_optimization)) logger.logkv('Policy-TimeStep', np.sum(times_step)) logger.logkv('Time', time.time() - start_time) logger.logkv('ItrTime', time.time() - itr_start_time) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr) logger.save_itr_params(itr, params) logger.log("Saved") logger.dumpkvs() if itr == 0: sess.graph.finalize() logger.logkv('Trainer-TimeTotal', time.time() - start_time) logger.log("Training finished") self.sess.close()
def fit(self, obs, act, obs_next, epochs=1000, compute_normalization=True, verbose=False, valid_split_ratio=None, rolling_average_persitency=None, log_tabular=False, early_stopping=True, prefix=''): """ Fits the NN dynamics model :param obs: observations - numpy array of shape (n_samples, ndim_obs) :param act: actions - numpy array of shape (n_samples, ndim_act) :param obs_next: observations after taking action - numpy array of shape (n_samples, ndim_obs) :param epochs: number of training epochs :param compute_normalization: boolean indicating whether normalization shall be (re-)computed given the data :param valid_split_ratio: relative size of validation split (float between 0.0 and 1.0) :param verbose: logging verbosity """ assert obs.ndim == 2 and obs.shape[1] == self.obs_space_dims assert obs_next.ndim == 2 and obs_next.shape[1] == self.obs_space_dims assert act.ndim == 2 and act.shape[1] == self.action_space_dims if valid_split_ratio is None: valid_split_ratio = self.valid_split_ratio if rolling_average_persitency is None: rolling_average_persitency = self.rolling_average_persitency assert 1 > valid_split_ratio >= 0 sess = tf.get_default_session() # split into valid and test set delta = obs_next - obs obs_train, act_train, delta_train, obs_test, act_test, delta_test = train_test_split( obs, act, delta, test_split_ratio=valid_split_ratio) if self._dataset_test is None: self._dataset_test = dict(obs=obs_test, act=act_test, delta=delta_test) self._dataset_train = dict(obs=obs_train, act=act_train, delta=delta_train) else: n_test_new_samples = len(obs_test) n_max_test = self.buffer_size - n_test_new_samples n_train_new_samples = len(obs_train) n_max_train = self.buffer_size - n_train_new_samples self._dataset_test['obs'] = np.concatenate( [self._dataset_test['obs'][-n_max_test:], obs_test]) self._dataset_test['act'] = np.concatenate( [self._dataset_test['act'][-n_max_test:], act_test]) self._dataset_test['delta'] = np.concatenate( [self._dataset_test['delta'][-n_max_test:], delta_test]) self._dataset_train['obs'] = np.concatenate( [self._dataset_train['obs'][-n_max_train:], obs_train]) self._dataset_train['act'] = np.concatenate( [self._dataset_train['act'][-n_max_train:], act_train]) self._dataset_train['delta'] = np.concatenate( [self._dataset_train['delta'][-n_max_train:], delta_train]) # create data queue if self.next_batch is None: self.next_batch, self.iterator = self._data_input_fn( self._dataset_train['obs'], self._dataset_train['act'], self._dataset_train['delta'], batch_size=self.batch_size, buffer_size=self.buffer_size) valid_loss_rolling_average = None if (self.normalization is None or compute_normalization) and self.normalize_input: self.compute_normalization(self._dataset_train['obs'], self._dataset_train['act'], self._dataset_train['delta']) if self.normalize_input: # normalize data obs_train, act_train, delta_train = self._normalize_data( self._dataset_train['obs'], self._dataset_train['act'], self._dataset_train['delta']) assert obs_train.ndim == act_train.ndim == delta_train.ndim == 2 else: obs_train = self._dataset_train['obs'] act_train = self._dataset_train['act'] delta_train = self._dataset_train['delta'] # Training loop for epoch in range(epochs): # initialize data queue sess.run(self.iterator.initializer, feed_dict={ self.obs_dataset_ph: obs_train, self.act_dataset_ph: act_train, self.delta_dataset_ph: delta_train }) batch_losses = [] while True: try: obs_batch, act_batch, delta_batch = sess.run( self.next_batch) # run train op batch_loss, _ = sess.run( [self.loss, self.train_op], feed_dict={ self.obs_ph: obs_batch, self.act_ph: act_batch, self.delta_ph: delta_batch }) batch_losses.append(batch_loss) except tf.errors.OutOfRangeError: # compute validation loss if self.normalize_input: # normalize data obs_test, act_test, delta_test = self._normalize_data( self._dataset_test['obs'], self._dataset_test['act'], self._dataset_test['delta']) assert obs_test.ndim == act_test.ndim == delta_test.ndim == 2 else: obs_test = self._dataset_test['obs'] act_test = self._dataset_test['act'] delta_test = self._dataset_test['delta'] valid_loss = sess.run(self.loss, feed_dict={ self.obs_ph: obs_test, self.act_ph: act_test, self.delta_ph: delta_test }) if valid_loss_rolling_average is None: valid_loss_rolling_average = 1.5 * valid_loss # set initial rolling to a higher value avoid too early stopping valid_loss_rolling_average_prev = 2.0 * valid_loss valid_loss_rolling_average = rolling_average_persitency * valid_loss_rolling_average + ( 1.0 - rolling_average_persitency) * valid_loss if verbose: logger.log( "Training NNDynamicsModel - finished epoch %i -- train loss: %.4f valid loss: %.4f valid_loss_mov_avg: %.4f" % (epoch, float(np.mean(batch_losses)), valid_loss, valid_loss_rolling_average)) break if early_stopping and valid_loss_rolling_average_prev < valid_loss_rolling_average: logger.log( 'Stopping DynamicsEnsemble Training since valid_loss_rolling_average decreased' ) break valid_loss_rolling_average_prev = valid_loss_rolling_average
def update_buffer(self, obs, act, obs_next, valid_split_ratio=None, check_init=True): assert obs.ndim == 2 and obs.shape[1] == self.obs_space_dims assert obs_next.ndim == 2 and obs_next.shape[1] == self.obs_space_dims assert act.ndim == 2 and act.shape[1] == self.action_space_dims self.timesteps_counter += obs.shape[0] if valid_split_ratio is None: valid_split_ratio = self.valid_split_ratio assert 1 > valid_split_ratio >= 0 # split into valid and test set obs_train_batches = [] act_train_batches = [] delta_train_batches = [] obs_test_batches = [] act_test_batches = [] delta_test_batches = [] delta = obs_next - obs for i in range(self.num_models): obs_train, act_train, delta_train, obs_test, act_test, delta_test = train_test_split( obs, act, delta, test_split_ratio=valid_split_ratio) obs_train_batches.append(obs_train) act_train_batches.append(act_train) delta_train_batches.append(delta_train) obs_test_batches.append(obs_test) act_test_batches.append(act_test) delta_test_batches.append(delta_test) # create data queue # If case should be entered exactly once if check_init and self._dataset_test is None: self._dataset_test = dict(obs=obs_test_batches, act=act_test_batches, delta=delta_test_batches) self._dataset_train = dict(obs=obs_train_batches, act=act_train_batches, delta=delta_train_batches) # assert self.next_batch is None self.next_batch, self.iterator = self._data_input_fn( self._dataset_train['obs'], self._dataset_train['act'], self._dataset_train['delta'], batch_size=self.batch_size) # assert self.normalization is None if self.normalize_input: self.compute_normalization(self._dataset_train['obs'], self._dataset_train['act'], self._dataset_train['delta']) else: n_test_new_samples = len(obs_test_batches[0]) n_max_test = self.buffer_size_test - n_test_new_samples n_train_new_samples = len(obs_train_batches[0]) n_max_train = self.buffer_size_train - n_train_new_samples for i in range(self.num_models): self._dataset_test['obs'][i] = np.concatenate([ self._dataset_test['obs'][i][-n_max_test:], obs_test_batches[i] ]) self._dataset_test['act'][i] = np.concatenate([ self._dataset_test['act'][i][-n_max_test:], act_test_batches[i] ]) self._dataset_test['delta'][i] = np.concatenate([ self._dataset_test['delta'][i][-n_max_test:], delta_test_batches[i] ]) self._dataset_train['obs'][i] = np.concatenate([ self._dataset_train['obs'][i][-n_max_train:], obs_train_batches[i] ]) self._dataset_train['act'][i] = np.concatenate([ self._dataset_train['act'][i][-n_max_train:], act_train_batches[i] ]) self._dataset_train['delta'][i] = np.concatenate([ self._dataset_train['delta'][i][-n_max_train:], delta_train_batches[i] ]) logger.log( 'Model has dataset_train, dataset_test with size {}, {}'.format( len(self._dataset_train['obs'][0]), len(self._dataset_test['obs'][0])))
def train(self): """ Trains policy on env using algo Pseudocode: for itr in n_itr: for step in num_inner_grad_steps: sampler.sample() algo.compute_updated_dists() algo.optimize_policy() sampler.update_goals() """ with self.sess.as_default() as sess: # initialize uninitialized vars (only initialize vars that were not loaded) # uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))] # sess.run(tf.variables_initializer(uninit_vars)) sess.run(tf.global_variables_initializer()) start_time = time.time() for itr in range(self.start_itr, self.n_itr): itr_start_time = time.time() logger.log("\n ---------------- Iteration %d ----------------" % itr) time_env_sampling_start = time.time() if self.initial_random_samples and itr == 0: logger.log("Obtaining random samples from the environment...") env_paths = self.env_sampler.obtain_samples(log=True, random=True, log_prefix='Data-EnvSampler-') else: logger.log("Obtaining samples from the environment using the policy...") env_paths = self.env_sampler.obtain_samples(log=True, log_prefix='Data-EnvSampler-') # Add sleeping time to match parallel experiment # time.sleep(10) logger.record_tabular('Data-TimeEnvSampling', time.time() - time_env_sampling_start) logger.log("Processing environment samples...") # first processing just for logging purposes time_env_samp_proc = time.time() samples_data = self.dynamics_sample_processor.process_samples(env_paths, log=True, log_prefix='Data-EnvTrajs-') self.env.log_diagnostics(env_paths, prefix='Data-EnvTrajs-') logger.record_tabular('Data-TimeEnvSampleProc', time.time() - time_env_samp_proc) ''' --------------- fit dynamics model --------------- ''' time_fit_start = time.time() self.dynamics_model.update_buffer(samples_data['observations'], samples_data['actions'], samples_data['next_observations'], check_init=True) buffer = None if not self.sample_from_buffer else samples_data logger.record_tabular('Model-TimeModelFit', time.time() - time_fit_start) ''' --------------- MAML steps --------------- ''' times_dyn_sampling = [] times_dyn_sample_processing = [] times_optimization = [] times_step = [] remaining_model_idx = list(range(self.dynamics_model.num_models)) valid_loss_rolling_average_prev = None with_new_data = True for id_step in range(self.repeat_steps): for epoch in range(self.num_epochs_per_step): logger.log("Training dynamics model for %i epochs ..." % 1) remaining_model_idx, valid_loss_rolling_average = self.dynamics_model.fit_one_epoch( remaining_model_idx, valid_loss_rolling_average_prev, with_new_data, log_tabular=True, prefix='Model-') with_new_data = False for step in range(self.num_grad_policy_per_step): logger.log("\n ---------------- Grad-Step %d ----------------" % int(itr * self.repeat_steps * self.num_grad_policy_per_step + id_step * self.num_grad_policy_per_step + step)) step_start_time = time.time() """ -------------------- Sampling --------------------------""" logger.log("Obtaining samples from the model...") time_env_sampling_start = time.time() paths = self.model_sampler.obtain_samples(log=True, log_prefix='Policy-', buffer=buffer) sampling_time = time.time() - time_env_sampling_start """ ----------------- Processing Samples ---------------------""" logger.log("Processing samples from the model...") time_proc_samples_start = time.time() samples_data = self.model_sample_processor.process_samples(paths, log='all', log_prefix='Policy-') proc_samples_time = time.time() - time_proc_samples_start if type(paths) is list: self.log_diagnostics(paths, prefix='Policy-') else: self.log_diagnostics(sum(paths.values(), []), prefix='Policy-') """ ------------------ Policy Update ---------------------""" logger.log("Optimizing policy...") # This needs to take all samples_data so that it can construct graph for meta-optimization. time_optimization_step_start = time.time() self.algo.optimize_policy(samples_data) optimization_time = time.time() - time_optimization_step_start times_dyn_sampling.append(sampling_time) times_dyn_sample_processing.append(proc_samples_time) times_optimization.append(optimization_time) times_step.append(time.time() - step_start_time) """ ------------------- Logging Stuff --------------------------""" logger.logkv('Iteration', itr) logger.logkv('n_timesteps', self.env_sampler.total_timesteps_sampled) logger.logkv('Policy-TimeSampleProc', np.sum(times_dyn_sample_processing)) logger.logkv('Policy-TimeSampling', np.sum(times_dyn_sampling)) logger.logkv('Policy-TimeAlgoOpt', np.sum(times_optimization)) logger.logkv('Policy-TimeStep', np.sum(times_step)) logger.logkv('Time', time.time() - start_time) logger.logkv('ItrTime', time.time() - itr_start_time) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr) logger.save_itr_params(itr, params) logger.log("Saved") logger.dumpkvs() if itr == 0: sess.graph.finalize() logger.logkv('Trainer-TimeTotal', time.time() - start_time) logger.log("Training finished") self.sess.close()
def train(self): """ Trains policy on env using algo Pseudocode: for itr in n_itr: for step in num_inner_grad_steps: sampler.sample() algo.compute_updated_dists() algo.optimize_policy() sampler.update_goals() """ with self.sess.as_default() as sess: # initialize uninitialized vars (only initialize vars that were not loaded) # uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))] # sess.run(tf.variables_initializer(uninit_vars)) sess.run(tf.global_variables_initializer()) if type(self.meta_steps_per_iter) is tuple: meta_steps_per_iter = np.linspace(self.meta_steps_per_iter[0] , self.meta_steps_per_iter[1], self.n_itr).astype(np.int) else: meta_steps_per_iter = [self.meta_steps_per_iter] * self.n_itr start_time = time.time() for itr in range(self.start_itr, self.n_itr): itr_start_time = time.time() logger.log("\n ---------------- Iteration %d ----------------" % itr) time_env_sampling_start = time.time() if self.initial_random_samples and itr == 0: logger.log("Obtaining random samples from the environment...") env_paths = self.env_sampler.obtain_samples(log=True, random=True, log_prefix='EnvSampler-') else: logger.log("Obtaining samples from the environment using the policy...") env_paths = self.env_sampler.obtain_samples(log=True, log_prefix='EnvSampler-') logger.record_tabular('Time-EnvSampling', time.time() - time_env_sampling_start) logger.log("Processing environment samples...") # first processing just for logging purposes time_env_samp_proc = time.time() if type(env_paths) is dict or type(env_paths) is collections.OrderedDict: env_paths = list(env_paths.values()) idxs = np.random.choice(range(len(env_paths)), size=self.num_rollouts_per_iter, replace=False) env_paths = sum([env_paths[idx] for idx in idxs], []) elif type(env_paths) is list: idxs = np.random.choice(range(len(env_paths)), size=self.num_rollouts_per_iter, replace=False) env_paths = [env_paths[idx] for idx in idxs] else: raise TypeError samples_data = self.dynamics_sample_processor.process_samples(env_paths, log=True, log_prefix='EnvTrajs-') self.env.log_diagnostics(env_paths, prefix='EnvTrajs-') logger.record_tabular('Time-EnvSampleProc', time.time() - time_env_samp_proc) ''' --------------- fit dynamics model --------------- ''' time_fit_start = time.time() logger.log("Training dynamics model for %i epochs ..." % (self.dynamics_model_max_epochs)) self.dynamics_model.fit(samples_data['observations'], samples_data['actions'], samples_data['next_observations'], epochs=self.dynamics_model_max_epochs, verbose=True, log_tabular=True) buffer = None if not self.sample_from_buffer else samples_data logger.record_tabular('Time-ModelFit', time.time() - time_fit_start) ''' ------------ log real performance --------------- ''' if self.log_real_performance: logger.log("Evaluating the performance of the real policy") self.policy.switch_to_pre_update() env_paths = self.env_sampler.obtain_samples(log=True, log_prefix='PrePolicy-') samples_data = self.model_sample_processor.process_samples(env_paths, log='all', log_prefix='PrePolicy-') self.algo._adapt(samples_data) env_paths = self.env_sampler.obtain_samples(log=True, log_prefix='PostPolicy-') self.model_sample_processor.process_samples(env_paths, log='all', log_prefix='PostPolicy-') ''' --------------- MAML steps --------------- ''' times_dyn_sampling = [] times_dyn_sample_processing = [] times_meta_sampling = [] times_inner_step = [] times_total_inner_step = [] times_outer_step = [] times_maml_steps = [] for meta_itr in range(meta_steps_per_iter[itr]): logger.log("\n ---------------- Meta-Step %d ----------------" % int(sum(meta_steps_per_iter[:itr]) + meta_itr)) self.policy.switch_to_pre_update() # Switch to pre-update policy all_samples_data, all_paths = [], [] list_sampling_time, list_inner_step_time, list_outer_step_time, list_proc_samples_time = [], [], [], [] time_maml_steps_start = time.time() start_total_inner_time = time.time() for step in range(self.num_inner_grad_steps+1): logger.log("\n ** Adaptation-Step %d **" % step) """ -------------------- Sampling --------------------------""" logger.log("Obtaining samples...") time_env_sampling_start = time.time() paths = self.model_sampler.obtain_samples(log=True, log_prefix='Step_%d-' % step, buffer=buffer) list_sampling_time.append(time.time() - time_env_sampling_start) all_paths.append(paths) """ ----------------- Processing Samples ---------------------""" logger.log("Processing samples...") time_proc_samples_start = time.time() samples_data = self.model_sample_processor.process_samples(paths, log='all', log_prefix='Step_%d-' % step) all_samples_data.append(samples_data) list_proc_samples_time.append(time.time() - time_proc_samples_start) self.log_diagnostics(sum(list(paths.values()), []), prefix='Step_%d-' % step) """ ------------------- Inner Policy Update --------------------""" time_inner_step_start = time.time() if step < self.num_inner_grad_steps: logger.log("Computing inner policy updates...") self.algo._adapt(samples_data) list_inner_step_time.append(time.time() - time_inner_step_start) total_inner_time = time.time() - start_total_inner_time time_maml_opt_start = time.time() """ ------------------ Outer Policy Update ---------------------""" logger.log("Optimizing policy...") # This needs to take all samples_data so that it can construct graph for meta-optimization. time_outer_step_start = time.time() self.algo.optimize_policy(all_samples_data) times_inner_step.append(list_inner_step_time) times_total_inner_step.append(total_inner_time) times_outer_step.append(time.time() - time_outer_step_start) times_meta_sampling.append(np.sum(list_sampling_time)) times_dyn_sampling.append(list_sampling_time) times_dyn_sample_processing.append(list_proc_samples_time) times_maml_steps.append(time.time() - time_maml_steps_start) """ ------------------- Logging Stuff --------------------------""" logger.logkv('Itr', itr) if self.log_real_performance: logger.logkv('n_timesteps', self.env_sampler.total_timesteps_sampled/(3 * self.policy.meta_batch_size) * self.num_rollouts_per_iter) else: logger.logkv('n_timesteps', self.env_sampler.total_timesteps_sampled/self.policy.meta_batch_size * self.num_rollouts_per_iter) logger.logkv('AvgTime-OuterStep', np.mean(times_outer_step)) logger.logkv('AvgTime-InnerStep', np.mean(times_inner_step)) logger.logkv('AvgTime-TotalInner', np.mean(times_total_inner_step)) logger.logkv('AvgTime-InnerStep', np.mean(times_inner_step)) logger.logkv('AvgTime-SampleProc', np.mean(times_dyn_sample_processing)) logger.logkv('AvgTime-Sampling', np.mean(times_dyn_sampling)) logger.logkv('AvgTime-MAMLSteps', np.mean(times_maml_steps)) logger.logkv('Time', time.time() - start_time) logger.logkv('ItrTime', time.time() - itr_start_time) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr) logger.save_itr_params(itr, params) logger.log("Saved") logger.dumpkvs() if itr == 0: sess.graph.finalize() logger.log("Training finished") self.sess.close()
def fit_one_epoch(self, remaining_model_idx, valid_loss_rolling_average_prev, with_new_data, compute_normalization=True, rolling_average_persitency=None, verbose=False, log_tabular=False, prefix=''): if rolling_average_persitency is None: rolling_average_persitency = self.rolling_average_persitency sess = tf.get_default_session() if with_new_data: if compute_normalization and self.normalize_input: self.compute_normalization(self._dataset_train['obs'], self._dataset_train['act'], self._dataset_train['delta']) self.used_timesteps_counter += len(self._dataset_train['obs'][0]) if self.normalize_input: # normalize data obs_train, act_train, delta_train = self._normalize_data( self._dataset_train['obs'], self._dataset_train['act'], self._dataset_train['delta']) else: obs_train, act_train, delta_train = self._dataset_train['obs'], self._dataset_train['act'], \ self._dataset_train['delta'] valid_loss_rolling_average = valid_loss_rolling_average_prev assert remaining_model_idx is not None train_op_to_do = [ op for idx, op in enumerate(self.train_op_model_batches) if idx in remaining_model_idx ] # initialize data queue feed_dict = dict( list(zip(self.obs_batches_dataset_ph, obs_train)) + list(zip(self.act_batches_dataset_ph, act_train)) + list(zip(self.delta_batches_dataset_ph, delta_train))) sess.run(self.iterator.initializer, feed_dict=feed_dict) # preparations for recording training stats batch_losses = [] """ ------- Looping through the shuffled and batched dataset for one epoch -------""" while True: try: obs_act_delta = sess.run(self.next_batch) obs_batch_stack = np.concatenate( obs_act_delta[:self.num_models], axis=0) act_batch_stack = np.concatenate( obs_act_delta[self.num_models:2 * self.num_models], axis=0) delta_batch_stack = np.concatenate( obs_act_delta[2 * self.num_models:], axis=0) # run train op batch_loss_train_ops = sess.run( self.loss_model_batches + train_op_to_do, feed_dict={ self.obs_model_batches_stack_ph: obs_batch_stack, self.act_model_batches_stack_ph: act_batch_stack, self.delta_model_batches_stack_ph: delta_batch_stack }) batch_loss = np.array(batch_loss_train_ops[:self.num_models]) batch_losses.append(batch_loss) except tf.errors.OutOfRangeError: if self.normalize_input: # TODO: if not with_new_data, don't recompute # normalize data obs_test, act_test, delta_test = self._normalize_data( self._dataset_test['obs'], self._dataset_test['act'], self._dataset_test['delta']) else: obs_test, act_test, delta_test = self._dataset_test['obs'], self._dataset_test['act'], \ self._dataset_test['delta'] obs_test_stack = np.concatenate(obs_test, axis=0) act_test_stack = np.concatenate(act_test, axis=0) delta_test_stack = np.concatenate(delta_test, axis=0) # compute validation loss valid_loss = sess.run(self.loss_model_batches, feed_dict={ self.obs_model_batches_stack_ph: obs_test_stack, self.act_model_batches_stack_ph: act_test_stack, self.delta_model_batches_stack_ph: delta_test_stack }) valid_loss = np.array(valid_loss) if valid_loss_rolling_average is None: valid_loss_rolling_average = 1.5 * valid_loss # set initial rolling to a higher value avoid too early stopping valid_loss_rolling_average_prev = 2.0 * valid_loss for i in range(len(valid_loss)): if valid_loss[i] < 0: valid_loss_rolling_average[i] = valid_loss[ i] / 1.5 # set initial rolling to a higher value avoid too early stopping valid_loss_rolling_average_prev[ i] = valid_loss[i] / 2.0 valid_loss_rolling_average = rolling_average_persitency*valid_loss_rolling_average \ + (1.0-rolling_average_persitency)*valid_loss if verbose: str_mean_batch_losses = ' '.join( ['%.4f' % x for x in np.mean(batch_losses, axis=0)]) str_valid_loss = ' '.join(['%.4f' % x for x in valid_loss]) str_valid_loss_rolling_averge = ' '.join( ['%.4f' % x for x in valid_loss_rolling_average]) logger.log( "Training NNDynamicsModel - finished one epoch\n" "train loss: %s\nvalid loss: %s\nvalid_loss_mov_avg: %s" % (str_mean_batch_losses, str_valid_loss, str_valid_loss_rolling_averge)) break for i in remaining_model_idx: if valid_loss_rolling_average_prev[i] < valid_loss_rolling_average[ i]: remaining_model_idx.remove(i) logger.log( 'Stop model {} since its valid_loss_rolling_average decreased' .format(i)) """ ------- Tabular Logging ------- """ if log_tabular: logger.logkv(prefix + 'TimeStepsCtr', self.timesteps_counter) logger.logkv(prefix + 'UsedTimeStepsCtr', self.used_timesteps_counter) logger.logkv(prefix + 'AvgSampleUsage', self.used_timesteps_counter / self.timesteps_counter) logger.logkv(prefix + 'NumModelRemaining', len(remaining_model_idx)) logger.logkv(prefix + 'AvgTrainLoss', np.mean(batch_losses)) logger.logkv(prefix + 'AvgValidLoss', np.mean(valid_loss)) logger.logkv(prefix + 'AvgValidLossRoll', np.mean(valid_loss_rolling_average)) return remaining_model_idx, valid_loss_rolling_average
def fit(self, obs, act, obs_next, epochs=1000, compute_normalization=True, valid_split_ratio=None, rolling_average_persitency=None, verbose=False, log_tabular=False, prefix=''): """ Fits the NN dynamics model :param obs: observations - numpy array of shape (n_samples, ndim_obs) :param act: actions - numpy array of shape (n_samples, ndim_act) :param obs_next: observations after taking action - numpy array of shape (n_samples, ndim_obs) :param epochs: number of training epochs :param compute_normalization: boolean indicating whether normalization shall be (re-)computed given the data :param valid_split_ratio: relative size of validation split (float between 0.0 and 1.0) :param (boolean) whether to log training stats in tabular format :param verbose: logging verbosity """ if rolling_average_persitency is None: rolling_average_persitency = self.rolling_average_persitency sess = tf.get_default_session() if obs is not None: self.update_buffer(obs, act, obs_next, valid_split_ratio, compute_normalization) if compute_normalization and self.normalize_input: self.compute_normalization(self._dataset_train['obs'], self._dataset_train['act'], self._dataset_train['delta']) if self.normalize_input: # normalize data obs_train, act_train, delta_train = self._normalize_data( self._dataset_train['obs'], self._dataset_train['act'], self._dataset_train['delta']) else: obs_train, act_train, delta_train = self._dataset_train['obs'], self._dataset_train['act'],\ self._dataset_train['delta'] valid_loss_rolling_average = None train_op_to_do = self.train_op_model_batches idx_to_remove = [] epoch_times = [] epochs_per_model = [] """ ------- Looping over training epochs ------- """ for epoch in range(epochs): # initialize data queue feed_dict = dict( list(zip(self.obs_batches_dataset_ph, obs_train)) + list(zip(self.act_batches_dataset_ph, act_train)) + list(zip(self.delta_batches_dataset_ph, delta_train))) sess.run(self.iterator.initializer, feed_dict=feed_dict) # preparations for recording training stats epoch_start_time = time.time() batch_losses = [] """ ------- Looping through the shuffled and batched dataset for one epoch -------""" while True: try: obs_act_delta = sess.run(self.next_batch) obs_batch_stack = np.concatenate( obs_act_delta[:self.num_models], axis=0) act_batch_stack = np.concatenate( obs_act_delta[self.num_models:2 * self.num_models], axis=0) delta_batch_stack = np.concatenate( obs_act_delta[2 * self.num_models:], axis=0) # run train op batch_loss_train_ops = sess.run( self.loss_model_batches + train_op_to_do, feed_dict={ self.obs_model_batches_stack_ph: obs_batch_stack, self.act_model_batches_stack_ph: act_batch_stack, self.delta_model_batches_stack_ph: delta_batch_stack }) batch_loss = np.array( batch_loss_train_ops[:self.num_models]) batch_losses.append(batch_loss) except tf.errors.OutOfRangeError: if self.normalize_input: # normalize data obs_test, act_test, delta_test = self._normalize_data( self._dataset_test['obs'], self._dataset_test['act'], self._dataset_test['delta']) else: obs_test, act_test, delta_test = self._dataset_test['obs'], self._dataset_test['act'], \ self._dataset_test['delta'] obs_test_stack = np.concatenate(obs_test, axis=0) act_test_stack = np.concatenate(act_test, axis=0) delta_test_stack = np.concatenate(delta_test, axis=0) # compute validation loss valid_loss = sess.run( self.loss_model_batches, feed_dict={ self.obs_model_batches_stack_ph: obs_test_stack, self.act_model_batches_stack_ph: act_test_stack, self.delta_model_batches_stack_ph: delta_test_stack }) valid_loss = np.array(valid_loss) if valid_loss_rolling_average is None: valid_loss_rolling_average = 1.5 * valid_loss # set initial rolling to a higher value avoid too early stopping valid_loss_rolling_average_prev = 2.0 * valid_loss for i in range(len(valid_loss)): if valid_loss[i] < 0: valid_loss_rolling_average[i] = valid_loss[ i] / 1.5 # set initial rolling to a higher value avoid too early stopping valid_loss_rolling_average_prev[ i] = valid_loss[i] / 2.0 valid_loss_rolling_average = rolling_average_persitency*valid_loss_rolling_average \ + (1.0-rolling_average_persitency)*valid_loss if verbose: str_mean_batch_losses = ' '.join([ '%.4f' % x for x in np.mean(batch_losses, axis=0) ]) str_valid_loss = ' '.join( ['%.4f' % x for x in valid_loss]) str_valid_loss_rolling_averge = ' '.join( ['%.4f' % x for x in valid_loss_rolling_average]) logger.log( "Training NNDynamicsModel - finished epoch %i --\n" "train loss: %s\nvalid loss: %s\nvalid_loss_mov_avg: %s" % (epoch, str_mean_batch_losses, str_valid_loss, str_valid_loss_rolling_averge)) break for i in range(self.num_models): if (valid_loss_rolling_average_prev[i] < valid_loss_rolling_average[i] or epoch == epochs - 1) and i not in idx_to_remove: idx_to_remove.append(i) epochs_per_model.append(epoch) if epoch < epochs - 1: logger.log( 'At Epoch {}, stop model {} since its valid_loss_rolling_average decreased' .format(epoch, i)) train_op_to_do = [ op for idx, op in enumerate(self.train_op_model_batches) if idx not in idx_to_remove ] if not idx_to_remove: epoch_times.append( time.time() - epoch_start_time ) # only track epoch times while all models are trained if not train_op_to_do: if verbose and epoch < epochs - 1: logger.log( 'Stopping all DynamicsEnsemble Training before reaching max_num_epochs' ) break valid_loss_rolling_average_prev = valid_loss_rolling_average """ ------- Tabular Logging ------- """ if log_tabular: logger.logkv(prefix + 'AvgModelEpochTime', np.mean(epoch_times)) assert len(epochs_per_model) == self.num_models logger.logkv(prefix + 'AvgEpochs', np.mean(epochs_per_model)) logger.logkv(prefix + 'StdEpochs', np.std(epochs_per_model)) logger.logkv(prefix + 'MaxEpochs', np.max(epochs_per_model)) logger.logkv(prefix + 'MinEpochs', np.min(epochs_per_model)) logger.logkv(prefix + 'AvgFinalTrainLoss', np.mean(batch_losses)) logger.logkv(prefix + 'AvgFinalValidLoss', np.mean(valid_loss)) logger.logkv(prefix + 'AvgFinalValidLossRoll', np.mean(valid_loss_rolling_average))