def _build_inputs(self): """ Builds input variables (and trivial views thereof) for the loss function network """ observation_space = self.policy.observation_space action_space = self.policy.action_space task_space = self.policy.task_space latent_space = self.policy.latent_space trajectory_space = self.inference.input_space policy_dist = self.policy._dist embed_dist = self.policy.embedding._dist infer_dist = self.inference._dist with tf.name_scope("inputs"): obs_var = observation_space.new_tensor_variable( 'obs', extra_dims=1 + 1, ) task_var = task_space.new_tensor_variable( 'task', extra_dims=1 + 1, ) action_var = action_space.new_tensor_variable( 'action', extra_dims=1 + 1, ) reward_var = tensor_utils.new_tensor( 'reward', ndim=1 + 1, dtype=tf.float32, ) latent_var = latent_space.new_tensor_variable( 'latent', extra_dims=1 + 1, ) baseline_var = tensor_utils.new_tensor( 'baseline', ndim=1 + 1, dtype=tf.float32, ) trajectory_var = trajectory_space.new_tensor_variable( 'trajectory', extra_dims=1 + 1, ) valid_var = tf.placeholder(tf.float32, shape=[None, None], name="valid") # Policy state (for RNNs) policy_state_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name=k) for k, shape in self.policy.state_info_specs } policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # Old policy distribution (for KL) policy_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='policy_old_%s' % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] # Embedding state (for RNNs) embed_state_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='embed_%s' % k) for k, shape in self.policy.embedding.state_info_specs } embed_state_info_vars_list = [ embed_state_info_vars[k] for k in self.policy.embedding.state_info_keys ] # Old embedding distribution (for KL) embed_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='embed_old_%s' % k) for k, shape in embed_dist.dist_info_specs } embed_old_dist_info_vars_list = [ embed_old_dist_info_vars[k] for k in embed_dist.dist_info_keys ] # Inference distribution state (for RNNs) infer_state_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='infer_%s' % k) for k, shape in self.inference.state_info_specs } infer_state_info_vars_list = [ infer_state_info_vars[k] for k in self.inference.state_info_keys ] # Old inference distribution (for KL) infer_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='infer_old_%s' % k) for k, shape in infer_dist.dist_info_specs } infer_old_dist_info_vars_list = [ infer_old_dist_info_vars[k] for k in infer_dist.dist_info_keys ] # Flattened view with tf.name_scope("flat"): obs_flat = flatten_batch(obs_var, name="obs_flat") task_flat = flatten_batch(task_var, name="task_flat") action_flat = flatten_batch(action_var, name="action_flat") reward_flat = flatten_batch(reward_var, name="reward_flat") latent_flat = flatten_batch(latent_var, name="latent_flat") trajectory_flat = flatten_batch(trajectory_var, name="trajectory_flat") valid_flat = flatten_batch(valid_var, name="valid_flat") policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name="policy_state_info_vars_flat") policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name="policy_old_dist_info_vars_flat") embed_state_info_vars_flat = flatten_batch_dict( embed_state_info_vars, name="embed_state_info_vars_flat") embed_old_dist_info_vars_flat = flatten_batch_dict( embed_old_dist_info_vars, name="embed_old_dist_info_vars_flat") infer_state_info_vars_flat = flatten_batch_dict( infer_state_info_vars, name="infer_state_info_vars_flat") infer_old_dist_info_vars_flat = flatten_batch_dict( infer_old_dist_info_vars, name="infer_old_dist_info_vars_flat") # Valid view with tf.name_scope("valid"): action_valid = filter_valids(action_flat, valid_flat, name="action_valid") policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name="policy_state_info_vars_valid") policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name="policy_old_dist_info_vars_valid") embed_old_dist_info_vars_valid = filter_valids_dict( embed_old_dist_info_vars_flat, valid_flat, name="embed_old_dist_info_vars_valid") infer_old_dist_info_vars_valid = filter_valids_dict( infer_old_dist_info_vars_flat, valid_flat, name="infer_old_dist_info_vars_valid") # Policy and embedding network loss and optimizer inputs pol_flat = graph_inputs( "PolicyLossInputsFlat", obs_var=obs_flat, task_var=task_flat, action_var=action_flat, reward_var=reward_flat, latent_var=latent_flat, trajectory_var=trajectory_flat, valid_var=valid_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, embed_state_info_vars=embed_state_info_vars_flat, embed_old_dist_info_vars=embed_old_dist_info_vars_flat, ) pol_valid = graph_inputs( "PolicyLossInputsValid", action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, embed_old_dist_info_vars=embed_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( "PolicyLossInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, trajectory_var=trajectory_var, task_var=task_var, latent_var=latent_var, valid_var=valid_var, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, embed_state_info_vars=embed_state_info_vars, embed_old_dist_info_vars=embed_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) # Special variant for the optimizer # * Uses lists instead of dicts for the distribution parameters # * Omits flats and valids # TODO: eliminate policy_opt_inputs = graph_inputs( "PolicyOptInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, trajectory_var=trajectory_var, task_var=task_var, latent_var=latent_var, valid_var=valid_var, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, embed_state_info_vars_list=embed_state_info_vars_list, embed_old_dist_info_vars_list=embed_old_dist_info_vars_list, ) # Inference network loss and optimizer inputs infer_flat = graph_inputs( "InferenceLossInputsFlat", latent_var=latent_flat, trajectory_var=trajectory_flat, valid_var=valid_flat, infer_state_info_vars=infer_state_info_vars_flat, infer_old_dist_info_vars=infer_old_dist_info_vars_flat, ) infer_valid = graph_inputs( "InferenceLossInputsValid", infer_old_dist_info_vars=infer_old_dist_info_vars_valid, ) inference_loss_inputs = graph_inputs( "InferenceLossInputs", latent_var=latent_var, trajectory_var=trajectory_var, valid_var=valid_var, infer_state_info_vars=infer_state_info_vars, infer_old_dist_info_vars=infer_old_dist_info_vars, flat=infer_flat, valid=infer_valid, ) # Special variant for the optimizer # * Uses lists instead of dicts for the distribution parameters # * Omits flats and valids # TODO: eliminate inference_opt_inputs = graph_inputs( "InferenceOptInputs", latent_var=latent_var, trajectory_var=trajectory_var, valid_var=valid_var, infer_state_info_vars_list=infer_state_info_vars_list, infer_old_dist_info_vars_list=infer_old_dist_info_vars_list, ) return (policy_loss_inputs, policy_opt_inputs, inference_loss_inputs, inference_opt_inputs)
def _build_inputs(self): """Decalre graph inputs variables.""" observation_space = self.policy.observation_space action_space = self.policy.action_space policy_dist = self.policy.distribution with tf.name_scope('inputs'): obs_var = observation_space.to_tf_placeholder( name='obs', batch_dims=2) # yapf: disable action_var = action_space.to_tf_placeholder( name='action', batch_dims=2) # yapf: disable reward_var = tensor_utils.new_tensor( name='reward', ndim=2, dtype=tf.float32) # yapf: disable valid_var = tensor_utils.new_tensor( name='valid', ndim=2, dtype=tf.float32) # yapf: disable feat_diff = tensor_utils.new_tensor( name='feat_diff', ndim=2, dtype=tf.float32) # yapf: disable param_v = tensor_utils.new_tensor( name='param_v', ndim=1, dtype=tf.float32) # yapf: disable param_eta = tensor_utils.new_tensor( name='param_eta', ndim=0, dtype=tf.float32) # yapf: disable policy_state_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } # yapf: disable policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # yapf: disable policy_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * 2 + list(shape), name='policy_old_%s' % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] with tf.name_scope('flat'): obs_flat = flatten_batch(obs_var, name='obs_flat') action_flat = flatten_batch(action_var, name='action_flat') reward_flat = flatten_batch(reward_var, name='reward_flat') valid_flat = flatten_batch(valid_var, name='valid_flat') feat_diff_flat = flatten_batch( feat_diff, name='feat_diff_flat') # yapf: disable policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name='policy_state_info_vars_flat') # yapf: disable policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name='policy_old_dist_info_vars_flat') with tf.name_scope('valid'): reward_valid = filter_valids( reward_flat, valid_flat, name='reward_valid') # yapf: disable action_valid = filter_valids( action_flat, valid_flat, name='action_valid') # yapf: disable policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name='policy_state_info_vars_valid') policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name='policy_old_dist_info_vars_valid') pol_flat = graph_inputs( 'PolicyLossInputsFlat', obs_var=obs_flat, action_var=action_flat, reward_var=reward_flat, valid_var=valid_flat, feat_diff=feat_diff_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, ) pol_valid = graph_inputs( 'PolicyLossInputsValid', reward_var=reward_valid, action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( 'PolicyLossInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) policy_opt_inputs = graph_inputs( 'PolicyOptInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) dual_opt_inputs = graph_inputs( 'DualOptInputs', reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs, dual_opt_inputs
def _build_inputs(self): observation_space = self.policy.observation_space action_space = self.policy.action_space policy_dist = self.policy.distribution with tf.name_scope("inputs"): obs_var = observation_space.new_tensor_variable( name="obs", extra_dims=2) action_var = action_space.new_tensor_variable( name="action", extra_dims=2) reward_var = tensor_utils.new_tensor( name="reward", ndim=2, dtype=tf.float32) valid_var = tf.placeholder( tf.float32, shape=[None, None], name="valid") baseline_var = tensor_utils.new_tensor( name="baseline", ndim=2, dtype=tf.float32) policy_state_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # old policy distribution policy_old_dist_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * 2 + list(shape), name="policy_old_%s" % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] # flattened view with tf.name_scope("flat"): obs_flat = flatten_batch(obs_var, name="obs_flat") action_flat = flatten_batch(action_var, name="action_flat") reward_flat = flatten_batch(reward_var, name="reward_flat") valid_flat = flatten_batch(valid_var, name="valid_flat") policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name="policy_state_info_vars_flat") policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name="policy_old_dist_info_vars_flat") # valid view with tf.name_scope("valid"): action_valid = filter_valids( action_flat, valid_flat, name="action_valid") policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name="policy_state_info_vars_valid") policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name="policy_old_dist_info_vars_valid") # policy loss and optimizer inputs pol_flat = graph_inputs( "PolicyLossInputsFlat", obs_var=obs_flat, action_var=action_flat, reward_var=reward_flat, valid_var=valid_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, ) pol_valid = graph_inputs( "PolicyLossInputsValid", action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( "PolicyLossInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) policy_opt_inputs = graph_inputs( "PolicyOptInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs
def _build_inputs(self): """Decalre graph inputs variables.""" observation_space = self.policy.observation_space action_space = self.policy.action_space policy_dist = self.policy.distribution with tf.name_scope("inputs"): obs_var = observation_space.new_tensor_variable( name="obs", extra_dims=2) # yapf: disable action_var = action_space.new_tensor_variable( name="action", extra_dims=2) # yapf: disable reward_var = tensor_utils.new_tensor( name="reward", ndim=2, dtype=tf.float32) # yapf: disable valid_var = tensor_utils.new_tensor( name="valid", ndim=2, dtype=tf.float32) # yapf: disable feat_diff = tensor_utils.new_tensor( name="feat_diff", ndim=2, dtype=tf.float32) # yapf: disable param_v = tensor_utils.new_tensor( name="param_v", ndim=1, dtype=tf.float32) # yapf: disable param_eta = tensor_utils.new_tensor( name="param_eta", ndim=0, dtype=tf.float32) # yapf: disable policy_state_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } # yapf: disable policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # yapf: disable policy_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * 2 + list(shape), name="policy_old_%s" % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] with tf.name_scope("flat"): obs_flat = flatten_batch(obs_var, name="obs_flat") action_flat = flatten_batch(action_var, name="action_flat") reward_flat = flatten_batch(reward_var, name="reward_flat") valid_flat = flatten_batch(valid_var, name="valid_flat") feat_diff_flat = flatten_batch( feat_diff, name="feat_diff_flat") # yapf: disable policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name="policy_state_info_vars_flat") # yapf: disable policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name="policy_old_dist_info_vars_flat") with tf.name_scope("valid"): reward_valid = filter_valids( reward_flat, valid_flat, name="reward_valid") # yapf: disable action_valid = filter_valids( action_flat, valid_flat, name="action_valid") # yapf: disable policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name="policy_state_info_vars_valid") policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name="policy_old_dist_info_vars_valid") pol_flat = graph_inputs( "PolicyLossInputsFlat", obs_var=obs_flat, action_var=action_flat, reward_var=reward_flat, valid_var=valid_flat, feat_diff=feat_diff_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, ) pol_valid = graph_inputs( "PolicyLossInputsValid", reward_var=reward_valid, action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( "PolicyLossInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) policy_opt_inputs = graph_inputs( "PolicyOptInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) dual_opt_inputs = graph_inputs( "DualOptInputs", reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs, dual_opt_inputs
def _build_inputs(self): """Build input variables. Returns: namedtuple: Collection of variables to compute policy loss. namedtuple: Collection of variables to do policy optimization. """ observation_space = self.policy.observation_space action_space = self.policy.action_space with tf.name_scope('inputs'): obs_var = observation_space.to_tf_placeholder( name='obs', batch_dims=2) # yapf: disable action_var = action_space.to_tf_placeholder( name='action', batch_dims=2) # yapf: disable reward_var = tensor_utils.new_tensor( name='reward', ndim=2, dtype=tf.float32) # yapf: disable valid_var = tensor_utils.new_tensor( name='valid', ndim=2, dtype=tf.float32) # yapf: disable feat_diff = tensor_utils.new_tensor( name='feat_diff', ndim=2, dtype=tf.float32) # yapf: disable param_v = tensor_utils.new_tensor( name='param_v', ndim=1, dtype=tf.float32) # yapf: disable param_eta = tensor_utils.new_tensor( name='param_eta', ndim=0, dtype=tf.float32) # yapf: disable policy_state_info_vars = { k: tf.compat.v1.placeholder( tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } # yapf: disable policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # yapf: disable self.policy.build(obs_var) self._old_policy.build(obs_var) policy_loss_inputs = graph_inputs( 'PolicyLossInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars=policy_state_info_vars, ) policy_opt_inputs = graph_inputs( 'PolicyOptInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, ) dual_opt_inputs = graph_inputs( 'DualOptInputs', reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs, dual_opt_inputs
def _build_inputs(self): observation_space = self.policy.observation_space action_space = self.policy.action_space policy_dist = self.policy.distribution with tf.name_scope('inputs'): if self.flatten_input: obs_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, observation_space.flat_dim], name='obs') else: obs_var = observation_space.to_tf_placeholder(name='obs', batch_dims=2) action_var = action_space.to_tf_placeholder(name='action', batch_dims=2) reward_var = tensor_utils.new_tensor(name='reward', ndim=2, dtype=tf.float32) valid_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='valid') baseline_var = tensor_utils.new_tensor(name='baseline', ndim=2, dtype=tf.float32) policy_state_info_vars = { k: tf.compat.v1.placeholder(tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # old policy distribution policy_old_dist_info_vars = { k: tf.compat.v1.placeholder(tf.float32, shape=[None] * 2 + list(shape), name='policy_old_%s' % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] # flattened view with tf.name_scope('flat'): obs_flat = flatten_batch(obs_var, name='obs_flat') action_flat = flatten_batch(action_var, name='action_flat') reward_flat = flatten_batch(reward_var, name='reward_flat') valid_flat = flatten_batch(valid_var, name='valid_flat') policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name='policy_state_info_vars_flat') policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name='policy_old_dist_info_vars_flat') # valid view with tf.name_scope('valid'): action_valid = filter_valids(action_flat, valid_flat, name='action_valid') policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name='policy_state_info_vars_valid') policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name='policy_old_dist_info_vars_valid') # policy loss and optimizer inputs pol_flat = graph_inputs( 'PolicyLossInputsFlat', obs_var=obs_flat, action_var=action_flat, reward_var=reward_flat, valid_var=valid_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, ) pol_valid = graph_inputs( 'PolicyLossInputsValid', action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( 'PolicyLossInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) policy_opt_inputs = graph_inputs( 'PolicyOptInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs
def init_opt(self): with tf.name_scope("inputs"): observations_var = self.env.observation_space.new_tensor_variable( 'observations', extra_dims=1) actions_var = self.env.action_space.new_tensor_variable( 'actions', extra_dims=1) advantages_var = tensor_utils.new_tensor('advantage', ndim=1, dtype=tf.float32) dist = self.policy.distribution dist_info_vars = self.policy.dist_info_sym(observations_var) old_dist_info_vars = self.backup_policy.dist_info_sym( observations_var) kl = dist.kl_sym(old_dist_info_vars, dist_info_vars) mean_kl = tf.reduce_mean(kl) max_kl = tf.reduce_max(kl) pos_eps_dist_info_vars = self.pos_eps_policy.dist_info_sym( observations_var) neg_eps_dist_info_vars = self.neg_eps_policy.dist_info_sym( observations_var) mix_dist_info_vars = self.mix_policy.dist_info_sym(observations_var) surr = tf.reduce_mean( dist.log_likelihood_sym(actions_var, dist_info_vars) * advantages_var) surr_pos_eps = tf.reduce_mean( dist.log_likelihood_sym(actions_var, pos_eps_dist_info_vars) * advantages_var) surr_neg_eps = tf.reduce_mean( dist.log_likelihood_sym(actions_var, neg_eps_dist_info_vars) * advantages_var) surr_mix = tf.reduce_mean( dist.log_likelihood_sym(actions_var, mix_dist_info_vars) * advantages_var) surr_loglikelihood = tf.reduce_sum( dist.log_likelihood_sym(actions_var, mix_dist_info_vars)) params = self.policy.get_params(trainable=True) mix_params = self.mix_policy.get_params(trainable=True) pos_eps_params = self.pos_eps_policy.get_params(trainable=True) neg_eps_params = self.neg_eps_policy.get_params(trainable=True) grads = tf.gradients(surr, params) grad_pos_eps = tf.gradients(surr_pos_eps, pos_eps_params) grad_neg_eps = tf.gradients(surr_neg_eps, neg_eps_params) grad_mix = tf.gradients(surr_mix, mix_params) grad_mix_lh = tf.gradients(surr_loglikelihood, mix_params) self._opt_fun = ext.LazyDict( f_loss=lambda: tensor_utils.compile_function( inputs=[observations_var, actions_var, advantages_var], outputs=surr, log_name="f_loss", ), f_train=lambda: tensor_utils.compile_function( inputs=[observations_var, actions_var, advantages_var], outputs=grads, log_name="f_grad"), f_mix_grad=lambda: tensor_utils.compile_function( inputs=[observations_var, actions_var, advantages_var], outputs=grad_mix, log_name="f_mix_grad"), f_pos_grad=lambda: tensor_utils.compile_function( inputs=[observations_var, actions_var, advantages_var], outputs=grad_pos_eps), f_neg_grad=lambda: tensor_utils.compile_function( inputs=[observations_var, actions_var, advantages_var], outputs=grad_neg_eps), f_mix_lh=lambda: tensor_utils.compile_function( inputs=[observations_var, actions_var], outputs=grad_mix_lh), f_kl=lambda: tensor_utils.compile_function( inputs=[observations_var], outputs=[mean_kl, max_kl], ))
def init_opt(self): observations_var = self.env.observation_space.new_tensor_variable( 'obs', extra_dims=1, ) actions_var = self.env.action_space.new_tensor_variable( 'action', extra_dims=1, ) advantages_var = tensor_utils.new_tensor( name='advantage', ndim=1, dtype=tf.float32, ) dist = self.policy.distribution old_dist_info_vars = self.backup_policy.dist_info_sym(observations_var) dist_info_vars = self.policy.dist_info_sym(observations_var) kl = dist.kl_sym(old_dist_info_vars, dist_info_vars) mean_kl = tf.reduce_mean(kl) max_kl = tf.reduce_max(kl) pos_eps_dist_info_vars = self.pos_eps_policy.dist_info_sym( observations_var) neg_eps_dist_info_vars = self.neg_eps_policy.dist_info_sym( observations_var) mix_dist_info_vars = self.mix_policy.dist_info_sym(observations_var) # formulate as a minimization problem # The gradient of the surrogate objective is the policy gradient surr = -tf.reduce_mean( dist.log_likelihood_sym(actions_var, dist_info_vars) * advantages_var) surr_pos_eps = -tf.reduce_mean( dist.log_likelihood_sym(actions_var, pos_eps_dist_info_vars) * advantages_var) surr_neg_eps = -tf.reduce_mean( dist.log_likelihood_sym(actions_var, neg_eps_dist_info_vars) * advantages_var) surr_mix = -tf.reduce_mean( dist.log_likelihood_sym(actions_var, mix_dist_info_vars) * advantages_var) surr_loglikelihood = tf.reduce_sum( dist.log_likelihood_sym(actions_var, mix_dist_info_vars)) params = self.policy.get_params(trainable=True) mix_params = self.mix_policy.get_params(trainable=True) pos_eps_params = self.pos_eps_policy.get_params(trainable=True) neg_eps_params = self.neg_eps_policy.get_params(trainable=True) grads = tf.gradients(surr, params) grad_pos_eps = tf.gradients(surr_pos_eps, pos_eps_params) grad_neg_eps = tf.gradients(surr_neg_eps, neg_eps_params) grad_mix = tf.gradients(surr_mix, mix_params) grad_mix_lh = tf.gradients(surr_loglikelihood, mix_params) inputs_list = [observations_var, actions_var, advantages_var] self.optimizer.update_opt(loss=surr, target=self.policy, leq_constraint=(mean_kl, self.delta), inputs=inputs_list) self._opt_fun = ext.LazyDict( f_loss=lambda: tensor_utils.compile_function( inputs=inputs_list, outputs=surr, log_name="f_loss", ), f_train=lambda: tensor_utils.compile_function( inputs=inputs_list, outputs=grads, log_name="f_grad"), f_mix_grad=lambda: tensor_utils.compile_function( inputs=inputs_list, outputs=grad_mix, log_name="f_mix_grad"), f_pos_grad=lambda: tensor_utils.compile_function( inputs=inputs_list, outputs=grad_pos_eps), f_neg_grad=lambda: tensor_utils.compile_function( inputs=inputs_list, outputs=grad_neg_eps), f_mix_lh=lambda: tensor_utils.compile_function( inputs=inputs_list, outputs=grad_mix_lh), f_kl=lambda: tensor_utils.compile_function( inputs=inputs_list, outputs=[mean_kl, max_kl], ))
def init_opt(self): with tf.name_scope(self.name, "NPO"): is_recurrent = int(self.policy.recurrent) with tf.name_scope("inputs"): obs_var = self.env.observation_space.new_tensor_variable( 'obs', extra_dims=1 + is_recurrent, ) action_var = self.env.action_space.new_tensor_variable( 'action', extra_dims=1 + is_recurrent, ) advantage_var = tensor_utils.new_tensor( 'advantage', ndim=1 + is_recurrent, dtype=tf.float32, ) dist = self.policy.distribution old_dist_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name='old_%s' % k) for k, shape in dist.dist_info_specs } old_dist_info_vars_list = [ old_dist_info_vars[k] for k in dist.dist_info_keys ] state_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name=k) for k, shape in self.policy.state_info_specs } state_info_vars_list = [ state_info_vars[k] for k in self.policy.state_info_keys ] if is_recurrent: valid_var = tf.placeholder( tf.float32, shape=[None, None], name="valid") else: valid_var = None dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars) kl = dist.kl_sym(old_dist_info_vars, dist_info_vars) lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, dist_info_vars) with tf.name_scope("mean_kl", values=[kl, valid_var]): if is_recurrent: mean_kl = tf.reduce_sum( kl * valid_var) / tf.reduce_sum(valid_var) else: mean_kl = tf.reduce_mean(kl, name="reduce_mean_er") tf.identity(mean_kl, name="mean_kl") with tf.name_scope( "surr_loss", values=[lr, advantage_var, valid_var]): if is_recurrent: surr_loss = (-tf.reduce_sum(lr * advantage_var * valid_var) / tf.reduce_sum(valid_var)) else: surr_loss = -tf.reduce_mean(lr * advantage_var) tf.identity(surr_loss, name="surr_loss") input_list = [ obs_var, action_var, advantage_var, ] + state_info_vars_list + old_dist_info_vars_list if is_recurrent: input_list.append(valid_var) self.optimizer.update_opt( loss=surr_loss, target=self.policy, leq_constraint=(mean_kl, self.step_size), inputs=input_list, constraint_name="mean_kl", name="update_opt_surr_loss") return dict()
def init_opt(self): with tf.name_scope(self.name, "VPG"): is_recurrent = int(self.policy.recurrent) with tf.name_scope("inputs"): obs_var = self.env.observation_space.new_tensor_variable( 'obs', extra_dims=1 + is_recurrent, ) action_var = self.env.action_space.new_tensor_variable( 'action', extra_dims=1 + is_recurrent, ) advantage_var = tensor_utils.new_tensor( name='advantage', ndim=1 + is_recurrent, dtype=tf.float32, ) dist = self.policy.distribution old_dist_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name='old_%s' % k) for k, shape in dist.dist_info_specs } old_dist_info_vars_list = [ old_dist_info_vars[k] for k in dist.dist_info_keys ] state_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name=k) for k, shape in self.policy.state_info_specs } state_info_vars_list = [ state_info_vars[k] for k in self.policy.state_info_keys ] if is_recurrent: valid_var = tf.placeholder( tf.float32, shape=[None, None], name="valid") else: valid_var = None dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars) logli = dist.log_likelihood_sym(action_var, dist_info_vars) kl = dist.kl_sym(old_dist_info_vars, dist_info_vars) # formulate as a minimization problem # The gradient of the surrogate objective is the policy gradient with tf.name_scope( "surr_obj", values=[logli, advantage_var, valid_var]): if is_recurrent: surr_obj = ( -tf.reduce_sum(logli * advantage_var * valid_var) / tf.reduce_sum(valid_var)) else: surr_obj = -tf.reduce_mean(logli * advantage_var) tf.identity(surr_obj, name="surr_obj") with tf.name_scope("mean_kl", values=[kl, valid_var]): if is_recurrent: mean_kl = tf.reduce_sum( kl * valid_var) / tf.reduce_sum(valid_var) else: mean_kl = tf.reduce_mean(kl) tf.identity(mean_kl, name="mean_kl") with tf.name_scope("max_kl", values=[kl, valid_var]): if is_recurrent: max_kl = tf.reduce_max(kl * valid_var) else: max_kl = tf.reduce_max(kl) tf.identity(max_kl, name="max_kl") input_list = [obs_var, action_var, advantage_var ] + state_info_vars_list if is_recurrent: input_list.append(valid_var) self.optimizer.update_opt( loss=surr_obj, target=self.policy, inputs=input_list) f_kl = tensor_utils.compile_function( inputs=input_list + old_dist_info_vars_list, outputs=[mean_kl, max_kl], ) self.opt_info = dict(f_kl=f_kl, )