예제 #1
0
    def get_value(self, rollouts, args):
        if self.use_embedding:
            inputs_actor = self._embed_if_needed(rollouts.obs[-1])
            inputs_critic = rollouts.belief[-1]
        else:
            if args.algo in ['ab-cb']:
                inputs_actor = rollouts.belief[-1]
                inputs_critic = rollouts.belief[-1]

            if args.algo in ['ah-cb']:
                inputs_actor = rollouts.obs[-1]
                inputs_critic = rollouts.belief[-1]

            if args.algo in ['ah-chs']:
                inputs_actor = rollouts.obs[-1]
                inputs_critic = rollouts.state[-1]

            if args.algo in ['ah-csb']:
                inputs_actor = rollouts.obs[-1]
                inputs_critic = bind(rollouts.state[-1], rollouts.belief[-1])

        actor_rnn_hxs = rollouts.actor_rnn_states[-1]
        critic_rnn_hxs = rollouts.critic_rnn_states[-1]
        masks = rollouts.masks[-1]

        value, _, _, _, _ = self.base(inputs_actor, inputs_critic,
                                      actor_rnn_hxs, critic_rnn_hxs, masks)
        return value
예제 #2
0
 def _embed_if_needed(self, obs):
     if not self.use_embedding:
         return obs
     actions_embedded = self.action_embed(obs[:, 1])
     obs_embedded = self.obs_embed(obs[:, 0])
     return bind(actions_embedded,
                 obs_embedded).reshape(-1, 2 * self.embed_size)
예제 #3
0
    def _embed_if_needed(self, obs):
        if not self.use_embedding:
            return obs

        # The past `n_reactive` observations have already been concatenated together, but
        # we don't want the embedding layer to treat them as a single mega-observation.
        # So we separate them here, and then concatenate them again after the embedding.
        # Note the shape is unaffected when n=1, which is exactly the behavior we want.

        n = self.n_reactive
        obs = obs.view(obs.shape[0] * n, -1)
        actions_embedded = self.action_embed(obs[:, 1])
        obs_embedded = self.obs_embed(obs[:, 0])
        return bind(actions_embedded,
                    obs_embedded).reshape(-1, 2 * self.embed_size * n)
예제 #4
0
    def act(self, rollouts, step, args, deterministic=False):
        if self.use_embedding:
            inputs_actor = self._embed_if_needed(rollouts.obs[step])
            inputs_critic = rollouts.belief[step]
        else:
            if args.algo in ['ab-cb', 'bc']:
                inputs_actor = rollouts.belief[step]
                inputs_critic = rollouts.belief[step]

            if args.algo in ['ah-cb']:
                inputs_actor = rollouts.obs[step]
                inputs_critic = rollouts.belief[step]

            if args.algo in ['ah-chs']:
                inputs_actor = rollouts.obs[step]
                inputs_critic = rollouts.state[step]

            if args.algo in ['ah-csb']:
                inputs_actor = rollouts.obs[step]
                inputs_critic = bind(rollouts.state[step],
                                     rollouts.belief[step])

        actor_rnn_hxs = rollouts.actor_rnn_states[step]
        critic_rnn_hxs = rollouts.critic_rnn_states[step]
        masks = rollouts.masks[step]

        value, actor_features, _, actor_rnn_hxs, critic_rnn_hxs = self.base(
            inputs_actor, inputs_critic, actor_rnn_hxs, critic_rnn_hxs, masks)
        dist = self.dist(actor_features)

        if deterministic:
            action = dist.mode()
        else:
            action = dist.sample()

        action_log_probs = dist.log_probs(action)

        return value, action, action_log_probs, actor_rnn_hxs, critic_rnn_hxs