Ejemplo n.º 1
0
    def get_action(self, state, add_state, hxs, masks, step_info):
        should_resets = [True if m == 0.0 else False for m in masks]
        # Reset the noise for the beginning of every episode.
        for should_reset, noise_gen in zip(should_resets, self.noise_gens):
            if should_reset:
                noise_gen.reset_states()

        n_procs = state.shape[0]

        action = self.forward(state, add_state, hxs, masks)
        if not step_info.is_eval:
            cur_step = step_info.cur_num_steps
            if (cur_step >= self.args.n_rnd_steps) and (np.random.rand() >= self.args.rnd_prob):
                # Get the added noise.
                noise = torch.FloatTensor([ng.sample(cur_step)
                    for ng in self.noise_gens]).to(self.args.device)
                action += noise

                # Multi-dimensional clamp the action to the action space range.
                action = torch.min(torch.max(action, self.ac_low_bound), self.ac_high_bound)
            else:
                action = torch.tensor([self.action_space.sample()
                    for _ in range(n_procs)]).to(self.args.device)

        return create_simple_action_data(action, hxs)
Ejemplo n.º 2
0
 def get_action(self, state, add_state, hxs, masks, step_info):
     dist = self.forward(state, add_state, hxs, masks)
     if step_info.is_eval:
         action = torch.argmax(dist, dim=1)
     else:
         action = torch.distributions.Categorical(dist).sample()
     return create_simple_action_data(action, hxs)
Ejemplo n.º 3
0
    def get_action(self, state, add_state, hxs, masks, step_info):
        n_procs = state.shape[0]
        cur_step = step_info.cur_num_steps

        if not step_info.is_eval and cur_step < self.args.n_rnd_steps:
            action = torch.tensor([self.action_space.sample()
                for _ in range(n_procs)]).to(self.args.device)
            return create_simple_action_data(action, hxs)

        dist = self.forward(state, add_state, hxs, masks)
        if step_info.is_eval:
            action = dist.mean
        else:
            action = dist.sample()

        action = torch.min(torch.max(action, self.ac_low_bound), self.ac_high_bound)

        return create_simple_action_data(action, hxs)
Ejemplo n.º 4
0
    def get_action(self, state, add_state, hxs, masks, step_info):
        n_procs = rutils.get_def_obs(state).shape[0]
        action = torch.tensor([
            self.action_space.sample() for _ in range(n_procs)
        ]).to(self.args.device)
        if isinstance(self.action_space, spaces.Discrete):
            action = action.unsqueeze(-1)

        return create_simple_action_data(action, hxs)
Ejemplo n.º 5
0
    def get_action(self, state, add_state, rnn_hxs, mask, step_info):
        base_features, _ = self.base_net(state, rnn_hxs, mask)

        ret_action = self.action_head(base_features)
        if step_info.is_eval or not self.is_stoch:
            ret_action = rutils.get_ac_compact(self.action_space, ret_action)
        else:
            if rutils.is_discrete(self.action_space):
                dist = torch.distributions.Categorical(
                    ret_action.softmax(dim=-1))
            else:
                std = self.std(base_features)
                dist = torch.distributions.Normal(ret_action, std)
            ret_action = dist.sample()

        return create_simple_action_data(ret_action, rnn_hxs)
Ejemplo n.º 6
0
    def get_action(self, state, hxs, masks, step_info):
        if self.is_first:
            masks = torch.zeros(masks.shape)

        sel_actions = []
        for i, mask in enumerate(masks):
            if mask == 0.0:
                if len(self.all_actions[i]) != 0:
                    # We should have exhaused all of the actions
                    assert self.ep_idx[i] == len(self.all_actions[i]) - 1
                # The env has reset, solve it.
                actions = self._solve_env(state[i])
                self.all_actions[i] = actions
                self.ep_idx[i] = 0
            else:
                self.ep_idx[i] += 1
            sel_actions.append(self.all_actions[i][self.ep_idx[i]])
        sel_actions = torch.tensor(sel_actions).unsqueeze(-1)
        self.is_first = False
        return create_simple_action_data(sel_actions)
Ejemplo n.º 7
0
    def get_action(self, state, add_state, hxs, masks, step_info):
        if step_info.is_eval:
            eps_threshold = 0
        else:
            num_steps = step_info.cur_num_steps
            eps_threshold = self.args.eps_end + \
                (self.args.eps_start - self.args.eps_end) * \
                math.exp(-1.0 * num_steps / self.args.eps_decay)

        sample = random.random()
        if sample > eps_threshold:
            q_vals = self.forward(state, add_state, hxs, masks)
            ret_action = q_vals.max(1)[1].unsqueeze(-1)
        else:
            # Take a random action.
            ret_action = torch.LongTensor([[random.randrange(self.action_space.n)]
                for i in range(state.shape[0])]).to(self.args.device)

        return create_simple_action_data(ret_action, hxs, {
            'alg_add_eps': eps_threshold
            })