def _build_inference_loss(self, i): """ Build loss function for the inference network """ infer_dist = self.inference._dist with tf.name_scope("infer_loss"): traj_ll_flat = self.inference.log_likelihood_sym( i.flat.trajectory_var, i.flat.latent_var, name="traj_ll_flat") traj_ll = tf.reshape(traj_ll_flat, [-1, self.max_path_length], name="traj_ll") # Calculate loss traj_gammas = tf.constant(float(self.discount), dtype=tf.float32, shape=[self.max_path_length]) traj_discounts = tf.cumprod(traj_gammas, exclusive=True, name="traj_discounts") discount_traj_ll = traj_discounts * traj_ll discount_traj_ll_flat = flatten_batch(discount_traj_ll, name="discount_traj_ll_flat") discount_traj_ll_valid = filter_valids( discount_traj_ll_flat, i.flat.valid_var, name="discount_traj_ll_valid") with tf.name_scope("loss"): infer_loss = -tf.reduce_mean(discount_traj_ll_valid, name="infer_loss") with tf.name_scope("kl"): # Calculate predicted embedding distributions for each timestep infer_dist_info_flat = self.inference.dist_info_sym( i.flat.trajectory_var, i.flat.infer_state_info_vars, name="infer_dist_info_flat") infer_dist_info_valid = filter_valids_dict( infer_dist_info_flat, i.flat.valid_var, name="infer_dist_info_valid") # Calculate KL divergence kl = infer_dist.kl_sym(i.valid.infer_old_dist_info_vars, infer_dist_info_valid) infer_kl = tf.reduce_mean(kl, name="infer_kl") return infer_loss, infer_kl
def _build_policy_loss(self, i): """ Build policy network loss """ pol_dist = self.policy._dist # Entropy terms embedding_entropy, inference_ce, policy_entropy = \ self._build_entropy_terms(i) # Augment the path rewards with entropy terms with tf.name_scope("augmented_rewards"): rewards = i.reward_var \ - (self.inference_ce_coeff * inference_ce) \ + (self.policy_ent_coeff * policy_entropy) with tf.name_scope("policy_loss"): with tf.name_scope("advantages"): advantages = compute_advantages(self.discount, self.gae_lambda, self.max_path_length, i.baseline_var, rewards, name="advantages") # Flatten and filter valids adv_flat = flatten_batch(advantages, name="adv_flat") adv_valid = filter_valids(adv_flat, i.flat.valid_var, name="adv_valid") policy_dist_info_flat = self.policy.dist_info_sym( i.flat.task_var, i.flat.obs_var, i.flat.policy_state_info_vars, name="policy_dist_info_flat") policy_dist_info_valid = filter_valids_dict( policy_dist_info_flat, i.flat.valid_var, name="policy_dist_info_valid") # Optionally normalize advantages eps = tf.constant(1e-8, dtype=tf.float32) if self.center_adv: with tf.name_scope("center_adv"): mean, var = tf.nn.moments(adv_valid, axes=[0]) adv_valid = tf.nn.batch_normalization( adv_valid, mean, var, 0, 1, eps) if self.positive_adv: with tf.name_scope("positive_adv"): m = tf.reduce_min(adv_valid) adv_valid = (adv_valid - m) + eps # Calculate loss function and KL divergence with tf.name_scope("kl"): kl = pol_dist.kl_sym( i.valid.policy_old_dist_info_vars, policy_dist_info_valid, ) pol_mean_kl = tf.reduce_mean(kl) # Calculate surrogate loss with tf.name_scope("surr_loss"): lr = pol_dist.likelihood_ratio_sym( i.valid.action_var, i.valid.policy_old_dist_info_vars, policy_dist_info_valid, name="lr") # Policy gradient surrogate objective surr_vanilla = lr * adv_valid if self._pg_loss == PGLoss.VANILLA: # VPG, TRPO use the standard surrogate objective surr_obj = tf.identity(surr_vanilla, name="surr_obj") elif self._pg_loss == PGLoss.CLIP: # PPO uses a surrogate objective with clipped LR lr_clip = tf.clip_by_value(lr, 1 - self.lr_clip_range, 1 + self.lr_clip_range, name="lr_clip") surr_clip = lr_clip * adv_valid surr_obj = tf.minimum(surr_vanilla, surr_clip, name="surr_obj") else: raise NotImplementedError("Unknown PGLoss") # Maximize E[surrogate objective] by minimizing # -E_t[surrogate objective] surr_loss = -tf.reduce_mean(surr_obj) # Embedding entropy bonus surr_loss -= self.embedding_ent_coeff * embedding_entropy embed_mean_kl = self._build_embedding_kl(i) # Diagnostic functions self.f_policy_kl = tensor_utils.compile_function( flatten_inputs(self._policy_opt_inputs), pol_mean_kl, log_name="f_policy_kl") self.f_rewards = tensor_utils.compile_function(flatten_inputs( self._policy_opt_inputs), rewards, log_name="f_rewards") # returns = self._build_returns(rewards) returns = discounted_returns(self.discount, self.max_path_length, rewards, name="returns") self.f_returns = tensor_utils.compile_function(flatten_inputs( self._policy_opt_inputs), returns, log_name="f_returns") return surr_loss, pol_mean_kl, embed_mean_kl
def _build_policy_loss(self, i): pol_dist = self.policy.distribution policy_entropy = self._build_entropy_term(i) with tf.name_scope("augmented_rewards"): rewards = i.reward_var + (self.policy_ent_coeff * policy_entropy) with tf.name_scope("policy_loss"): advantages = compute_advantages( self.discount, self.gae_lambda, self.max_path_length, i.baseline_var, rewards, name="advantages") adv_flat = flatten_batch(advantages, name="adv_flat") adv_valid = filter_valids( adv_flat, i.flat.valid_var, name="adv_valid") if self.policy.recurrent: advantages = tf.reshape(advantages, [-1, self.max_path_length]) # Optionally normalize advantages eps = tf.constant(1e-8, dtype=tf.float32) if self.center_adv: with tf.name_scope("center_adv"): mean, var = tf.nn.moments(adv_valid, axes=[0]) adv_valid = tf.nn.batch_normalization( adv_valid, mean, var, 0, 1, eps) if self.positive_adv: with tf.name_scope("positive_adv"): m = tf.reduce_min(adv_valid) adv_valid = (adv_valid - m) + eps if self.policy.recurrent: policy_dist_info = self.policy.dist_info_sym( i.obs_var, i.policy_state_info_vars, name="policy_dist_info") else: policy_dist_info_flat = self.policy.dist_info_sym( i.flat.obs_var, i.flat.policy_state_info_vars, name="policy_dist_info_flat") policy_dist_info_valid = filter_valids_dict( policy_dist_info_flat, i.flat.valid_var, name="policy_dist_info_valid") # Calculate loss function and KL divergence with tf.name_scope("kl"): if self.policy.recurrent: kl = pol_dist.kl_sym( i.policy_old_dist_info_vars, policy_dist_info, ) pol_mean_kl = tf.reduce_sum( kl * i.valid_var) / tf.reduce_sum(i.valid_var) else: kl = pol_dist.kl_sym( i.valid.policy_old_dist_info_vars, policy_dist_info_valid, ) pol_mean_kl = tf.reduce_mean(kl) # Calculate vanilla loss with tf.name_scope("vanilla_loss"): if self.policy.recurrent: ll = pol_dist.log_likelihood_sym( i.action_var, policy_dist_info, name="log_likelihood") vanilla = ll * advantages * i.valid_var else: ll = pol_dist.log_likelihood_sym( i.valid.action_var, policy_dist_info_valid, name="log_likelihood") vanilla = ll * adv_valid # Calculate surrogate loss with tf.name_scope("surrogate_loss"): if self.policy.recurrent: lr = pol_dist.likelihood_ratio_sym( i.action_var, i.policy_old_dist_info_vars, policy_dist_info, name="lr") surrogate = lr * advantages * i.valid_var else: lr = pol_dist.likelihood_ratio_sym( i.valid.action_var, i.valid.policy_old_dist_info_vars, policy_dist_info_valid, name="lr") surrogate = lr * adv_valid # Finalize objective function with tf.name_scope("loss"): if self._pg_loss == PGLoss.VANILLA: # VPG uses the vanilla objective obj = tf.identity(vanilla, name="vanilla_obj") elif self._pg_loss == PGLoss.SURROGATE: # TRPO uses the standard surrogate objective obj = tf.identity(surrogate, name="surr_obj") elif self._pg_loss == PGLoss.SURROGATE_CLIP: lr_clip = tf.clip_by_value( lr, 1 - self.lr_clip_range, 1 + self.lr_clip_range, name="lr_clip") if self.policy.recurrent: surr_clip = lr_clip * advantages * i.valid_var else: surr_clip = lr_clip * adv_valid obj = tf.minimum(surrogate, surr_clip, name="surr_obj") else: raise NotImplementedError("Unknown PGLoss") # Maximize E[surrogate objective] by minimizing # -E_t[surrogate objective] if self.policy.recurrent: loss = -tf.reduce_sum(obj) / tf.reduce_sum(i.valid_var) else: loss = -tf.reduce_mean(obj) # Diagnostic functions self.f_policy_kl = tensor_utils.compile_function( flatten_inputs(self._policy_opt_inputs), pol_mean_kl, log_name="f_policy_kl") self.f_rewards = tensor_utils.compile_function( flatten_inputs(self._policy_opt_inputs), rewards, log_name="f_rewards") returns = discounted_returns(self.discount, self.max_path_length, rewards) self.f_returns = tensor_utils.compile_function( flatten_inputs(self._policy_opt_inputs), returns, log_name="f_returns") return loss, pol_mean_kl
def _build_inputs(self): """ Builds input variables (and trivial views thereof) for the loss function network """ observation_space = self.policy.observation_space action_space = self.policy.action_space task_space = self.policy.task_space latent_space = self.policy.latent_space trajectory_space = self.inference.input_space policy_dist = self.policy._dist embed_dist = self.policy.embedding._dist infer_dist = self.inference._dist with tf.name_scope("inputs"): obs_var = observation_space.new_tensor_variable( 'obs', extra_dims=1 + 1, ) task_var = task_space.new_tensor_variable( 'task', extra_dims=1 + 1, ) action_var = action_space.new_tensor_variable( 'action', extra_dims=1 + 1, ) reward_var = tensor_utils.new_tensor( 'reward', ndim=1 + 1, dtype=tf.float32, ) latent_var = latent_space.new_tensor_variable( 'latent', extra_dims=1 + 1, ) baseline_var = tensor_utils.new_tensor( 'baseline', ndim=1 + 1, dtype=tf.float32, ) trajectory_var = trajectory_space.new_tensor_variable( 'trajectory', extra_dims=1 + 1, ) valid_var = tf.placeholder(tf.float32, shape=[None, None], name="valid") # Policy state (for RNNs) policy_state_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name=k) for k, shape in self.policy.state_info_specs } policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # Old policy distribution (for KL) policy_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='policy_old_%s' % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] # Embedding state (for RNNs) embed_state_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='embed_%s' % k) for k, shape in self.policy.embedding.state_info_specs } embed_state_info_vars_list = [ embed_state_info_vars[k] for k in self.policy.embedding.state_info_keys ] # Old embedding distribution (for KL) embed_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='embed_old_%s' % k) for k, shape in embed_dist.dist_info_specs } embed_old_dist_info_vars_list = [ embed_old_dist_info_vars[k] for k in embed_dist.dist_info_keys ] # Inference distribution state (for RNNs) infer_state_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='infer_%s' % k) for k, shape in self.inference.state_info_specs } infer_state_info_vars_list = [ infer_state_info_vars[k] for k in self.inference.state_info_keys ] # Old inference distribution (for KL) infer_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='infer_old_%s' % k) for k, shape in infer_dist.dist_info_specs } infer_old_dist_info_vars_list = [ infer_old_dist_info_vars[k] for k in infer_dist.dist_info_keys ] # Flattened view with tf.name_scope("flat"): obs_flat = flatten_batch(obs_var, name="obs_flat") task_flat = flatten_batch(task_var, name="task_flat") action_flat = flatten_batch(action_var, name="action_flat") reward_flat = flatten_batch(reward_var, name="reward_flat") latent_flat = flatten_batch(latent_var, name="latent_flat") trajectory_flat = flatten_batch(trajectory_var, name="trajectory_flat") valid_flat = flatten_batch(valid_var, name="valid_flat") policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name="policy_state_info_vars_flat") policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name="policy_old_dist_info_vars_flat") embed_state_info_vars_flat = flatten_batch_dict( embed_state_info_vars, name="embed_state_info_vars_flat") embed_old_dist_info_vars_flat = flatten_batch_dict( embed_old_dist_info_vars, name="embed_old_dist_info_vars_flat") infer_state_info_vars_flat = flatten_batch_dict( infer_state_info_vars, name="infer_state_info_vars_flat") infer_old_dist_info_vars_flat = flatten_batch_dict( infer_old_dist_info_vars, name="infer_old_dist_info_vars_flat") # Valid view with tf.name_scope("valid"): action_valid = filter_valids(action_flat, valid_flat, name="action_valid") policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name="policy_state_info_vars_valid") policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name="policy_old_dist_info_vars_valid") embed_old_dist_info_vars_valid = filter_valids_dict( embed_old_dist_info_vars_flat, valid_flat, name="embed_old_dist_info_vars_valid") infer_old_dist_info_vars_valid = filter_valids_dict( infer_old_dist_info_vars_flat, valid_flat, name="infer_old_dist_info_vars_valid") # Policy and embedding network loss and optimizer inputs pol_flat = graph_inputs( "PolicyLossInputsFlat", obs_var=obs_flat, task_var=task_flat, action_var=action_flat, reward_var=reward_flat, latent_var=latent_flat, trajectory_var=trajectory_flat, valid_var=valid_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, embed_state_info_vars=embed_state_info_vars_flat, embed_old_dist_info_vars=embed_old_dist_info_vars_flat, ) pol_valid = graph_inputs( "PolicyLossInputsValid", action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, embed_old_dist_info_vars=embed_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( "PolicyLossInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, trajectory_var=trajectory_var, task_var=task_var, latent_var=latent_var, valid_var=valid_var, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, embed_state_info_vars=embed_state_info_vars, embed_old_dist_info_vars=embed_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) # Special variant for the optimizer # * Uses lists instead of dicts for the distribution parameters # * Omits flats and valids # TODO: eliminate policy_opt_inputs = graph_inputs( "PolicyOptInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, trajectory_var=trajectory_var, task_var=task_var, latent_var=latent_var, valid_var=valid_var, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, embed_state_info_vars_list=embed_state_info_vars_list, embed_old_dist_info_vars_list=embed_old_dist_info_vars_list, ) # Inference network loss and optimizer inputs infer_flat = graph_inputs( "InferenceLossInputsFlat", latent_var=latent_flat, trajectory_var=trajectory_flat, valid_var=valid_flat, infer_state_info_vars=infer_state_info_vars_flat, infer_old_dist_info_vars=infer_old_dist_info_vars_flat, ) infer_valid = graph_inputs( "InferenceLossInputsValid", infer_old_dist_info_vars=infer_old_dist_info_vars_valid, ) inference_loss_inputs = graph_inputs( "InferenceLossInputs", latent_var=latent_var, trajectory_var=trajectory_var, valid_var=valid_var, infer_state_info_vars=infer_state_info_vars, infer_old_dist_info_vars=infer_old_dist_info_vars, flat=infer_flat, valid=infer_valid, ) # Special variant for the optimizer # * Uses lists instead of dicts for the distribution parameters # * Omits flats and valids # TODO: eliminate inference_opt_inputs = graph_inputs( "InferenceOptInputs", latent_var=latent_var, trajectory_var=trajectory_var, valid_var=valid_var, infer_state_info_vars_list=infer_state_info_vars_list, infer_old_dist_info_vars_list=infer_old_dist_info_vars_list, ) return (policy_loss_inputs, policy_opt_inputs, inference_loss_inputs, inference_opt_inputs)
def _build_inputs(self): """Decalre graph inputs variables.""" observation_space = self.policy.observation_space action_space = self.policy.action_space policy_dist = self.policy.distribution with tf.name_scope('inputs'): obs_var = observation_space.to_tf_placeholder( name='obs', batch_dims=2) # yapf: disable action_var = action_space.to_tf_placeholder( name='action', batch_dims=2) # yapf: disable reward_var = tensor_utils.new_tensor( name='reward', ndim=2, dtype=tf.float32) # yapf: disable valid_var = tensor_utils.new_tensor( name='valid', ndim=2, dtype=tf.float32) # yapf: disable feat_diff = tensor_utils.new_tensor( name='feat_diff', ndim=2, dtype=tf.float32) # yapf: disable param_v = tensor_utils.new_tensor( name='param_v', ndim=1, dtype=tf.float32) # yapf: disable param_eta = tensor_utils.new_tensor( name='param_eta', ndim=0, dtype=tf.float32) # yapf: disable policy_state_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } # yapf: disable policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # yapf: disable policy_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * 2 + list(shape), name='policy_old_%s' % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] with tf.name_scope('flat'): obs_flat = flatten_batch(obs_var, name='obs_flat') action_flat = flatten_batch(action_var, name='action_flat') reward_flat = flatten_batch(reward_var, name='reward_flat') valid_flat = flatten_batch(valid_var, name='valid_flat') feat_diff_flat = flatten_batch( feat_diff, name='feat_diff_flat') # yapf: disable policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name='policy_state_info_vars_flat') # yapf: disable policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name='policy_old_dist_info_vars_flat') with tf.name_scope('valid'): reward_valid = filter_valids( reward_flat, valid_flat, name='reward_valid') # yapf: disable action_valid = filter_valids( action_flat, valid_flat, name='action_valid') # yapf: disable policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name='policy_state_info_vars_valid') policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name='policy_old_dist_info_vars_valid') pol_flat = graph_inputs( 'PolicyLossInputsFlat', obs_var=obs_flat, action_var=action_flat, reward_var=reward_flat, valid_var=valid_flat, feat_diff=feat_diff_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, ) pol_valid = graph_inputs( 'PolicyLossInputsValid', reward_var=reward_valid, action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( 'PolicyLossInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) policy_opt_inputs = graph_inputs( 'PolicyOptInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) dual_opt_inputs = graph_inputs( 'DualOptInputs', reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs, dual_opt_inputs
def _build_inputs(self): observation_space = self.policy.observation_space action_space = self.policy.action_space policy_dist = self.policy.distribution with tf.name_scope("inputs"): obs_var = observation_space.new_tensor_variable( name="obs", extra_dims=2) action_var = action_space.new_tensor_variable( name="action", extra_dims=2) reward_var = tensor_utils.new_tensor( name="reward", ndim=2, dtype=tf.float32) valid_var = tf.placeholder( tf.float32, shape=[None, None], name="valid") baseline_var = tensor_utils.new_tensor( name="baseline", ndim=2, dtype=tf.float32) policy_state_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # old policy distribution policy_old_dist_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * 2 + list(shape), name="policy_old_%s" % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] # flattened view with tf.name_scope("flat"): obs_flat = flatten_batch(obs_var, name="obs_flat") action_flat = flatten_batch(action_var, name="action_flat") reward_flat = flatten_batch(reward_var, name="reward_flat") valid_flat = flatten_batch(valid_var, name="valid_flat") policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name="policy_state_info_vars_flat") policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name="policy_old_dist_info_vars_flat") # valid view with tf.name_scope("valid"): action_valid = filter_valids( action_flat, valid_flat, name="action_valid") policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name="policy_state_info_vars_valid") policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name="policy_old_dist_info_vars_valid") # policy loss and optimizer inputs pol_flat = graph_inputs( "PolicyLossInputsFlat", obs_var=obs_flat, action_var=action_flat, reward_var=reward_flat, valid_var=valid_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, ) pol_valid = graph_inputs( "PolicyLossInputsValid", action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( "PolicyLossInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) policy_opt_inputs = graph_inputs( "PolicyOptInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs
def _build_graph(self): self.policy = self.policy_cls(**self.policy_args, name='ddopg_model_policy') self.var_list = self.policy.get_params(trainable=True) self.var_shapes = [var.shape for var in self.var_list] self.n_params = sum( [shape.num_elements() for shape in self.var_shapes]) self.policy_dist = self.policy.distribution self.policy_params_shapes = [ param.shape for param in self.policy.get_params(trainable=True) ] observation_space = self.policy.observation_space action_space = self.policy.action_space self.obs_var = observation_space.new_tensor_variable( name='obs', extra_dims=2) # (n_paths, H, Dy) self.action_var = action_space.new_tensor_variable( name='action', extra_dims=2) # (n_paths, H, Du) self.path_return_var = tf.placeholder(tf.float32, [None], 'path_return') # (n_paths, ) self.valid_var = tf.placeholder(tf.float32, [None, None], 'valid') # (n_paths, H) self.train_logps = tf.placeholder( dtype=tf.float32, shape=[None, None], name='train_logps_pre') # (n_paths, n_train) self.log_std_var = tf.placeholder(dtype=tf.float32, shape=[], name='log_std') self.delta_var = tf.placeholder(dtype=tf.float32, shape=[], name='delta') self.input_vars = [ self.obs_var, self.action_var, self.path_return_var, self.valid_var, self.train_logps, self.delta_var, self.log_std_var ] # Flatten observation and actions for vectorized computations self.obs_flat = flatten_batch(self.obs_var, name='obs_flat') # (n_paths * H, Dy) self.action_flat = flatten_batch( self.action_var, name='action_flat') # (n_paths * H, Du) # Shape of training data: (# of paths, path horizon) = (N_train, H) self.batch_shape = tf.shape(self.obs_var)[0:2] # Compute logp for all policy dist_info_flat = self.policy.dist_info_sym(self.obs_flat, name='dist_info_flat') dist_info_flat['log_std'] = self.log_std_var * tf.ones_like( dist_info_flat['mean']) test_logp_flat = self.policy_dist.log_likelihood_sym(self.action_flat, dist_info_flat, name='logp_flat') test_logp_full = tf.reshape(test_logp_flat, self.batch_shape) # (n_epochs, H) self.test_logps = tf.reduce_sum(test_logp_full * self.valid_var, axis=1)[None, :] self.all_logps = tf.concat( ( self.train_logps, # (n_train + n_test, n_paths) self.test_logps), axis=0) # Prevent exp() overflow by shifting logps self.logp_max = tf.reduce_max(self.all_logps, axis=0) # (n_paths, ) self.train_logps_0 = self.train_logps - self.logp_max # (n_train, n_paths) self.test_logps_0 = self.test_logps - self.logp_max # (n_paths, ) self.train_liks = tf.exp(self.train_logps_0) # (n_train, n_paths) self.test_liks = tf.exp(self.test_logps_0) # (n_paths, ) # Mean traj lik for empirical mixture distribution self.train_mean_liks = tf.reduce_mean(self.train_liks, axis=0) + self.eps # (n_paths, ) # Compute prediction for all training policies train_res = self._compute_prediction_vec(self.train_liks) self.J_train = train_res[0] self.J2_train = train_res[1] self.J_var_train = train_res[2] self.J_unc_train = train_res[3] self.w_train = train_res[4] self.wn_train = train_res[5] self.ess_train = train_res[6] # Compute prediction for all test policies test_res = self._compute_prediction_vec(self.test_liks) self.J_test = test_res[0] self.J2_test = test_res[1] self.J_var_test = test_res[2] self.J_unc_test = test_res[3] self.w_test = test_res[4] self.wn_test = test_res[5] self.ess_test = test_res[6]
def _build_inputs(self): """Decalre graph inputs variables.""" observation_space = self.policy.observation_space action_space = self.policy.action_space policy_dist = self.policy.distribution with tf.name_scope("inputs"): obs_var = observation_space.new_tensor_variable( name="obs", extra_dims=2) # yapf: disable action_var = action_space.new_tensor_variable( name="action", extra_dims=2) # yapf: disable reward_var = tensor_utils.new_tensor( name="reward", ndim=2, dtype=tf.float32) # yapf: disable valid_var = tensor_utils.new_tensor( name="valid", ndim=2, dtype=tf.float32) # yapf: disable feat_diff = tensor_utils.new_tensor( name="feat_diff", ndim=2, dtype=tf.float32) # yapf: disable param_v = tensor_utils.new_tensor( name="param_v", ndim=1, dtype=tf.float32) # yapf: disable param_eta = tensor_utils.new_tensor( name="param_eta", ndim=0, dtype=tf.float32) # yapf: disable policy_state_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } # yapf: disable policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # yapf: disable policy_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * 2 + list(shape), name="policy_old_%s" % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] with tf.name_scope("flat"): obs_flat = flatten_batch(obs_var, name="obs_flat") action_flat = flatten_batch(action_var, name="action_flat") reward_flat = flatten_batch(reward_var, name="reward_flat") valid_flat = flatten_batch(valid_var, name="valid_flat") feat_diff_flat = flatten_batch( feat_diff, name="feat_diff_flat") # yapf: disable policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name="policy_state_info_vars_flat") # yapf: disable policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name="policy_old_dist_info_vars_flat") with tf.name_scope("valid"): reward_valid = filter_valids( reward_flat, valid_flat, name="reward_valid") # yapf: disable action_valid = filter_valids( action_flat, valid_flat, name="action_valid") # yapf: disable policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name="policy_state_info_vars_valid") policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name="policy_old_dist_info_vars_valid") pol_flat = graph_inputs( "PolicyLossInputsFlat", obs_var=obs_flat, action_var=action_flat, reward_var=reward_flat, valid_var=valid_flat, feat_diff=feat_diff_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, ) pol_valid = graph_inputs( "PolicyLossInputsValid", reward_var=reward_valid, action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( "PolicyLossInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) policy_opt_inputs = graph_inputs( "PolicyOptInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) dual_opt_inputs = graph_inputs( "DualOptInputs", reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs, dual_opt_inputs
def _build_policy_loss(self, i): pol_dist = self.policy.distribution policy_entropy = self._build_entropy_term(i) rewards = i.reward_var if self._maximum_entropy: with tf.name_scope('augmented_rewards'): rewards = i.reward_var + self.policy_ent_coeff * policy_entropy with tf.name_scope('policy_loss'): adv = compute_advantages(self.discount, self.gae_lambda, self.max_path_length, i.baseline_var, rewards, name='adv') adv_flat = flatten_batch(adv, name='adv_flat') adv_valid = filter_valids(adv_flat, i.flat.valid_var, name='adv_valid') if self.policy.recurrent: adv = tf.reshape(adv, [-1, self.max_path_length]) # Optionally normalize advantages eps = tf.constant(1e-8, dtype=tf.float32) if self.center_adv: if self.policy.recurrent: adv = center_advs(adv, axes=[0], eps=eps) else: adv_valid = center_advs(adv_valid, axes=[0], eps=eps) if self.positive_adv: if self.policy.recurrent: adv = positive_advs(adv, eps) else: adv_valid = positive_advs(adv_valid, eps) if self.policy.recurrent: policy_dist_info = self.policy.dist_info_sym( i.obs_var, i.policy_state_info_vars, name='policy_dist_info') else: policy_dist_info_flat = self.policy.dist_info_sym( i.flat.obs_var, i.flat.policy_state_info_vars, name='policy_dist_info_flat') policy_dist_info_valid = filter_valids_dict( policy_dist_info_flat, i.flat.valid_var, name='policy_dist_info_valid') policy_dist_info = policy_dist_info_valid # Calculate loss function and KL divergence with tf.name_scope('kl'): if self.policy.recurrent: kl = pol_dist.kl_sym( i.policy_old_dist_info_vars, policy_dist_info, ) pol_mean_kl = tf.reduce_sum( kl * i.valid_var) / tf.reduce_sum(i.valid_var) else: kl = pol_dist.kl_sym( i.valid.policy_old_dist_info_vars, policy_dist_info_valid, ) pol_mean_kl = tf.reduce_mean(kl) # Calculate vanilla loss with tf.name_scope('vanilla_loss'): if self.policy.recurrent: ll = pol_dist.log_likelihood_sym(i.action_var, policy_dist_info, name='log_likelihood') vanilla = ll * adv * i.valid_var else: ll = pol_dist.log_likelihood_sym(i.valid.action_var, policy_dist_info_valid, name='log_likelihood') vanilla = ll * adv_valid # Calculate surrogate loss with tf.name_scope('surrogate_loss'): if self.policy.recurrent: lr = pol_dist.likelihood_ratio_sym( i.action_var, i.policy_old_dist_info_vars, policy_dist_info, name='lr') surrogate = lr * adv * i.valid_var else: lr = pol_dist.likelihood_ratio_sym( i.valid.action_var, i.valid.policy_old_dist_info_vars, policy_dist_info_valid, name='lr') surrogate = lr * adv_valid # Finalize objective function with tf.name_scope('loss'): if self._pg_loss == 'vanilla': # VPG uses the vanilla objective obj = tf.identity(vanilla, name='vanilla_obj') elif self._pg_loss == 'surrogate': # TRPO uses the standard surrogate objective obj = tf.identity(surrogate, name='surr_obj') elif self._pg_loss == 'surrogate_clip': lr_clip = tf.clip_by_value(lr, 1 - self.lr_clip_range, 1 + self.lr_clip_range, name='lr_clip') if self.policy.recurrent: surr_clip = lr_clip * adv * i.valid_var else: surr_clip = lr_clip * adv_valid obj = tf.minimum(surrogate, surr_clip, name='surr_obj') if self._entropy_regularzied: obj += self.policy_ent_coeff * policy_entropy # Maximize E[surrogate objective] by minimizing # -E_t[surrogate objective] if self.policy.recurrent: loss = -tf.reduce_sum(obj) / tf.reduce_sum(i.valid_var) else: loss = -tf.reduce_mean(obj) # Diagnostic functions self.f_policy_kl = tensor_utils.compile_function( flatten_inputs(self._policy_opt_inputs), pol_mean_kl, log_name='f_policy_kl') self.f_rewards = tensor_utils.compile_function( flatten_inputs(self._policy_opt_inputs), rewards, log_name='f_rewards') returns = discounted_returns(self.discount, self.max_path_length, rewards) self.f_returns = tensor_utils.compile_function( flatten_inputs(self._policy_opt_inputs), returns, log_name='f_returns') return loss, pol_mean_kl
def _build_inputs(self): observation_space = self.policy.observation_space action_space = self.policy.action_space policy_dist = self.policy.distribution with tf.name_scope('inputs'): if self.flatten_input: obs_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, observation_space.flat_dim], name='obs') else: obs_var = observation_space.to_tf_placeholder(name='obs', batch_dims=2) action_var = action_space.to_tf_placeholder(name='action', batch_dims=2) reward_var = tensor_utils.new_tensor(name='reward', ndim=2, dtype=tf.float32) valid_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='valid') baseline_var = tensor_utils.new_tensor(name='baseline', ndim=2, dtype=tf.float32) policy_state_info_vars = { k: tf.compat.v1.placeholder(tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # old policy distribution policy_old_dist_info_vars = { k: tf.compat.v1.placeholder(tf.float32, shape=[None] * 2 + list(shape), name='policy_old_%s' % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] # flattened view with tf.name_scope('flat'): obs_flat = flatten_batch(obs_var, name='obs_flat') action_flat = flatten_batch(action_var, name='action_flat') reward_flat = flatten_batch(reward_var, name='reward_flat') valid_flat = flatten_batch(valid_var, name='valid_flat') policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name='policy_state_info_vars_flat') policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name='policy_old_dist_info_vars_flat') # valid view with tf.name_scope('valid'): action_valid = filter_valids(action_flat, valid_flat, name='action_valid') policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name='policy_state_info_vars_valid') policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name='policy_old_dist_info_vars_valid') # policy loss and optimizer inputs pol_flat = graph_inputs( 'PolicyLossInputsFlat', obs_var=obs_flat, action_var=action_flat, reward_var=reward_flat, valid_var=valid_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, ) pol_valid = graph_inputs( 'PolicyLossInputsValid', action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( 'PolicyLossInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) policy_opt_inputs = graph_inputs( 'PolicyOptInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs
def _build_policy_loss(self, i): pol_dist = self.policy.distribution policy_entropy = self._build_entropy_term(i) with tf.name_scope('augmented_rewards'): rewards = i.reward_var + (self.policy_ent_coeff * policy_entropy) with tf.name_scope('policy_loss'): advantages = compute_advantages(self.discount, self.gae_lambda, self.max_path_length, i.baseline_var, rewards, name='advantages') adv_flat = flatten_batch(advantages, name='adv_flat') adv_valid = filter_valids(adv_flat, i.flat.valid_var, name='adv_valid') if self.policy.recurrent: advantages = tf.reshape(advantages, [-1, self.max_path_length]) # Optionally normalize advantages eps = tf.constant(1e-8, dtype=tf.float32) if self.center_adv: with tf.name_scope('center_adv'): mean, var = tf.nn.moments(adv_valid, axes=[0]) adv_valid = tf.nn.batch_normalization( adv_valid, mean, var, 0, 1, eps) if self.positive_adv: with tf.name_scope('positive_adv'): m = tf.reduce_min(adv_valid) adv_valid = (adv_valid - m) + eps if self.policy.recurrent: policy_dist_info = self.policy.dist_info_sym( i.obs_var, i.policy_state_info_vars, name='policy_dist_info') else: policy_dist_info_flat = self.policy.dist_info_sym( i.flat.obs_var, i.flat.policy_state_info_vars, name='policy_dist_info_flat') policy_dist_info_valid = filter_valids_dict( policy_dist_info_flat, i.flat.valid_var, name='policy_dist_info_valid') # Calculate loss function and KL divergence with tf.name_scope('kl'): if self.policy.recurrent: kl = pol_dist.kl_sym( i.policy_old_dist_info_vars, policy_dist_info, ) pol_mean_kl = tf.reduce_sum( kl * i.valid_var) / tf.reduce_sum(i.valid_var) else: kl = pol_dist.kl_sym( i.valid.policy_old_dist_info_vars, policy_dist_info_valid, ) pol_mean_kl = tf.reduce_mean(kl) # Calculate surrogate loss with tf.name_scope('surr_loss'): if self.policy.recurrent: lr = pol_dist.likelihood_ratio_sym( i.action_var, i.policy_old_dist_info_vars, policy_dist_info, name='lr') surr_vanilla = lr * advantages * i.valid_var else: lr = pol_dist.likelihood_ratio_sym( i.valid.action_var, i.valid.policy_old_dist_info_vars, policy_dist_info_valid, name='lr') surr_vanilla = lr * adv_valid if self._pg_loss == PGLoss.VANILLA: # VPG, TRPO use the standard surrogate objective surr_obj = tf.identity(surr_vanilla, name='surr_obj') elif self._pg_loss == PGLoss.CLIP: lr_clip = tf.clip_by_value(lr, 1 - self.clip_range, 1 + self.clip_range, name='lr_clip') if self.policy.recurrent: surr_clip = lr_clip * advantages * i.valid_var else: surr_clip = lr_clip * adv_valid surr_obj = tf.minimum(surr_vanilla, surr_clip, name='surr_obj') else: raise NotImplementedError('Unknown PGLoss') # Maximize E[surrogate objective] by minimizing # -E_t[surrogate objective] if self.policy.recurrent: surr_loss = (-tf.reduce_sum(surr_vanilla)) / tf.reduce_sum( i.valid_var) else: surr_loss = -tf.reduce_mean(surr_obj) # Diagnostic functions self.f_policy_kl = compile_function(flatten_inputs( self._policy_opt_inputs), pol_mean_kl, log_name='f_policy_kl') self.f_rewards = compile_function(flatten_inputs( self._policy_opt_inputs), rewards, log_name='f_rewards') returns = discounted_returns(self.discount, self.max_path_length, rewards) self.f_returns = compile_function(flatten_inputs( self._policy_opt_inputs), returns, log_name='f_returns') return surr_loss, pol_mean_kl