コード例 #1
0
    def _build_inputs(self):
        """Build input variables.

        Returns:
            namedtuple: Collection of variables to compute policy loss.
            namedtuple: Collection of variables to do policy optimization.

        """
        observation_space = self.policy.observation_space
        action_space = self.policy.action_space

        with tf.name_scope('inputs'):
            obs_var = observation_space.to_tf_placeholder(
                name='obs',
                batch_dims=2)
            action_var = action_space.to_tf_placeholder(
                name='action',
                batch_dims=2)
            reward_var = new_tensor(
                name='reward',
                ndim=2,
                dtype=tf.float32)
            valid_var = new_tensor(
                name='valid',
                ndim=2,
                dtype=tf.float32)
            feat_diff = new_tensor(
                name='feat_diff',
                ndim=2,
                dtype=tf.float32)
            param_v = new_tensor(
                name='param_v',
                ndim=1,
                dtype=tf.float32)
            param_eta = new_tensor(
                name='param_eta',
                ndim=0,
                dtype=tf.float32)
            policy_state_info_vars = {
                k: tf.compat.v1.placeholder(
                    tf.float32,
                    shape=[None] * 2 + list(shape),
                    name=k)
                for k, shape in self.policy.state_info_specs
            }
            policy_state_info_vars_list = [
                policy_state_info_vars[k]
                for k in self.policy.state_info_keys
            ]

        self._policy_network = self.policy.build(obs_var, name='policy')
        self._old_policy_network = self._old_policy.build(obs_var,
                                                          name='policy')

        policy_loss_inputs = graph_inputs(
            'PolicyLossInputs',
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            valid_var=valid_var,
            feat_diff=feat_diff,
            param_eta=param_eta,
            param_v=param_v,
            policy_state_info_vars=policy_state_info_vars,
        )
        policy_opt_inputs = graph_inputs(
            'PolicyOptInputs',
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            valid_var=valid_var,
            feat_diff=feat_diff,
            param_eta=param_eta,
            param_v=param_v,
            policy_state_info_vars_list=policy_state_info_vars_list,
        )
        dual_opt_inputs = graph_inputs(
            'DualOptInputs',
            reward_var=reward_var,
            valid_var=valid_var,
            feat_diff=feat_diff,
            param_eta=param_eta,
            param_v=param_v,
            policy_state_info_vars_list=policy_state_info_vars_list,
        )

        return policy_loss_inputs, policy_opt_inputs, dual_opt_inputs
コード例 #2
0
    def _build_inputs(self):
        """Build input variables.

        Returns:
            namedtuple: Collection of variables to compute policy loss.
            namedtuple: Collection of variables to do policy optimization.
            namedtuple: Collection of variables to compute inference loss.
            namedtuple: Collection of variables to do inference optimization.

        """
        # pylint: disable=too-many-statements
        observation_space = self.policy.observation_space
        action_space = self.policy.action_space
        task_space = self.policy.task_space
        latent_space = self.policy.latent_space
        trajectory_space = self._inference.spec.input_space

        with tf.name_scope('inputs'):
            obs_var = observation_space.to_tf_placeholder(name='obs',
                                                          batch_dims=2)
            task_var = tf.compat.v1.placeholder(
                tf.float32,
                shape=[None, None, task_space.flat_dim],
                name='task')
            trajectory_var = tf.compat.v1.placeholder(
                tf.float32, shape=[None, None, trajectory_space.flat_dim])
            latent_var = tf.compat.v1.placeholder(
                tf.float32, shape=[None, None, latent_space.flat_dim])

            action_var = action_space.to_tf_placeholder(name='action',
                                                        batch_dims=2)
            reward_var = tf.compat.v1.placeholder(tf.float32,
                                                  shape=[None, None],
                                                  name='reward')
            baseline_var = tf.compat.v1.placeholder(tf.float32,
                                                    shape=[None, None],
                                                    name='baseline')

            valid_var = tf.compat.v1.placeholder(tf.float32,
                                                 shape=[None, None],
                                                 name='valid')

            # Policy state (for RNNs)
            policy_state_info_vars = {
                k: tf.compat.v1.placeholder(tf.float32,
                                            shape=[None] * 2 + list(shape),
                                            name=k)
                for k, shape in self.policy.state_info_specs
            }
            policy_state_info_vars_list = [
                policy_state_info_vars[k] for k in self.policy.state_info_keys
            ]

            # Encoder state (for RNNs)
            embed_state_info_vars = {
                k: tf.compat.v1.placeholder(tf.float32,
                                            shape=[None] * 2 + list(shape),
                                            name='embed_%s' % k)
                for k, shape in self.policy.encoder.state_info_specs
            }
            embed_state_info_vars_list = [
                embed_state_info_vars[k]
                for k in self.policy.encoder.state_info_keys
            ]

            # Inference distribution state (for RNNs)
            infer_state_info_vars = {
                k: tf.compat.v1.placeholder(tf.float32,
                                            shape=[None] * 2 + list(shape),
                                            name='infer_%s' % k)
                for k, shape in self._inference.state_info_specs
            }
            infer_state_info_vars_list = [
                infer_state_info_vars[k]
                for k in self._inference.state_info_keys
            ]

        extra_obs_var = [
            tf.cast(v, tf.float32) for v in policy_state_info_vars_list
        ]
        # Pylint false alarm
        # pylint: disable=unexpected-keyword-arg, no-value-for-parameter
        augmented_obs_var = tf.concat([obs_var] + extra_obs_var, axis=-1)
        extra_traj_var = [
            tf.cast(v, tf.float32) for v in infer_state_info_vars_list
        ]
        augmented_traj_var = tf.concat([trajectory_var] + extra_traj_var, -1)

        # Policy and encoder network loss and optimizer inputs
        policy_loss_inputs = graph_inputs(
            'PolicyLossInputs',
            augmented_obs_var=augmented_obs_var,
            augmented_traj_var=augmented_traj_var,
            task_var=task_var,
            latent_var=latent_var,
            action_var=action_var,
            reward_var=reward_var,
            baseline_var=baseline_var,
            valid_var=valid_var)
        policy_opt_inputs = graph_inputs(
            'PolicyOptInputs',
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            baseline_var=baseline_var,
            trajectory_var=trajectory_var,
            task_var=task_var,
            latent_var=latent_var,
            valid_var=valid_var,
            policy_state_info_vars_list=policy_state_info_vars_list,
            embed_state_info_vars_list=embed_state_info_vars_list,
        )

        # Inference network loss and optimizer inputs
        inference_loss_inputs = graph_inputs('InferenceLossInputs',
                                             latent_var=latent_var,
                                             valid_var=valid_var)
        inference_opt_inputs = graph_inputs(
            'InferenceOptInputs',
            latent_var=latent_var,
            trajectory_var=trajectory_var,
            valid_var=valid_var,
            infer_state_info_vars_list=infer_state_info_vars_list,
        )

        return (policy_loss_inputs, policy_opt_inputs, inference_loss_inputs,
                inference_opt_inputs)
コード例 #3
0
    def _build_inputs(self):
        """Build input variables.

        Returns:
            namedtuple: Collection of variables to compute policy loss.
            namedtuple: Collection of variables to do policy optimization.

        """
        observation_space = self.policy.observation_space
        action_space = self.policy.action_space

        with tf.name_scope('inputs'):
            obs_var = observation_space.to_tf_placeholder(name='obs',
                                                          batch_dims=2)
            action_var = action_space.to_tf_placeholder(name='action',
                                                        batch_dims=2)
            reward_var = tf.compat.v1.placeholder(tf.float32,
                                                  shape=[None, None],
                                                  name='reward')
            valid_var = tf.compat.v1.placeholder(tf.float32,
                                                 shape=[None, None],
                                                 name='valid')
            baseline_var = tf.compat.v1.placeholder(tf.float32,
                                                    shape=[None, None],
                                                    name='baseline')

            policy_state_info_vars = {
                k: tf.compat.v1.placeholder(tf.float32,
                                            shape=[None] * 2 + list(shape),
                                            name=k)
                for k, shape in self.policy.state_info_specs
            }
            policy_state_info_vars_list = [
                policy_state_info_vars[k] for k in self.policy.state_info_keys
            ]

        augmented_obs_var = obs_var
        for k in self.policy.state_info_keys:
            extra_state_var = policy_state_info_vars[k]
            extra_state_var = tf.cast(extra_state_var, tf.float32)
            augmented_obs_var = tf.concat([augmented_obs_var, extra_state_var],
                                          -1)

        self._policy_network = self.policy.build(augmented_obs_var,
                                                 name='policy')
        self._old_policy_network = self._old_policy.build(augmented_obs_var,
                                                          name='policy')

        policy_loss_inputs = graph_inputs(
            'PolicyLossInputs',
            action_var=action_var,
            reward_var=reward_var,
            baseline_var=baseline_var,
            valid_var=valid_var,
            policy_state_info_vars=policy_state_info_vars,
        )
        policy_opt_inputs = graph_inputs(
            'PolicyOptInputs',
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            baseline_var=baseline_var,
            valid_var=valid_var,
            policy_state_info_vars_list=policy_state_info_vars_list,
        )

        return policy_loss_inputs, policy_opt_inputs