def make_vars(self, stepnum='0'): # lists over the meta_batch_size obs_vars, action_vars, adv_vars = [], [], [] for i in range(self.meta_batch_size): obs_vars.append(self.env.observation_space.new_tensor_variable( 'obs' + stepnum + '_' + str(i), extra_dims=1, )) action_vars.append(self.env.action_space.new_tensor_variable( 'action' + stepnum + '_' + str(i), extra_dims=1, )) adv_vars.append(tensor_utils.new_tensor( name='advantage' + stepnum + '_' + str(i), ndim=1, dtype=tf.float32, )) return obs_vars, action_vars, adv_vars
def init_opt(self): is_recurrent = int(self.policy.recurrent) obs_var = self.env.observation_space.new_tensor_variable( 'obs', extra_dims=1 + is_recurrent, ) action_var = self.env.action_space.new_tensor_variable( 'action', extra_dims=1 + is_recurrent, ) advantage_var = tensor_utils.new_tensor( name='advantage', ndim=1 + is_recurrent, dtype=tf.float32, ) dist = self.policy.distribution old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name='old_%s' % k) for k, shape in dist.dist_info_specs } old_dist_info_vars_list = [ old_dist_info_vars[k] for k in dist.dist_info_keys ] state_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name=k) for k, shape in self.policy.state_info_specs } state_info_vars_list = [ state_info_vars[k] for k in self.policy.state_info_keys ] if is_recurrent: valid_var = tf.placeholder(tf.float32, shape=[None, None], name="valid") else: valid_var = None dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars) logli = dist.log_likelihood_sym(action_var, dist_info_vars) kl = dist.kl_sym(old_dist_info_vars, dist_info_vars) # formulate as a minimization problem # The gradient of the surrogate objective is the policy gradient if is_recurrent: surr_obj = -tf.reduce_sum( logli * advantage_var * valid_var) / tf.reduce_sum(valid_var) mean_kl = tf.reduce_sum(kl * valid_var) / tf.reduce_sum(valid_var) max_kl = tf.reduce_max(kl * valid_var) else: surr_obj = -tf.reduce_mean(logli * advantage_var) mean_kl = tf.reduce_mean(kl) max_kl = tf.reduce_max(kl) input_list = [obs_var, action_var, advantage_var ] + state_info_vars_list if is_recurrent: input_list.append(valid_var) #self.policy.set_init_surr_obj(input_list, [surr_obj]) # debugging self.optimizer.update_opt(loss=surr_obj, target=self.policy, inputs=input_list) f_kl = tensor_utils.compile_function( inputs=input_list + old_dist_info_vars_list, outputs=[mean_kl, max_kl], ) self.opt_info = dict(f_kl=f_kl, )
def init_opt(self): is_recurrent = int(self.policy.recurrent) obs_var = self.env.observation_space.new_tensor_variable( 'obs', extra_dims=1 + is_recurrent, ) action_var = self.env.action_space.new_tensor_variable( 'action', extra_dims=1 + is_recurrent, ) advantage_var = tensor_utils.new_tensor( 'advantage', ndim=1 + is_recurrent, dtype=tf.float32, ) dist = self.policy.distribution old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name='old_%s' % k) for k, shape in dist.dist_info_specs } old_dist_info_vars_list = [ old_dist_info_vars[k] for k in dist.dist_info_keys ] state_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name=k) for k, shape in self.policy.state_info_specs } state_info_vars_list = [ state_info_vars[k] for k in self.policy.state_info_keys ] if is_recurrent: valid_var = tf.placeholder(tf.float32, shape=[None, None], name="valid") else: valid_var = None dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars) kl = dist.kl_sym(old_dist_info_vars, dist_info_vars) lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, dist_info_vars) if is_recurrent: mean_kl = tf.reduce_sum(kl * valid_var) / tf.reduce_sum(valid_var) surr_loss = -tf.reduce_sum( lr * advantage_var * valid_var) / tf.reduce_sum(valid_var) else: mean_kl = tf.reduce_mean(kl) surr_loss = -tf.reduce_mean(lr * advantage_var) input_list = [ obs_var, action_var, advantage_var, ] + state_info_vars_list + old_dist_info_vars_list if is_recurrent: input_list.append(valid_var) self.optimizer.update_opt(loss=surr_loss, target=self.policy, leq_constraint=(mean_kl, self.step_size), inputs=input_list, constraint_name="mean_kl") return dict()