def optimize_policy(self, all_samples_data, log=True): """ Performs MAML outer step Args: all_samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and meta task log (bool) : whether to log statistics Returns: None """ meta_op_input_dict = self._extract_input_dict_meta_op(all_samples_data, self._optimization_keys) logger.log("Computing KL before") mean_kl_before = self.optimizer.constraint_val(meta_op_input_dict) logger.log("Computing loss before") loss_before = self.optimizer.loss(meta_op_input_dict) logger.log("Optimizing") self.optimizer.optimize(meta_op_input_dict) logger.log("Computing loss after") loss_after = self.optimizer.loss(meta_op_input_dict) logger.log("Computing KL after") mean_kl = self.optimizer.constraint_val(meta_op_input_dict) if log: logger.logkv('MeanKLBefore', mean_kl_before) logger.logkv('MeanKL', mean_kl) logger.logkv('LossBefore', loss_before) logger.logkv('LossAfter', loss_after) logger.logkv('dLoss', loss_before - loss_after)
def optimize_policy(self, all_samples_data, log=True): """ Performs MAML outer step Args: all_samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and meta task log (bool) : whether to log statistics Returns: None """ meta_op_input_dict = self._extract_input_dict_meta_op( all_samples_data, self._optimization_keys) if log: logger.log("Optimizing") loss_before = self.optimizer.optimize( input_val_dict=meta_op_input_dict) if log: logger.log("Computing statistics") loss_after = self.optimizer.loss(input_val_dict=meta_op_input_dict) if log: logger.logkv('LossBefore', loss_before) logger.logkv('LossAfter', loss_after)
def optimize(self, input_val_dict): """ Carries out the optimization step Args: input_val_dict (dict): dict containing the values to be fed into the computation graph Returns: (float) loss before optimization """ sess = tf.compat.v1.get_default_session() batch_size, seq_len, *_ = list(input_val_dict.values())[0].shape loss_before_opt = None for epoch in range(self._max_epochs): hidden_batch = self._target.get_zero_state(batch_size) if self._verbose: logger.log("Epoch %d" % epoch) # run train op loss = [] all_grads = [] for i in range(0, seq_len, self._backprop_steps): n_i = i + self._backprop_steps feed_dict = dict([(self._input_ph_dict[key], input_val_dict[key][:, i:n_i]) for key in self._input_ph_dict.keys()]) feed_dict[self._hidden_ph] = hidden_batch batch_loss, grads, hidden_batch = sess.run( [self._loss, self._gradients_var, self._next_hidden_var], feed_dict=feed_dict) loss.append(batch_loss) all_grads.append(grads) grads = [np.mean(grad, axis=0) for grad in zip(*all_grads)] feed_dict = dict(zip(self._gradients_ph, grads)) _ = sess.run(self._train_op, feed_dict=feed_dict) if not loss_before_opt: loss_before_opt = np.mean(loss) # if self._verbose: # logger.log("Epoch: %d | Loss: %f" % (epoch, new_loss)) # # if abs(last_loss - new_loss) < self._tolerance: # break # last_loss = new_loss return loss_before_opt
def optimize_policy(self, all_samples_data, log=True): """ Performs MAML outer step Args: all_samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and meta task log (bool) : whether to log statistics Returns: None """ meta_op_input_dict = self._extract_input_dict_meta_op( all_samples_data, self._optimization_keys) # add kl_coeffs / clip_eps to meta_op_input_dict meta_op_input_dict['inner_kl_coeff'] = self.inner_kl_coeff if self.clip_outer: meta_op_input_dict['clip_eps'] = self.clip_eps else: meta_op_input_dict['outer_kl_coeff'] = self.outer_kl_coeff if log: logger.log("Optimizing") loss_before = self.optimizer.optimize( input_val_dict=meta_op_input_dict) if log: logger.log("Computing statistics") loss_after, inner_kls, outer_kl = self.optimizer.compute_stats( input_val_dict=meta_op_input_dict) if self.adaptive_inner_kl_penalty: if log: logger.log("Updating inner KL loss coefficients") self.inner_kl_coeff = self.adapt_kl_coeff(self.inner_kl_coeff, inner_kls, self.target_inner_step) if self.adaptive_outer_kl_penalty: if log: logger.log("Updating outer KL loss coefficients") self.outer_kl_coeff = self.adapt_kl_coeff(self.outer_kl_coeff, outer_kl, self.target_outer_step) if log: logger.logkv('LossBefore', loss_before) logger.logkv('LossAfter', loss_after) logger.logkv('KLInner', np.mean(inner_kls)) logger.logkv('KLCoeffInner', np.mean(self.inner_kl_coeff)) if not self.clip_outer: logger.logkv('KLOuter', outer_kl)
def optimize(self, input_val_dict): """ Carries out the optimization step Args: input_val_dict (dict): dict containing the values to be fed into the computation graph Returns: (float) loss before optimization """ sess = tf.compat.v1.get_default_session() feed_dict = self.create_feed_dict(input_val_dict) # Overload self._batch size # dataset = MAMLBatchDataset(inputs, num_batches=self._batch_size, extra_inputs=extra_inputs, # meta_batch_size=self.meta_batch_size, num_grad_updates=self.num_grad_updates) # Todo: reimplement minibatches loss_before_opt = None for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % epoch) loss, _ = sess.run([self._loss, self._train_op], feed_dict) if not loss_before_opt: loss_before_opt = loss # if self._verbose: # logger.log("Epoch: %d | Loss: %f" % (epoch, new_loss)) # # if abs(last_loss - new_loss) < self._tolerance: # break # last_loss = new_loss return loss_before_opt
def validate(self, sess): """ Tests policy on env using algo Pseudocode: for itr in n_itr: for step in num_inner_grad_steps: sampler.sample() algo.compute_updated_dists() algo.optimize_policy() sampler.update_goals() """ # initialize uninitialized vars (only initialize vars that were not loaded) avg_returns = [] for itr in range(self.n_itr): logger.log("\n ---------------- Iteration %d ----------------" % itr) self.policy.switch_to_pre_update() # Switch to pre-update policy self.sampler.update_tasks() for step in range(self.num_inner_grad_steps + 1): logger.log('** Step ' + str(step) + ' **') """ -------------------- Sampling --------------------------""" paths = self.sampler.obtain_samples(log=False) """ ----------------- Processing Samples ---------------------""" samples_data = self.sample_processor.process_samples( paths, log='all', log_prefix='Step_%d-' % step) self.log_diagnostics(sum(list(paths.values()), []), prefix='Step_%d-' % step) """ ------------------- Inner Policy Update --------------------""" if step < self.num_inner_grad_steps: self.algo._adapt(samples_data) avg_returns.append(self.sample_processor.avg_return) logger.log(f'Average validation reward: {np.mean(avg_returns)}') return np.mean(avg_returns)
def train(self, tester): """ Trains policy on env using algo Pseudocode: for itr in n_itr: for step in num_inner_grad_steps: sampler.sample() algo.compute_updated_dists() algo.optimize_policy() sampler.update_goals() """ best_train_reward = -np.inf with self.sess.as_default() as sess: # initialize uninitialized vars (only initialize vars that were not loaded) uninit_vars = [ var for var in tf.compat.v1.global_variables() if not sess.run(tf.compat.v1.is_variable_initialized(var)) ] sess.run(tf.compat.v1.variables_initializer(uninit_vars)) start_time = time.time() for itr in range(self.start_itr, self.n_itr): itr_start_time = time.time() logger.log( "\n ---------------- Iteration %d ----------------" % itr) logger.log( "Sampling set of tasks/goals for this meta-batch...") self.sampler.update_tasks() self.policy.switch_to_pre_update( ) # Switch to pre-update policy all_samples_data, all_paths = [], [] list_sampling_time, list_inner_step_time, list_outer_step_time, list_proc_samples_time = [], [], [], [] start_total_inner_time = time.time() for step in range(self.num_inner_grad_steps + 1): logger.log('** Step ' + str(step) + ' **') """ -------------------- Sampling --------------------------""" logger.log("Obtaining samples...") time_env_sampling_start = time.time() paths = self.sampler.obtain_samples(log=True, log_prefix='Step_%d-' % step) list_sampling_time.append(time.time() - time_env_sampling_start) all_paths.append(paths) """ ----------------- Processing Samples ---------------------""" logger.log("Processing samples...") time_proc_samples_start = time.time() samples_data = self.sample_processor.process_samples( paths, log='all', log_prefix='Step_%d-' % step) all_samples_data.append(samples_data) list_proc_samples_time.append(time.time() - time_proc_samples_start) self.log_diagnostics(sum(list(paths.values()), []), prefix='Step_%d-' % step) """ ------------------- Inner Policy Update --------------------""" time_inner_step_start = time.time() if step < self.num_inner_grad_steps: logger.log("Computing inner policy updates...") self.algo._adapt(samples_data) # train_writer = tf.summary.FileWriter('/home/ignasi/Desktop/maml_zoo_graph', # sess.graph) list_inner_step_time.append(time.time() - time_inner_step_start) total_inner_time = time.time() - start_total_inner_time time_maml_opt_start = time.time() """ ------------------ Outer Policy Update ---------------------""" logger.log("Optimizing policy...") # This needs to take all samples_data so that it can construct graph for meta-optimization. time_outer_step_start = time.time() self.algo.optimize_policy(all_samples_data) """ ------------------- Logging Stuff --------------------------""" logger.logkv('Itr', itr) logger.logkv('n_timesteps', self.sampler.total_timesteps_sampled) logger.logkv('Time-OuterStep', time.time() - time_outer_step_start) logger.logkv('Time-TotalInner', total_inner_time) logger.logkv('Time-InnerStep', np.sum(list_inner_step_time)) logger.logkv('Time-SampleProc', np.sum(list_proc_samples_time)) logger.logkv('Time-Sampling', np.sum(list_sampling_time)) logger.logkv('Time', time.time() - start_time) logger.logkv('ItrTime', time.time() - itr_start_time) logger.logkv('Time-MAMLSteps', time.time() - time_maml_opt_start) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr) logger.save_itr_params(itr, params) logger.log("Saved") logger.dumpkvs() if itr == 0: sess.graph.finalize() if (itr % 80 == 79) or ( (self.sample_processor.avg_return > best_train_reward * 1.05) and (itr > 50)) and self.sample_processor.avg_return > 300.0: if self.sample_processor.avg_return > best_train_reward: best_train_reward = self.sample_processor.avg_return print('TESTING') delim = '====================\n' print(delim, delim, delim) val_reward = tester.validate(sess) if self.best_val_reward < val_reward: self.best_itr = itr self.best_val_reward = val_reward print(delim, delim, delim) logger.log("Training finished") logger.log( f"Best iteration is {self.best_itr} with reward {self.best_val_reward}" ) self.sess.close() return self.best_itr
def optimize(self, input_val_dict): """ Carries out the optimization step Args: inputs (list): inputs for the optimization extra_inputs (list): extra inputs for the optimization subsample_grouped_inputs (None or list): subsample data from each element of the list """ logger.log("Start CG optimization") logger.log("computing loss before") loss_before = self.loss(input_val_dict) logger.log("performing update") logger.log("computing gradient") gradient = self.gradient(input_val_dict) logger.log("gradient computed") logger.log("computing descent direction") Hx = self._hvp_approach.build_eval(input_val_dict) descent_direction = conjugate_gradients(Hx, gradient, cg_iters=self._cg_iters) rat = 1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8) initial_step_size = np.sqrt(2.0 * self._max_constraint_val * rat) if np.isnan(initial_step_size): logger.log("Initial step size is NaN! Rejecting the step!") return initial_descent_step = initial_step_size * descent_direction logger.log("descent direction computed") prev_params = self._target.get_param_values() prev_params_values = _flatten_params(prev_params) loss, constraint_val, n_iter, violated = 0, 0, 0, False for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange( self._max_backtracks)): cur_step = ratio * initial_descent_step cur_params_values = prev_params_values - cur_step cur_params = _unflatten_params(cur_params_values, params_example=prev_params) self._target.set_params(cur_params) loss, constraint_val = self.loss( input_val_dict), self.constraint_val(input_val_dict) if loss < loss_before and constraint_val <= self._max_constraint_val: break """ ------------------- Logging Stuff -------------------------- """ if np.isnan(loss): violated = True logger.log("Line search violated because loss is NaN") if np.isnan(constraint_val): violated = True logger.log("Line search violated because constraint %s is NaN" % self._constraint_name) if loss >= loss_before: violated = True logger.log("Line search violated because loss not improving") if constraint_val >= self._max_constraint_val: violated = True logger.log( "Line search violated because constraint %s is violated" % self._constraint_name) if violated and not self._accept_violation: logger.log("Line search condition violated. Rejecting the step!") self._target.set_params(prev_params) logger.log("backtrack iters: %d" % n_iter) logger.log("computing loss after") logger.log("optimization finished")
def train(self): """ Trains policy on env using algo Pseudocode: for itr in n_itr: for step in num_inner_grad_steps: sampler.sample() algo.compute_updated_dists() algo.optimize_policy() sampler.update_goals() """ with self.sess.as_default() as sess: # initialize uninitialized vars (only initialize vars that were not loaded) uninit_vars = [ var for var in tf.compat.v1.global_variables() if not sess.run(tf.compat.v1.is_variable_initialized(var)) ] sess.run(tf.compat.v1.variables_initializer(uninit_vars)) start_time = time.time() for itr in range(self.start_itr, self.n_itr): self.task = self.env.sample_tasks(self.sampler.meta_batch_size) self.sampler.set_tasks(self.task) itr_start_time = time.time() logger.log( "\n ---------------- Iteration %d ----------------" % itr) logger.log( "Sampling set of tasks/goals for this meta-batch...") """ -------------------- Sampling --------------------------""" logger.log("Obtaining samples...") time_env_sampling_start = time.time() paths = self.sampler.obtain_samples(log=True, log_prefix='train-') sampling_time = time.time() - time_env_sampling_start """ ----------------- Processing Samples ---------------------""" logger.log("Processing samples...") time_proc_samples_start = time.time() samples_data = self.sample_processor.process_samples( paths, log='all', log_prefix='train-') proc_samples_time = time.time() - time_proc_samples_start self.log_diagnostics(sum(paths.values(), []), prefix='train-') """ ------------------ Policy Update ---------------------""" logger.log("Optimizing policy...") # This needs to take all samples_data so that it can construct graph for meta-optimization. time_optimization_step_start = time.time() self.algo.optimize_policy(samples_data) """ ------------------- Logging Stuff --------------------------""" logger.logkv('Itr', itr) logger.logkv('n_timesteps', self.sampler.total_timesteps_sampled) logger.logkv('Time-Optimization', time.time() - time_optimization_step_start) logger.logkv('Time-SampleProc', np.sum(proc_samples_time)) logger.logkv('Time-Sampling', sampling_time) logger.logkv('Time', time.time() - start_time) logger.logkv('ItrTime', time.time() - itr_start_time) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr) logger.save_itr_params(itr, params) logger.log("Saved") logger.dumpkvs() if itr == 0: sess.graph.finalize() logger.log("Training finished") self.sess.close()