def _build_inputs(self): """ Builds input variables (and trivial views thereof) for the loss function network """ observation_space = self.policy.observation_space action_space = self.policy.action_space task_space = self.policy.task_space latent_space = self.policy.latent_space trajectory_space = self.inference.input_space policy_dist = self.policy._dist embed_dist = self.policy.embedding._dist infer_dist = self.inference._dist with tf.name_scope("inputs"): obs_var = observation_space.new_tensor_variable( 'obs', extra_dims=1 + 1, ) task_var = task_space.new_tensor_variable( 'task', extra_dims=1 + 1, ) action_var = action_space.new_tensor_variable( 'action', extra_dims=1 + 1, ) reward_var = tensor_utils.new_tensor( 'reward', ndim=1 + 1, dtype=tf.float32, ) latent_var = latent_space.new_tensor_variable( 'latent', extra_dims=1 + 1, ) baseline_var = tensor_utils.new_tensor( 'baseline', ndim=1 + 1, dtype=tf.float32, ) trajectory_var = trajectory_space.new_tensor_variable( 'trajectory', extra_dims=1 + 1, ) valid_var = tf.placeholder(tf.float32, shape=[None, None], name="valid") # Policy state (for RNNs) policy_state_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name=k) for k, shape in self.policy.state_info_specs } policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # Old policy distribution (for KL) policy_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='policy_old_%s' % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] # Embedding state (for RNNs) embed_state_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='embed_%s' % k) for k, shape in self.policy.embedding.state_info_specs } embed_state_info_vars_list = [ embed_state_info_vars[k] for k in self.policy.embedding.state_info_keys ] # Old embedding distribution (for KL) embed_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='embed_old_%s' % k) for k, shape in embed_dist.dist_info_specs } embed_old_dist_info_vars_list = [ embed_old_dist_info_vars[k] for k in embed_dist.dist_info_keys ] # Inference distribution state (for RNNs) infer_state_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='infer_%s' % k) for k, shape in self.inference.state_info_specs } infer_state_info_vars_list = [ infer_state_info_vars[k] for k in self.inference.state_info_keys ] # Old inference distribution (for KL) infer_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * (1 + 1) + list(shape), name='infer_old_%s' % k) for k, shape in infer_dist.dist_info_specs } infer_old_dist_info_vars_list = [ infer_old_dist_info_vars[k] for k in infer_dist.dist_info_keys ] # Flattened view with tf.name_scope("flat"): obs_flat = flatten_batch(obs_var, name="obs_flat") task_flat = flatten_batch(task_var, name="task_flat") action_flat = flatten_batch(action_var, name="action_flat") reward_flat = flatten_batch(reward_var, name="reward_flat") latent_flat = flatten_batch(latent_var, name="latent_flat") trajectory_flat = flatten_batch(trajectory_var, name="trajectory_flat") valid_flat = flatten_batch(valid_var, name="valid_flat") policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name="policy_state_info_vars_flat") policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name="policy_old_dist_info_vars_flat") embed_state_info_vars_flat = flatten_batch_dict( embed_state_info_vars, name="embed_state_info_vars_flat") embed_old_dist_info_vars_flat = flatten_batch_dict( embed_old_dist_info_vars, name="embed_old_dist_info_vars_flat") infer_state_info_vars_flat = flatten_batch_dict( infer_state_info_vars, name="infer_state_info_vars_flat") infer_old_dist_info_vars_flat = flatten_batch_dict( infer_old_dist_info_vars, name="infer_old_dist_info_vars_flat") # Valid view with tf.name_scope("valid"): action_valid = filter_valids(action_flat, valid_flat, name="action_valid") policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name="policy_state_info_vars_valid") policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name="policy_old_dist_info_vars_valid") embed_old_dist_info_vars_valid = filter_valids_dict( embed_old_dist_info_vars_flat, valid_flat, name="embed_old_dist_info_vars_valid") infer_old_dist_info_vars_valid = filter_valids_dict( infer_old_dist_info_vars_flat, valid_flat, name="infer_old_dist_info_vars_valid") # Policy and embedding network loss and optimizer inputs pol_flat = graph_inputs( "PolicyLossInputsFlat", obs_var=obs_flat, task_var=task_flat, action_var=action_flat, reward_var=reward_flat, latent_var=latent_flat, trajectory_var=trajectory_flat, valid_var=valid_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, embed_state_info_vars=embed_state_info_vars_flat, embed_old_dist_info_vars=embed_old_dist_info_vars_flat, ) pol_valid = graph_inputs( "PolicyLossInputsValid", action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, embed_old_dist_info_vars=embed_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( "PolicyLossInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, trajectory_var=trajectory_var, task_var=task_var, latent_var=latent_var, valid_var=valid_var, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, embed_state_info_vars=embed_state_info_vars, embed_old_dist_info_vars=embed_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) # Special variant for the optimizer # * Uses lists instead of dicts for the distribution parameters # * Omits flats and valids # TODO: eliminate policy_opt_inputs = graph_inputs( "PolicyOptInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, trajectory_var=trajectory_var, task_var=task_var, latent_var=latent_var, valid_var=valid_var, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, embed_state_info_vars_list=embed_state_info_vars_list, embed_old_dist_info_vars_list=embed_old_dist_info_vars_list, ) # Inference network loss and optimizer inputs infer_flat = graph_inputs( "InferenceLossInputsFlat", latent_var=latent_flat, trajectory_var=trajectory_flat, valid_var=valid_flat, infer_state_info_vars=infer_state_info_vars_flat, infer_old_dist_info_vars=infer_old_dist_info_vars_flat, ) infer_valid = graph_inputs( "InferenceLossInputsValid", infer_old_dist_info_vars=infer_old_dist_info_vars_valid, ) inference_loss_inputs = graph_inputs( "InferenceLossInputs", latent_var=latent_var, trajectory_var=trajectory_var, valid_var=valid_var, infer_state_info_vars=infer_state_info_vars, infer_old_dist_info_vars=infer_old_dist_info_vars, flat=infer_flat, valid=infer_valid, ) # Special variant for the optimizer # * Uses lists instead of dicts for the distribution parameters # * Omits flats and valids # TODO: eliminate inference_opt_inputs = graph_inputs( "InferenceOptInputs", latent_var=latent_var, trajectory_var=trajectory_var, valid_var=valid_var, infer_state_info_vars_list=infer_state_info_vars_list, infer_old_dist_info_vars_list=infer_old_dist_info_vars_list, ) return (policy_loss_inputs, policy_opt_inputs, inference_loss_inputs, inference_opt_inputs)
def _build_inputs(self): """Build input variables. Returns: namedtuple: Collection of variables to compute policy loss. namedtuple: Collection of variables to do policy optimization. """ observation_space = self.policy.observation_space action_space = self.policy.action_space with tf.name_scope('inputs'): if self._flatten_input: obs_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, observation_space.flat_dim], name='obs') else: obs_var = observation_space.to_tf_placeholder(name='obs', batch_dims=2) action_var = action_space.to_tf_placeholder(name='action', batch_dims=2) reward_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='reward') valid_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='valid') baseline_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='baseline') policy_state_info_vars = { k: tf.compat.v1.placeholder(tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # concat the action input with obs_var to become the final # state input augmented_obs_var = obs_var for k in self.policy.state_info_keys: extra_state_var = policy_state_info_vars[k] extra_state_var = tf.cast(extra_state_var, tf.float32) augmented_obs_var = tf.concat([augmented_obs_var, extra_state_var], -1) self.policy.build(augmented_obs_var) self._old_policy.build(augmented_obs_var) self._old_policy.model.parameters = self.policy.model.parameters policy_loss_inputs = graph_inputs( 'PolicyLossInputs', action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars=policy_state_info_vars, ) policy_opt_inputs = graph_inputs( 'PolicyOptInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars_list=policy_state_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs
def _build_inputs(self): observation_space = self.policy.observation_space action_space = self.policy.action_space policy_dist = self.policy.distribution with tf.name_scope("inputs"): obs_var = observation_space.new_tensor_variable( name="obs", extra_dims=2) action_var = action_space.new_tensor_variable( name="action", extra_dims=2) reward_var = tensor_utils.new_tensor( name="reward", ndim=2, dtype=tf.float32) valid_var = tf.placeholder( tf.float32, shape=[None, None], name="valid") baseline_var = tensor_utils.new_tensor( name="baseline", ndim=2, dtype=tf.float32) policy_state_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # old policy distribution policy_old_dist_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * 2 + list(shape), name="policy_old_%s" % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] # flattened view with tf.name_scope("flat"): obs_flat = flatten_batch(obs_var, name="obs_flat") action_flat = flatten_batch(action_var, name="action_flat") reward_flat = flatten_batch(reward_var, name="reward_flat") valid_flat = flatten_batch(valid_var, name="valid_flat") policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name="policy_state_info_vars_flat") policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name="policy_old_dist_info_vars_flat") # valid view with tf.name_scope("valid"): action_valid = filter_valids( action_flat, valid_flat, name="action_valid") policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name="policy_state_info_vars_valid") policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name="policy_old_dist_info_vars_valid") # policy loss and optimizer inputs pol_flat = graph_inputs( "PolicyLossInputsFlat", obs_var=obs_flat, action_var=action_flat, reward_var=reward_flat, valid_var=valid_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, ) pol_valid = graph_inputs( "PolicyLossInputsValid", action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( "PolicyLossInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) policy_opt_inputs = graph_inputs( "PolicyOptInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs
def _build_inputs(self): """Decalre graph inputs variables.""" observation_space = self.policy.observation_space action_space = self.policy.action_space policy_dist = self.policy.distribution with tf.name_scope('inputs'): obs_var = observation_space.to_tf_placeholder( name='obs', batch_dims=2) # yapf: disable action_var = action_space.to_tf_placeholder( name='action', batch_dims=2) # yapf: disable reward_var = tensor_utils.new_tensor( name='reward', ndim=2, dtype=tf.float32) # yapf: disable valid_var = tensor_utils.new_tensor( name='valid', ndim=2, dtype=tf.float32) # yapf: disable feat_diff = tensor_utils.new_tensor( name='feat_diff', ndim=2, dtype=tf.float32) # yapf: disable param_v = tensor_utils.new_tensor( name='param_v', ndim=1, dtype=tf.float32) # yapf: disable param_eta = tensor_utils.new_tensor( name='param_eta', ndim=0, dtype=tf.float32) # yapf: disable policy_state_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } # yapf: disable policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # yapf: disable policy_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * 2 + list(shape), name='policy_old_%s' % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] with tf.name_scope('flat'): obs_flat = flatten_batch(obs_var, name='obs_flat') action_flat = flatten_batch(action_var, name='action_flat') reward_flat = flatten_batch(reward_var, name='reward_flat') valid_flat = flatten_batch(valid_var, name='valid_flat') feat_diff_flat = flatten_batch( feat_diff, name='feat_diff_flat') # yapf: disable policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name='policy_state_info_vars_flat') # yapf: disable policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name='policy_old_dist_info_vars_flat') with tf.name_scope('valid'): reward_valid = filter_valids( reward_flat, valid_flat, name='reward_valid') # yapf: disable action_valid = filter_valids( action_flat, valid_flat, name='action_valid') # yapf: disable policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name='policy_state_info_vars_valid') policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name='policy_old_dist_info_vars_valid') pol_flat = graph_inputs( 'PolicyLossInputsFlat', obs_var=obs_flat, action_var=action_flat, reward_var=reward_flat, valid_var=valid_flat, feat_diff=feat_diff_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, ) pol_valid = graph_inputs( 'PolicyLossInputsValid', reward_var=reward_valid, action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( 'PolicyLossInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) policy_opt_inputs = graph_inputs( 'PolicyOptInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) dual_opt_inputs = graph_inputs( 'DualOptInputs', reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs, dual_opt_inputs
def _build_inputs(self): """Build input variables. Returns: namedtuple: Collection of variables to compute policy loss. namedtuple: Collection of variables to do policy optimization. """ observation_space = self.policy.observation_space action_space = self.policy.action_space with tf.name_scope('inputs'): obs_var = observation_space.to_tf_placeholder( name='obs', batch_dims=2) # yapf: disable action_var = action_space.to_tf_placeholder( name='action', batch_dims=2) # yapf: disable reward_var = tensor_utils.new_tensor( name='reward', ndim=2, dtype=tf.float32) # yapf: disable valid_var = tensor_utils.new_tensor( name='valid', ndim=2, dtype=tf.float32) # yapf: disable feat_diff = tensor_utils.new_tensor( name='feat_diff', ndim=2, dtype=tf.float32) # yapf: disable param_v = tensor_utils.new_tensor( name='param_v', ndim=1, dtype=tf.float32) # yapf: disable param_eta = tensor_utils.new_tensor( name='param_eta', ndim=0, dtype=tf.float32) # yapf: disable policy_state_info_vars = { k: tf.compat.v1.placeholder( tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } # yapf: disable policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # yapf: disable self.policy.build(obs_var) self._old_policy.build(obs_var) policy_loss_inputs = graph_inputs( 'PolicyLossInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars=policy_state_info_vars, ) policy_opt_inputs = graph_inputs( 'PolicyOptInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, ) dual_opt_inputs = graph_inputs( 'DualOptInputs', reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs, dual_opt_inputs
def _build_inputs(self): """Build input variables. Returns: namedtuple: Collection of variables to compute policy loss. namedtuple: Collection of variables to do policy optimization. namedtuple: Collection of variables to compute inference loss. namedtuple: Collection of variables to do inference optimization. """ # pylint: disable=too-many-statements observation_space = self.policy.observation_space action_space = self.policy.action_space task_space = self.policy.task_space latent_space = self.policy.latent_space trajectory_space = self._inference.spec.input_space with tf.name_scope('inputs'): obs_var = observation_space.to_tf_placeholder(name='obs', batch_dims=2) task_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, task_space.flat_dim], name='task') trajectory_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, trajectory_space.flat_dim]) latent_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, latent_space.flat_dim]) action_var = action_space.to_tf_placeholder(name='action', batch_dims=2) reward_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='reward') baseline_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='baseline') valid_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='valid') # Policy state (for RNNs) policy_state_info_vars = { k: tf.compat.v1.placeholder(tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # Encoder state (for RNNs) embed_state_info_vars = { k: tf.compat.v1.placeholder(tf.float32, shape=[None] * 2 + list(shape), name='embed_%s' % k) for k, shape in self.policy.encoder.state_info_specs } embed_state_info_vars_list = [ embed_state_info_vars[k] for k in self.policy.encoder.state_info_keys ] # Inference distribution state (for RNNs) infer_state_info_vars = { k: tf.compat.v1.placeholder(tf.float32, shape=[None] * 2 + list(shape), name='infer_%s' % k) for k, shape in self._inference.state_info_specs } infer_state_info_vars_list = [ infer_state_info_vars[k] for k in self._inference.state_info_keys ] extra_obs_var = [ tf.cast(v, tf.float32) for v in policy_state_info_vars_list ] # Pylint false alarm # pylint: disable=unexpected-keyword-arg, no-value-for-parameter augmented_obs_var = tf.concat([obs_var] + extra_obs_var, axis=-1) extra_traj_var = [ tf.cast(v, tf.float32) for v in infer_state_info_vars_list ] augmented_traj_var = tf.concat([trajectory_var] + extra_traj_var, -1) # Policy and encoder network loss and optimizer inputs policy_loss_inputs = graph_inputs( 'PolicyLossInputs', augmented_obs_var=augmented_obs_var, augmented_traj_var=augmented_traj_var, task_var=task_var, latent_var=latent_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var) policy_opt_inputs = graph_inputs( 'PolicyOptInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, trajectory_var=trajectory_var, task_var=task_var, latent_var=latent_var, valid_var=valid_var, policy_state_info_vars_list=policy_state_info_vars_list, embed_state_info_vars_list=embed_state_info_vars_list, ) # Inference network loss and optimizer inputs inference_loss_inputs = graph_inputs('InferenceLossInputs', latent_var=latent_var, valid_var=valid_var) inference_opt_inputs = graph_inputs( 'InferenceOptInputs', latent_var=latent_var, trajectory_var=trajectory_var, valid_var=valid_var, infer_state_info_vars_list=infer_state_info_vars_list, ) return (policy_loss_inputs, policy_opt_inputs, inference_loss_inputs, inference_opt_inputs)
def _build_inputs(self): """Decalre graph inputs variables.""" observation_space = self.policy.observation_space action_space = self.policy.action_space policy_dist = self.policy.distribution with tf.name_scope("inputs"): obs_var = observation_space.new_tensor_variable( name="obs", extra_dims=2) # yapf: disable action_var = action_space.new_tensor_variable( name="action", extra_dims=2) # yapf: disable reward_var = tensor_utils.new_tensor( name="reward", ndim=2, dtype=tf.float32) # yapf: disable valid_var = tensor_utils.new_tensor( name="valid", ndim=2, dtype=tf.float32) # yapf: disable feat_diff = tensor_utils.new_tensor( name="feat_diff", ndim=2, dtype=tf.float32) # yapf: disable param_v = tensor_utils.new_tensor( name="param_v", ndim=1, dtype=tf.float32) # yapf: disable param_eta = tensor_utils.new_tensor( name="param_eta", ndim=0, dtype=tf.float32) # yapf: disable policy_state_info_vars = { k: tf.placeholder( tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } # yapf: disable policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # yapf: disable policy_old_dist_info_vars = { k: tf.placeholder(tf.float32, shape=[None] * 2 + list(shape), name="policy_old_%s" % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] with tf.name_scope("flat"): obs_flat = flatten_batch(obs_var, name="obs_flat") action_flat = flatten_batch(action_var, name="action_flat") reward_flat = flatten_batch(reward_var, name="reward_flat") valid_flat = flatten_batch(valid_var, name="valid_flat") feat_diff_flat = flatten_batch( feat_diff, name="feat_diff_flat") # yapf: disable policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name="policy_state_info_vars_flat") # yapf: disable policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name="policy_old_dist_info_vars_flat") with tf.name_scope("valid"): reward_valid = filter_valids( reward_flat, valid_flat, name="reward_valid") # yapf: disable action_valid = filter_valids( action_flat, valid_flat, name="action_valid") # yapf: disable policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name="policy_state_info_vars_valid") policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name="policy_old_dist_info_vars_valid") pol_flat = graph_inputs( "PolicyLossInputsFlat", obs_var=obs_flat, action_var=action_flat, reward_var=reward_flat, valid_var=valid_flat, feat_diff=feat_diff_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, ) pol_valid = graph_inputs( "PolicyLossInputsValid", reward_var=reward_valid, action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( "PolicyLossInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) policy_opt_inputs = graph_inputs( "PolicyOptInputs", obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) dual_opt_inputs = graph_inputs( "DualOptInputs", reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs, dual_opt_inputs
def _build_inputs(self): observation_space = self.policy.observation_space action_space = self.policy.action_space policy_dist = self.policy.distribution with tf.name_scope('inputs'): if self.flatten_input: obs_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, observation_space.flat_dim], name='obs') else: obs_var = observation_space.to_tf_placeholder(name='obs', batch_dims=2) action_var = action_space.to_tf_placeholder(name='action', batch_dims=2) reward_var = tensor_utils.new_tensor(name='reward', ndim=2, dtype=tf.float32) valid_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='valid') baseline_var = tensor_utils.new_tensor(name='baseline', ndim=2, dtype=tf.float32) policy_state_info_vars = { k: tf.compat.v1.placeholder(tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # old policy distribution policy_old_dist_info_vars = { k: tf.compat.v1.placeholder(tf.float32, shape=[None] * 2 + list(shape), name='policy_old_%s' % k) for k, shape in policy_dist.dist_info_specs } policy_old_dist_info_vars_list = [ policy_old_dist_info_vars[k] for k in policy_dist.dist_info_keys ] # flattened view with tf.name_scope('flat'): obs_flat = flatten_batch(obs_var, name='obs_flat') action_flat = flatten_batch(action_var, name='action_flat') reward_flat = flatten_batch(reward_var, name='reward_flat') valid_flat = flatten_batch(valid_var, name='valid_flat') policy_state_info_vars_flat = flatten_batch_dict( policy_state_info_vars, name='policy_state_info_vars_flat') policy_old_dist_info_vars_flat = flatten_batch_dict( policy_old_dist_info_vars, name='policy_old_dist_info_vars_flat') # valid view with tf.name_scope('valid'): action_valid = filter_valids(action_flat, valid_flat, name='action_valid') policy_state_info_vars_valid = filter_valids_dict( policy_state_info_vars_flat, valid_flat, name='policy_state_info_vars_valid') policy_old_dist_info_vars_valid = filter_valids_dict( policy_old_dist_info_vars_flat, valid_flat, name='policy_old_dist_info_vars_valid') # policy loss and optimizer inputs pol_flat = graph_inputs( 'PolicyLossInputsFlat', obs_var=obs_flat, action_var=action_flat, reward_var=reward_flat, valid_var=valid_flat, policy_state_info_vars=policy_state_info_vars_flat, policy_old_dist_info_vars=policy_old_dist_info_vars_flat, ) pol_valid = graph_inputs( 'PolicyLossInputsValid', action_var=action_valid, policy_state_info_vars=policy_state_info_vars_valid, policy_old_dist_info_vars=policy_old_dist_info_vars_valid, ) policy_loss_inputs = graph_inputs( 'PolicyLossInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars=policy_state_info_vars, policy_old_dist_info_vars=policy_old_dist_info_vars, flat=pol_flat, valid=pol_valid, ) policy_opt_inputs = graph_inputs( 'PolicyOptInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars_list=policy_state_info_vars_list, policy_old_dist_info_vars_list=policy_old_dist_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs