def _build_inputs(self): """Build input variables. Returns: namedtuple: Collection of variables to compute policy loss. namedtuple: Collection of variables to do policy optimization. """ observation_space = self.policy.observation_space action_space = self.policy.action_space with tf.name_scope('inputs'): obs_var = observation_space.to_tf_placeholder( name='obs', batch_dims=2) action_var = action_space.to_tf_placeholder( name='action', batch_dims=2) reward_var = new_tensor( name='reward', ndim=2, dtype=tf.float32) valid_var = new_tensor( name='valid', ndim=2, dtype=tf.float32) feat_diff = new_tensor( name='feat_diff', ndim=2, dtype=tf.float32) param_v = new_tensor( name='param_v', ndim=1, dtype=tf.float32) param_eta = new_tensor( name='param_eta', ndim=0, dtype=tf.float32) policy_state_info_vars = { k: tf.compat.v1.placeholder( tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] self._policy_network = self.policy.build(obs_var, name='policy') self._old_policy_network = self._old_policy.build(obs_var, name='policy') policy_loss_inputs = graph_inputs( 'PolicyLossInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars=policy_state_info_vars, ) policy_opt_inputs = graph_inputs( 'PolicyOptInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, ) dual_opt_inputs = graph_inputs( 'DualOptInputs', reward_var=reward_var, valid_var=valid_var, feat_diff=feat_diff, param_eta=param_eta, param_v=param_v, policy_state_info_vars_list=policy_state_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs, dual_opt_inputs
def _build_inputs(self): """Build input variables. Returns: namedtuple: Collection of variables to compute policy loss. namedtuple: Collection of variables to do policy optimization. namedtuple: Collection of variables to compute inference loss. namedtuple: Collection of variables to do inference optimization. """ # pylint: disable=too-many-statements observation_space = self.policy.observation_space action_space = self.policy.action_space task_space = self.policy.task_space latent_space = self.policy.latent_space trajectory_space = self._inference.spec.input_space with tf.name_scope('inputs'): obs_var = observation_space.to_tf_placeholder(name='obs', batch_dims=2) task_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, task_space.flat_dim], name='task') trajectory_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, trajectory_space.flat_dim]) latent_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, latent_space.flat_dim]) action_var = action_space.to_tf_placeholder(name='action', batch_dims=2) reward_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='reward') baseline_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='baseline') valid_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='valid') # Policy state (for RNNs) policy_state_info_vars = { k: tf.compat.v1.placeholder(tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] # Encoder state (for RNNs) embed_state_info_vars = { k: tf.compat.v1.placeholder(tf.float32, shape=[None] * 2 + list(shape), name='embed_%s' % k) for k, shape in self.policy.encoder.state_info_specs } embed_state_info_vars_list = [ embed_state_info_vars[k] for k in self.policy.encoder.state_info_keys ] # Inference distribution state (for RNNs) infer_state_info_vars = { k: tf.compat.v1.placeholder(tf.float32, shape=[None] * 2 + list(shape), name='infer_%s' % k) for k, shape in self._inference.state_info_specs } infer_state_info_vars_list = [ infer_state_info_vars[k] for k in self._inference.state_info_keys ] extra_obs_var = [ tf.cast(v, tf.float32) for v in policy_state_info_vars_list ] # Pylint false alarm # pylint: disable=unexpected-keyword-arg, no-value-for-parameter augmented_obs_var = tf.concat([obs_var] + extra_obs_var, axis=-1) extra_traj_var = [ tf.cast(v, tf.float32) for v in infer_state_info_vars_list ] augmented_traj_var = tf.concat([trajectory_var] + extra_traj_var, -1) # Policy and encoder network loss and optimizer inputs policy_loss_inputs = graph_inputs( 'PolicyLossInputs', augmented_obs_var=augmented_obs_var, augmented_traj_var=augmented_traj_var, task_var=task_var, latent_var=latent_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var) policy_opt_inputs = graph_inputs( 'PolicyOptInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, trajectory_var=trajectory_var, task_var=task_var, latent_var=latent_var, valid_var=valid_var, policy_state_info_vars_list=policy_state_info_vars_list, embed_state_info_vars_list=embed_state_info_vars_list, ) # Inference network loss and optimizer inputs inference_loss_inputs = graph_inputs('InferenceLossInputs', latent_var=latent_var, valid_var=valid_var) inference_opt_inputs = graph_inputs( 'InferenceOptInputs', latent_var=latent_var, trajectory_var=trajectory_var, valid_var=valid_var, infer_state_info_vars_list=infer_state_info_vars_list, ) return (policy_loss_inputs, policy_opt_inputs, inference_loss_inputs, inference_opt_inputs)
def _build_inputs(self): """Build input variables. Returns: namedtuple: Collection of variables to compute policy loss. namedtuple: Collection of variables to do policy optimization. """ observation_space = self.policy.observation_space action_space = self.policy.action_space with tf.name_scope('inputs'): obs_var = observation_space.to_tf_placeholder(name='obs', batch_dims=2) action_var = action_space.to_tf_placeholder(name='action', batch_dims=2) reward_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='reward') valid_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='valid') baseline_var = tf.compat.v1.placeholder(tf.float32, shape=[None, None], name='baseline') policy_state_info_vars = { k: tf.compat.v1.placeholder(tf.float32, shape=[None] * 2 + list(shape), name=k) for k, shape in self.policy.state_info_specs } policy_state_info_vars_list = [ policy_state_info_vars[k] for k in self.policy.state_info_keys ] augmented_obs_var = obs_var for k in self.policy.state_info_keys: extra_state_var = policy_state_info_vars[k] extra_state_var = tf.cast(extra_state_var, tf.float32) augmented_obs_var = tf.concat([augmented_obs_var, extra_state_var], -1) self._policy_network = self.policy.build(augmented_obs_var, name='policy') self._old_policy_network = self._old_policy.build(augmented_obs_var, name='policy') policy_loss_inputs = graph_inputs( 'PolicyLossInputs', action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars=policy_state_info_vars, ) policy_opt_inputs = graph_inputs( 'PolicyOptInputs', obs_var=obs_var, action_var=action_var, reward_var=reward_var, baseline_var=baseline_var, valid_var=valid_var, policy_state_info_vars_list=policy_state_info_vars_list, ) return policy_loss_inputs, policy_opt_inputs