Ejemplo n.º 1
0
    def get_action(self, belief, state, det=False):
        actor_out = self.forward(belief, state)
        if self._dist == 'tanh_normal':
            # actor_out.size() == (N x (action_size * 2))
            # replace the below workaround
            raw_init_std = np.log(np.exp(self._init_std) - 1)
            # tmp = torch.tensor(self._init_std,
            #                    device=actor_out.get_device())
            # raw_init_std = torch.log(torch.exp(tmp) - 1)
            action_mean, action_std_dev = torch.chunk(actor_out, 2, dim=1)
            action_mean = self._mean_scale * torch.tanh(
                action_mean / self._mean_scale)
            action_std = F.softplus(action_std_dev +
                                    raw_init_std) + self._min_std

            dist = Normal(action_mean, action_std)
            dist = TransformedDistribution(dist, TanhBijector())
            dist = torch.distributions.Independent(dist, 1)
            dist = SampleDist(dist)
        elif self._dist == 'onehot':
            # actor_out.size() == (N x action_size)
            # fix for RuntimeError: CUDA error: device-side assert triggered
            actor_out = (torch.tanh(actor_out) + 1.0) * 0.5
            dist = Categorical(logits=actor_out)
            dist = OneHotDist(dist)
        else:
            raise NotImplementedError(self._dist)
        if det:
            return dist.mode()
        else:
            return dist.sample()
Ejemplo n.º 2
0
    def get_action(self, belief, posterior_state, explore=False, det=False):
        state = posterior_state
        B, H, Z = belief.size(0), belief.size(1), state.size(1)

        actions_l_mean_lists, actions_l_std_lists = self.get_action_sequence(
            belief, state, B)

        belief, state = belief.unsqueeze(dim=1).expand(
            B, self.candidates,
            H).reshape(-1, H), state.unsqueeze(dim=1).expand(
                B, self.candidates, Z).reshape(-1, Z)

        # Initialize factorized belief over action sequences q(a_t:t+H) ~ N(0, I)
        # action_mean, action_std_dev = torch.zeros(self.planning_horizon, B, 1, self.action_size, device=belief.device), torch.ones(self.planning_horizon, B, 1, self.action_size, device=belief.device)
        action_mean, action_std_dev = None, None
        for _ in range(self.optimisation_iters):
            # print("optimization_iters",_)
            # Evaluate J action sequences from the current belief (over entire sequence at once, batched over particles)
            if _ == 0:
                sub_action_list = []
                for id in range(len(self.actor_pool)):
                    # a = self.candidates//len(self.actor_pool)
                    action = (
                        actions_l_mean_lists[id] + actions_l_std_lists[id] *
                        torch.randn(self.top_planning_horizon,
                                    B,
                                    self.candidates // len(self.actor_pool),
                                    self.action_size,
                                    device=belief.device)
                    ).view(
                        self.top_planning_horizon,
                        B * self.candidates // len(self.actor_pool),
                        self.action_size
                    )  # Sample actions (time x (batch x candidates) x actions)
                    sub_action_list.append(action)
                actions = torch.cat(sub_action_list, dim=1)
            else:
                actions = (action_mean + action_std_dev * torch.randn(
                    self.top_planning_horizon,
                    B,
                    self.candidates,
                    self.action_size,
                    device=belief.device)).view(
                        self.top_planning_horizon, B * self.candidates,
                        self.action_size
                    )  # Sample actions (time x (batch x candidates) x actions)
            # Sample next states

            beliefs, states, _, _ = self.upper_transition_model(
                state, actions, belief)

            # if args.MultiGPU:
            #   actions_trans = torch.transpose(actions, 0, 1).cuda()
            #   beliefs, states, _, _ = self.transition_model(state, actions_trans, belief)
            #   beliefs, states = list(map(lambda x: x.view(-1, self.candidates, x.shape[2]), [beliefs, states]))
            #
            # else:
            #   beliefs, states, _, _ = self.transition_model(state, actions, belief)
            # beliefs, states, _, _ = self.transition_model(state, actions, belief)# [12, 1000, 200] [12, 1000, 30] : 12 horizon steps; 1000 candidates

            # Calculate expected returns (technically sum of rewards over planning horizon)
            returns = self.reward_model(beliefs.view(-1, H), states.view(
                -1, Z
            )).view(self.top_planning_horizon, -1).sum(
                dim=0)  # output from r-model[12000]->view[12, 1000]->sum[1000]
            # Re-fit belief to the K best action sequencessetting -> Repositories
            _, topk = returns.reshape(B, self.candidates).topk(
                self.top_candidates, dim=1, largest=True, sorted=False)
            topk += self.candidates * torch.arange(
                0, B, dtype=torch.int64, device=topk.device).unsqueeze(
                    dim=1)  # Fix indices for unrolled actions
            best_actions = actions[:, topk.view(-1)].reshape(
                self.top_planning_horizon, B, self.top_candidates,
                self.action_size)
            # Update belief with new means and standard deviations
            action_mean, action_std_dev = best_actions.mean(
                dim=2, keepdim=True), best_actions.std(dim=2,
                                                       unbiased=False,
                                                       keepdim=True)

        # Return sample action from distribution

        dist = Normal(action_mean[0].squeeze(dim=1),
                      action_std_dev[0].squeeze(dim=1))
        dist = TransformedDistribution(dist, TanhBijector())
        dist = torch.distributions.Independent(dist, 1)
        dist = SampleDist(dist)
        if det:
            tmp = dist.mode()
            return tmp
        else:
            tmp = dist.rsample()
            return tmp