class MAMLNPO(BatchMAMLPolopt): """ Natural Policy Optimization. """ def __init__(self, optimizer=None, optimizer_args=None, step_size=0.01, use_maml=True, **kwargs): assert optimizer is not None # only for use with MAML TRPO self.optimizer = optimizer self.offPolicy_optimizer = FirstOrderOptimizer(max_epochs=1) self.step_size = step_size self.use_maml = use_maml self.kl_constrain_step = -1 # needs to be 0 or -1 (original pol params, or new pol params) super(MAMLNPO, self).__init__(**kwargs) def make_vars(self, stepnum='0'): # lists over the meta_batch_size obs_vars, action_vars, adv_vars, imp_vars = [], [], [], [] for i in range(self.meta_batch_size): obs_vars.append( self.env.observation_space.new_tensor_variable( 'obs' + stepnum + '_' + str(i), extra_dims=1, )) action_vars.append( self.env.action_space.new_tensor_variable( 'action' + stepnum + '_' + str(i), extra_dims=1, )) adv_vars.append( tensor_utils.new_tensor( name='advantage' + stepnum + '_' + str(i), ndim=1, dtype=tf.float32, )) imp_vars.append( tensor_utils.new_tensor( name='imp_ratios' + stepnum + '_' + str(i), ndim=1, dtype=tf.float32, )) return obs_vars, action_vars, adv_vars, imp_vars @overrides def init_opt(self): is_recurrent = int(self.policy.recurrent) assert not is_recurrent # not supported dist = self.policy.distribution old_dist_info_vars, old_dist_info_vars_list = [], [] for i in range(self.meta_batch_size): old_dist_info_vars.append({ k: tf.placeholder(tf.float32, shape=[None] + list(shape), name='old_%s_%s' % (i, k)) for k, shape in dist.dist_info_specs }) old_dist_info_vars_list += [ old_dist_info_vars[i][k] for k in dist.dist_info_keys ] state_info_vars, state_info_vars_list = {}, [] all_surr_objs, input_list = [], [] new_params = None for j in range(self.num_grad_updates): obs_vars, action_vars, adv_vars, _ = self.make_vars(str(j)) surr_objs = [] cur_params = new_params new_params = [ ] # if there are several grad_updates the new_params are overwritten kls = [] for i in range(self.meta_batch_size): if j == 0: dist_info_vars, params = self.policy.dist_info_sym( obs_vars[i], state_info_vars, all_params=self.policy.all_params) if self.kl_constrain_step == 0: kl = dist.kl_sym(old_dist_info_vars[i], dist_info_vars) kls.append(kl) else: dist_info_vars, params = self.policy.updated_dist_info_sym( i, all_surr_objs[-1][i], obs_vars[i], params_dict=cur_params[i]) new_params.append(params) logli = dist.log_likelihood_sym(action_vars[i], dist_info_vars) # formulate as a minimization problem # The gradient of the surrogate objective is the policy gradient surr_objs.append(-tf.reduce_mean(logli * adv_vars[i])) input_list += obs_vars + action_vars + adv_vars + state_info_vars_list if j == 0: # For computing the fast update for sampling self.policy.set_init_surr_obj(input_list, surr_objs) init_input_list = input_list all_surr_objs.append(surr_objs) obs_vars, action_vars, adv_vars, _ = self.make_vars('test') surr_objs = [] for i in range(self.meta_batch_size): dist_info_vars, _ = self.policy.updated_dist_info_sym( i, all_surr_objs[-1][i], obs_vars[i], params_dict=new_params[i]) if self.kl_constrain_step == -1: # if we only care about the kl of the last step, the last item in kls will be the overall kl = dist.kl_sym(old_dist_info_vars[i], dist_info_vars) kls.append(kl) lr = dist.likelihood_ratio_sym(action_vars[i], old_dist_info_vars[i], dist_info_vars) surr_objs.append(-tf.reduce_mean(lr * adv_vars[i])) if self.use_maml: surr_obj = tf.reduce_mean(tf.stack( surr_objs, 0)) # mean over meta_batch_size (the diff tasks) input_list += obs_vars + action_vars + adv_vars + old_dist_info_vars_list else: surr_obj = tf.reduce_mean( tf.stack(all_surr_objs[0], 0)) # if not meta, just use the first surr_obj input_list = init_input_list if self.use_maml: mean_kl = tf.reduce_mean( tf.concat(kls, 0) ) ##CF shouldn't this have the option of self.kl_constrain_step == -1? max_kl = tf.reduce_max(tf.concat(kls, 0)) self.optimizer.update_opt(loss=surr_obj, target=self.policy, leq_constraint=(mean_kl, self.step_size), inputs=input_list, constraint_name="mean_kl") else: self.optimizer.update_opt( loss=surr_obj, target=self.policy, inputs=input_list, ) return dict() @overrides def init_opt_offPolicy(self): is_recurrent = int(self.policy.recurrent) assert not is_recurrent # not supported dist = self.policy.distribution state_info_vars, state_info_vars_list = {}, [] all_surr_objs, input_list = [], [] new_params = None for j in range(self.num_grad_updates): obs_vars, action_vars, adv_vars, imp_vars = self.make_vars(str(j)) surr_objs = [] cur_params = new_params new_params = [ ] # if there are several grad_updates the new_params are overwritten for i in range(self.meta_batch_size): if j == 0: dist_info_vars, params = self.policy.dist_info_sym( obs_vars[i], state_info_vars, all_params=self.policy.all_params) else: dist_info_vars, params = self.policy.updated_dist_info_sym( i, all_surr_objs[-1][i], obs_vars[i], params_dict=cur_params[i]) new_params.append(params) logli = dist.log_likelihood_sym(action_vars[i], dist_info_vars) # formulate as a minimization problem # The gradient of the surrogate objective is the policy gradient surr_objs.append(-tf.reduce_mean(logli * imp_vars[i] * adv_vars[i])) input_list += obs_vars + action_vars + adv_vars + imp_vars all_surr_objs.append(surr_objs) obs_vars, action_vars, _, _ = self.make_vars('test') surr_objs = [] for i in range(self.meta_batch_size): dist_info_vars, _ = self.policy.updated_dist_info_sym( i, all_surr_objs[-1][i], obs_vars[i], params_dict=new_params[i]) logli = dist.log_likelihood_sym(action_vars[i], dist_info_vars) surr_objs.append(-tf.reduce_mean(logli)) surr_obj = tf.reduce_mean(tf.stack( surr_objs, 0)) # mean over meta_batch_size (the diff tasks) input_list += obs_vars + action_vars self.offPolicy_optimizer.update_opt( loss=surr_obj, target=self.policy, inputs=input_list, ) def offPolicy_optimization_step(self, samples_data, expert_data): input_list = [] #for step in range(len(all_samples_data)): # these are the gradient steps obs_list, action_list, adv_list , imp_list , expert_obs_list , expert_action_list = [], [], [] , [] , [], [] for i in range(self.meta_batch_size): inputs = ext.extract(samples_data[i], "observations", "actions", "advantages", 'traj_imp_weights') obs_list.append(inputs[0]) action_list.append(inputs[1]) adv_list.append(inputs[2]) imp_list.append(inputs[3]) expert_inputs = ext.extract(expert_data[i], "observations", "actions") expert_obs_list.append(expert_inputs[0]) expert_action_list.append(expert_inputs[1]) input_list += obs_list + action_list + adv_list + imp_list + expert_obs_list + expert_action_list self.offPolicy_optimizer.optimize(input_list) @overrides def optimize_policy(self, itr, all_samples_data): assert len( all_samples_data ) == self.num_grad_updates + 1 # we collected the rollouts to compute the grads and then the test! if not self.use_maml: all_samples_data = [all_samples_data[0]] input_list = [] for step in range( len(all_samples_data)): # these are the gradient steps obs_list, action_list, adv_list = [], [], [] for i in range(self.meta_batch_size): inputs = ext.extract(all_samples_data[step][i], "observations", "actions", "advantages") obs_list.append(inputs[0]) action_list.append(inputs[1]) adv_list.append(inputs[2]) input_list += obs_list + action_list + adv_list # [ [obs_0], [act_0], [adv_0], [obs_1], ... ] if step == 0: ##CF not used? init_inputs = input_list if self.use_maml: dist_info_list = [] for i in range(self.meta_batch_size): agent_infos = all_samples_data[ self.kl_constrain_step][i]['agent_infos'] dist_info_list += [ agent_infos[k] for k in self.policy.distribution.dist_info_keys ] input_list += tuple(dist_info_list) logger.log("Computing KL before") mean_kl_before = self.optimizer.constraint_val(input_list) logger.log("Computing loss before") loss_before = self.optimizer.loss(input_list) logger.log("Optimizing") self.optimizer.optimize(input_list) logger.log("Computing loss after") loss_after = self.optimizer.loss(input_list) if self.use_maml: logger.log("Computing KL after") mean_kl = self.optimizer.constraint_val(input_list) logger.record_tabular('MeanKLBefore', mean_kl_before) # this now won't be 0! logger.record_tabular('MeanKL', mean_kl) logger.record_tabular('LossBefore', loss_before) logger.record_tabular('LossAfter', loss_after) logger.record_tabular('dLoss', loss_before - loss_after) return dict() @overrides def get_itr_snapshot(self, itr, samples_data): return dict( itr=itr, policy=self.policy, baseline=self.baseline, env=self.env, )
class NPO(BatchPolopt): """ Natural Policy Optimization. """ def __init__(self, optimizer_class=None, optimizer_args=None, step_size=0.01, penalty=0.0, **kwargs): self.optimizer_class = default(optimizer_class, PenaltyLbfgsOptimizer) self.optimizer_args = default(optimizer_args, dict()) self.penalty = penalty self.constrain_together = penalty > 0 self.step_size = step_size self.metrics = [] super(NPO, self).__init__(**kwargs) @overrides def init_opt(self): ############################### # # Variable Definitions # ############################### all_task_dist_info_vars = [] all_obs_vars = [] for i, policy in enumerate(self.local_policies): task_obs_var = self.env_partitions[ i].observation_space.new_tensor_variable('obs%d' % i, extra_dims=1) task_dist_info_vars = [] for j, other_policy in enumerate(self.local_policies): state_info_vars = dict() # Not handling recurrent policies dist_info_vars = other_policy.dist_info_sym( task_obs_var, state_info_vars) task_dist_info_vars.append(dist_info_vars) all_obs_vars.append(task_obs_var) all_task_dist_info_vars.append(task_dist_info_vars) obs_var = self.env.observation_space.new_tensor_variable('obs', extra_dims=1) action_var = self.env.action_space.new_tensor_variable('action', extra_dims=1) advantage_var = tensor_utils.new_tensor('advantage', ndim=1, dtype=tf.float32) old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] + list(shape), name='old_%s' % k) for k, shape in self.policy.distribution.dist_info_specs } old_dist_info_vars_list = [ old_dist_info_vars[k] for k in self.policy.distribution.dist_info_keys ] input_list = [obs_var, action_var, advantage_var ] + old_dist_info_vars_list + all_obs_vars ############################### # # Local Policy Optimization # ############################### self.optimizers = [] self.metrics = [] for n, policy in enumerate(self.local_policies): state_info_vars = dict() dist_info_vars = policy.dist_info_sym(obs_var, state_info_vars) dist = policy.distribution kl = dist.kl_sym(old_dist_info_vars, dist_info_vars) lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, dist_info_vars) surr_loss = -tf.reduce_mean(lr * advantage_var) if self.constrain_together: additional_loss = Metrics.kl_on_others( n, dist, all_task_dist_info_vars) else: additional_loss = tf.constant(0.0) local_loss = surr_loss + self.penalty * additional_loss kl_metric = tensor_utils.compile_function(inputs=input_list, outputs=additional_loss, log_name="KLPenalty%d" % n) self.metrics.append(kl_metric) mean_kl_constraint = tf.reduce_mean(kl) optimizer = self.optimizer_class(**self.optimizer_args) optimizer.update_opt( loss=local_loss, target=policy, leq_constraint=(mean_kl_constraint, self.step_size), inputs=input_list, constraint_name="mean_kl_%d" % n, ) self.optimizers.append(optimizer) ############################### # # Global Policy Optimization # ############################### # Behaviour Cloning Loss state_info_vars = dict() center_dist_info_vars = self.policy.dist_info_sym( obs_var, state_info_vars) behaviour_cloning_loss = tf.losses.mean_squared_error( action_var, center_dist_info_vars['mean']) self.center_optimizer = FirstOrderOptimizer(max_epochs=1, verbose=True, batch_size=1000) self.center_optimizer.update_opt(behaviour_cloning_loss, self.policy, [obs_var, action_var]) # TRPO Loss kl = dist.kl_sym(old_dist_info_vars, center_dist_info_vars) lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, center_dist_info_vars) center_trpo_loss = -tf.reduce_mean(lr * advantage_var) mean_kl_constraint = tf.reduce_mean(kl) optimizer = self.optimizer_class(**self.optimizer_args) optimizer.update_opt( loss=center_trpo_loss, target=self.policy, leq_constraint=(mean_kl_constraint, self.step_size), inputs=[obs_var, action_var, advantage_var] + old_dist_info_vars_list, constraint_name="mean_kl_center", ) self.center_trpo_optimizer = optimizer # Reset Local Policies to Global Policy assignment_operations = [] for policy in self.local_policies: for param_local, param_center in zip( policy.get_params_internal(), self.policy.get_params_internal()): if 'std' not in param_local.name: assignment_operations.append( tf.assign(param_local, param_center)) self.reset_to_center = tf.group(*assignment_operations) return dict() def optimize_local_policies(self, itr, all_samples_data): dist_info_keys = self.policy.distribution.dist_info_keys for n, optimizer in enumerate(self.optimizers): obs_act_adv_values = tuple( ext.extract(all_samples_data[n], "observations", "actions", "advantages")) dist_info_list = tuple([ all_samples_data[n]["agent_infos"][k] for k in dist_info_keys ]) all_task_obs_values = tuple([ samples_data["observations"] for samples_data in all_samples_data ]) all_input_values = obs_act_adv_values + dist_info_list + all_task_obs_values optimizer.optimize(all_input_values) kl_penalty = sliced_fun(self.metrics[n], 1)(all_input_values) logger.record_tabular('KLPenalty%d' % n, kl_penalty) def optimize_global_policy(self, itr, all_samples_data): all_observations = np.concatenate([ samples_data['observations'] for samples_data in all_samples_data ]) all_actions = np.concatenate([ samples_data['agent_infos']['mean'] for samples_data in all_samples_data ]) num_itrs = 1 if itr % self.distillation_period != 0 else 30 for _ in range(num_itrs): self.center_optimizer.optimize([all_observations, all_actions]) paths = self.global_sampler.obtain_samples(itr) samples_data = self.global_sampler.process_samples(itr, paths) obs_values = tuple( ext.extract(samples_data, "observations", "actions", "advantages")) dist_info_list = [ samples_data["agent_infos"][k] for k in self.policy.distribution.dist_info_keys ] all_input_values = obs_values + tuple(dist_info_list) self.center_trpo_optimizer.optimize(all_input_values) self.env.log_diagnostics(paths) @overrides def optimize_policy(self, itr, all_samples_data): self.optimize_local_policies(itr, all_samples_data) self.optimize_global_policy(itr, all_samples_data) if itr % self.distillation_period == 0: sess = tf.get_default_session() sess.run(self.reset_to_center) logger.log('Reset Local Policies to Global Policies') return dict()