Esempio n. 1
0
    def get_random_actions(self, obs, available_actions=None):
        batch_size = obs.shape[0]
        if available_actions is not None:
            logits = torch.ones(batch_size, self.act_dim)
            random_actions = avail_choose(logits, available_actions)
            random_actions = random_actions.sample()
            random_actions = make_onehot(random_actions, batch_size,
                                         self.act_dim).cpu().numpy()
        else:
            if self.discrete_action:
                if self.multidiscrete:
                    random_actions = [
                        OneHotCategorical(logits=torch.ones(
                            batch_size, self.act_dim[i])).sample().numpy()
                        for i in range(len(self.act_dim))
                    ]
                    random_actions = np.concatenate(random_actions, axis=-1)
                else:
                    random_actions = OneHotCategorical(logits=torch.ones(
                        batch_size, self.act_dim)).sample().numpy()
            else:
                random_actions = np.random.uniform(self.act_space.low,
                                                   self.act_space.high,
                                                   size=(batch_size,
                                                         self.act_dim))

        return random_actions
Esempio n. 2
0
    def get_random_actions_discrete(self, obs, available_actions=None):
        assert len(obs.shape) == 2, "No random actions on sequence"
        batch_size = obs.shape[0]
        if available_actions is not None:
            logits = torch.ones(batch_size, self.act_dim)
            random_actions = avail_choose(logits, available_actions)
            random_actions = random_actions.sample()
            random_actions = make_onehot(random_actions, batch_size, self.act_dim).cpu().numpy()
        else:
            if self.multidiscrete:
                random_actions = [OneHotCategorical(logits=torch.ones(batch_size, self.act_dim[i])).sample().numpy() for
                                  i in
                                  range(len(self.act_dim))]
                random_actions = np.concatenate(random_actions, axis=-1)
            else:
                random_actions = OneHotCategorical(logits=torch.ones(batch_size, self.act_dim)).sample().numpy()

        return random_actions
Esempio n. 3
0
    def train_on_batch(self, batch, use_cent_agent_obs):
        # unpack the batch
        obs_batch, cent_obs_batch, act_batch, rew_batch, nobs_batch, cent_nobs_batch, dones_batch, avail_act_batch, navail_act_batch = batch

        if use_cent_agent_obs:
            agent_0_pol = self.policy_mapping_fn(self.agent_ids[0])
            cent_obs_batch = torch.FloatTensor(cent_obs_batch[agent_0_pol][0]).to(self.device)
            cent_nobs_batch = torch.FloatTensor(cent_nobs_batch[agent_0_pol][0]).to(self.device)
        else:
            cent_obs_batch = torch.FloatTensor(cent_obs_batch[self.policy_ids[0]]).to(self.device)
            cent_nobs_batch = torch.FloatTensor(cent_nobs_batch[self.policy_ids[0]]).to(self.device)

        rew_batch = torch.FloatTensor(rew_batch[self.policy_ids[0]]).to(self.device)
        dones_batch = torch.FloatTensor(dones_batch['env']).to(self.device)

        # individual agent q value sequences: each element is of shape (ep_len, batch_size, 1)
        agent_q_sequences = []
        # individual agent next step q value sequences
        agent_next_q_sequences = []
        batch_size = None

        for p_id in self.policy_ids:
            # get data related to the policy id
            curr_obs_batch = torch.FloatTensor(obs_batch[p_id]).to(self.device)
            curr_act_batch = torch.FloatTensor(act_batch[p_id]).to(self.device)
            curr_nobs_batch = torch.FloatTensor(nobs_batch[p_id]).to(self.device)
        
            # stack over agents to process them all at once
            stacked_act_batch = torch.cat(list(curr_act_batch), dim=-2)
            stacked_obs_batch = torch.cat(list(curr_obs_batch), dim=-2)
            stacked_nobs_batch = torch.cat(list(curr_nobs_batch), dim=-2)

            if navail_act_batch[p_id] is not None:
                curr_navail_act_batch = torch.FloatTensor(navail_act_batch[p_id]).to(self.device)
                stacked_navail_act_batch = torch.cat(list(curr_navail_act_batch), dim=-2)
            else:
                stacked_navail_act_batch = None
            

            policy = self.policies[p_id]
            batch_size = curr_obs_batch.shape[2]
            total_batch_size = batch_size * len(self.policy_agents[p_id])
            seq_len = curr_obs_batch.shape[1]
            # form previous action sequence and get all q values for every possible action
            if isinstance(policy.act_dim, np.ndarray):
                # multidiscrete case
                sum_act_dim = int(sum(policy.act_dim))
            else:
                sum_act_dim = policy.act_dim
            pol_prev_act_buffer_seq = torch.cat((torch.zeros(1, total_batch_size, sum_act_dim).to(self.device),
                                                 stacked_act_batch[:-1]))
            pol_all_q_out_sequence, pol_final_hidden = policy.get_q_values(stacked_obs_batch, pol_prev_act_buffer_seq, policy.init_hidden(-1, total_batch_size))

            if isinstance(pol_all_q_out_sequence, list):
                # multidiscrete case
                ind = 0
                Q_per_part = []
                for i in range(len(policy.act_dim)):
                    curr_stacked_act_batch = stacked_act_batch[:, :, ind : ind + policy.act_dim[i]]
                    curr_stacked_act_batch_ind = curr_stacked_act_batch.max(dim=-1)[1]
                    curr_all_q_out_sequence = pol_all_q_out_sequence[i]
                    curr_pol_q_out_sequence = torch.gather(curr_all_q_out_sequence, 2, curr_stacked_act_batch_ind.unsqueeze(dim=-1))
                    Q_per_part.append(curr_pol_q_out_sequence)
                    ind += policy.act_dim[i]
                Q_sequence_combined_parts = torch.cat(Q_per_part, dim=-1)
                pol_agents_q_out_sequence = Q_sequence_combined_parts.split(split_size=batch_size, dim=-2)
            else:
                # get the q values associated with the action taken acording ot the batch
                stacked_act_batch_ind = stacked_act_batch.max(dim=-1)[1]
                pol_q_out_sequence = torch.gather(pol_all_q_out_sequence, 2, stacked_act_batch_ind.unsqueeze(dim=-1))
                # separate into agent q sequences for each agent, then cat along the final dimension to prepare for mixer input
                pol_agents_q_out_sequence = pol_q_out_sequence.split(split_size=batch_size, dim=-2)
            agent_q_sequences.append(torch.cat(pol_agents_q_out_sequence, dim=-1))

            target_policy = self.target_policies[p_id]
            with torch.no_grad():
                if isinstance(target_policy.act_dim, np.ndarray):
                    # multidiscrete case
                    sum_act_dim = int(sum(target_policy.act_dim))
                else:
                    sum_act_dim = target_policy.act_dim
                _, new_target_hiddens = target_policy.get_q_values(stacked_obs_batch[0], torch.zeros(total_batch_size, sum_act_dim).to(self.device), target_policy.init_hidden(-1, total_batch_size))

                if self.args.double_q:
                    # actions come from live q; get the q values for the final nobs
                    pol_final_qs, _ = policy.get_q_values(stacked_nobs_batch[-1], stacked_act_batch[-1], pol_final_hidden)

                    if type(pol_final_qs) == list:
                        # multidiscrete case
                        assert stacked_navail_act_batch is None, "Available actions not supported for multidiscrete"
                        pol_nacts = []
                        for i in range(len(pol_final_qs)):
                            pol_final_curr_qs = pol_final_qs[i]
                            pol_all_curr_q_out_seq = pol_all_q_out_sequence[i]
                            pol_all_nq_out_curr_seq = torch.cat((pol_all_curr_q_out_seq[1:], pol_final_curr_qs[None]))
                            pol_curr_nacts = pol_all_nq_out_curr_seq.max(dim=-1)[1]
                            curr_act_dim = policy.act_dim[i]
                            pol_curr_nacts = make_onehot(pol_curr_nacts, total_batch_size, curr_act_dim, seq_len=seq_len)
                            pol_nacts.append(pol_curr_nacts)
                        pol_nacts = torch.cat(pol_nacts, dim=-1)
                        targ_pol_nq_seq, _ = target_policy.get_q_values(stacked_nobs_batch, stacked_act_batch, new_target_hiddens, action_batch=pol_nacts)
                    else:
                        # cat to form all the next step qs
                        pol_all_nq_out_sequence = torch.cat((pol_all_q_out_sequence[1:], pol_final_qs[None]))
                        # mask out the unavailable actions
                        if stacked_navail_act_batch is not None:
                            pol_all_nq_out_sequence[stacked_navail_act_batch == 0.0] = -1e10
                        # greedily choose actions which maximize the q values and convert these actions to onehot
                        pol_nacts = pol_all_nq_out_sequence.max(dim=-1)[1]
                        if isinstance(policy.act_dim, np.ndarray):
                            # multidiscrete case
                            sum_act_dim = int(sum(policy.act_dim))
                        else:
                            sum_act_dim = policy.act_dim
                        pol_nacts = make_onehot(pol_nacts, total_batch_size, sum_act_dim, seq_len=seq_len)

                        # q values given by target but evaluated at actions taken by live
                        targ_pol_nq_seq, _ = target_policy.get_q_values(stacked_nobs_batch, stacked_act_batch, new_target_hiddens, action_batch=pol_nacts)

                else:
                    # just choose actions from target policy
                    _, targ_pol_nq_seq, _, _ = target_policy.get_actions(stacked_nobs_batch, stacked_act_batch, new_target_hiddens, t_env=None, available_actions=stacked_navail_act_batch, explore=False)

                # separate the next qs into sequences for each agent
                pol_agents_nq_sequence = targ_pol_nq_seq.split(split_size=batch_size, dim=-2)
            # cat target qs along the final dim
            agent_next_q_sequences.append(torch.cat(pol_agents_nq_sequence, dim=-1))

        # combine the agent q value sequences to feed into mixer networks
        agent_q_sequences = torch.cat(agent_q_sequences, dim=-1)
        agent_next_q_sequences = torch.cat(agent_next_q_sequences, dim=-1)

        # store the sequences of predicted and next step Q_tot values to form Bellman errors
        predicted_Q_tot_vals = []
        next_step_Q_tot_vals = []
        for t in range(len(agent_q_sequences)):
            curr_state = cent_obs_batch[t]  # global state should be same across agents
            next_state = cent_nobs_batch[t]
            curr_agent_qs = agent_q_sequences[t]
            next_step_agent_qs = agent_next_q_sequences[t]

            curr_Q_tot = self.mixer(curr_agent_qs, curr_state)
            next_step_Q_tot = self.target_mixer(next_step_agent_qs, next_state)

            predicted_Q_tot_vals.append(curr_Q_tot.squeeze(-1))

            next_step_Q_tot_vals.append(next_step_Q_tot.squeeze(-1))

        # stack over time dimension
        predicted_Q_tot_vals = torch.stack(predicted_Q_tot_vals)
        next_step_Q_tot_vals = torch.stack(next_step_Q_tot_vals)
        # all agents must share reward, so get the reward sequence for an agent
        rewards = rew_batch[0]
        # get the done sequence for the env
        dones = dones_batch
        # form bootstrapped targets
        Q_tot_targets = rewards + (1 - dones.float()) * self.args.gamma * next_step_Q_tot_vals
        # form mask to mask out sequence elements corresponding to states at which the episode already ended
        curr_dones_mask = torch.cat(
            (torch.zeros(1, batch_size, 1).float().to(self.device), dones[:self.episode_length - 1, :, :]))

        predicted_Q_tots = predicted_Q_tot_vals * (1 - curr_dones_mask)

        Q_tot_targets = Q_tot_targets * (1 - curr_dones_mask)
        # loss is MSE Bellman Error
        loss = (((predicted_Q_tots - Q_tot_targets.detach()) ** 2).sum()) / (1 - curr_dones_mask).sum()
        self.optimizer.zero_grad()
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.parameters, self.args.grad_norm_clip)
        self.optimizer.step()

        return loss, grad_norm, predicted_Q_tots.mean()
Esempio n. 4
0
    def get_actions(self,
                    observation_batch,
                    prev_action_batch,
                    hidden_states,
                    t_env,
                    available_actions=None,
                    explore=True,
                    warmup=False):
        """
        get actions in epsilon-greedy manner, if specified
        """

        if len(observation_batch.shape) == 2:
            batch_size = observation_batch.shape[0]
            no_sequence = True
        else:
            batch_size = observation_batch.shape[1]
            seq_len = observation_batch.shape[0]
            no_sequence = False

        q_values, new_hidden_states = self.get_q_values(
            observation_batch, prev_action_batch, hidden_states)
        # mask the available actions by giving -inf q values to unavailable actions
        if available_actions is not None:
            q_values[available_actions == 0.0] = -1e10
        #greedy_Qs, greedy_actions = list(map(lambda a: a.max(dim=-1), q_values))
        if self.multidiscrete:
            onehot_actions = []
            greedy_Qs = []
            for i in range(len(self.act_dim)):
                greedy_Q, greedy_action = q_values[i].max(dim=-1)

                if explore:
                    assert no_sequence, "Can only explore on non-sequences"
                    if warmup:
                        eps = 1.0
                    else:
                        eps = self.schedule.eval(t_env)
                    # rand_like samples from Uniform[0, 1]
                    rand_number = torch.rand_like(observation_batch[:, 0])
                    # random actions sample uniformly from action space
                    random_action = Categorical(logits=torch.ones(
                        batch_size, self.act_dim[i])).sample().to(self.device)
                    take_random = (rand_number < eps).float()

                    action = (1.0 - take_random) * greedy_action.float(
                    ) + take_random * random_action
                    onehot_action = make_onehot(action.cpu(), batch_size,
                                                self.act_dim[i]).to(
                                                    self.device).detach()
                else:
                    greedy_Q = greedy_Q.unsqueeze(-1)
                    if no_sequence:
                        onehot_action = make_onehot(greedy_action.cpu(),
                                                    batch_size,
                                                    self.act_dim[i])
                    else:
                        onehot_action = make_onehot(greedy_action.cpu(),
                                                    batch_size,
                                                    self.act_dim[i],
                                                    seq_len=seq_len)

                onehot_actions.append(onehot_action)
                greedy_Qs.append(greedy_Q)

            onehot_actions = torch.cat(onehot_actions, dim=-1)
            greedy_Qs = torch.cat(greedy_Qs, dim=-1)

            if explore:
                return onehot_actions, greedy_Qs, new_hidden_states.detach(
                ), eps
            else:
                return onehot_actions, greedy_Qs, new_hidden_states, None
        else:
            greedy_Qs, greedy_actions = q_values.max(dim=-1)
            if explore:
                assert no_sequence, "Can only explore on non-sequences"
                if warmup:
                    eps = 1.0
                else:
                    eps = self.schedule.eval(t_env)
                # rand_like samples from Uniform[0, 1]
                rand_numbers = torch.rand_like(observation_batch[:, 0])
                # random actions sample uniformly from action space
                logits = torch.ones_like(prev_action_batch)
                random_actions = avail_choose(logits,
                                              available_actions).sample().to(
                                                  self.device).float()
                take_random = (rand_numbers < eps).float()

                actions = (1.0 - take_random) * greedy_actions.float(
                ) + take_random * random_actions
                return make_onehot(actions.cpu(), batch_size, self.act_dim).to(
                    self.device).detach(), greedy_Qs, new_hidden_states.detach(
                    ), eps

            else:
                greedy_Qs = greedy_Qs.unsqueeze(-1)
                if no_sequence:
                    return make_onehot(
                        greedy_actions, batch_size,
                        self.act_dim), greedy_Qs, new_hidden_states, None
                else:
                    return make_onehot(
                        greedy_actions,
                        batch_size,
                        self.act_dim,
                        seq_len=seq_len), greedy_Qs, new_hidden_states, None
Esempio n. 5
0
    def get_actions(self,
                    obs,
                    prev_actions,
                    actor_rnn_states,
                    available_actions=None,
                    use_target=False,
                    t_env=None,
                    use_gumbel=False,
                    explore=False):

        assert prev_actions is None or len(obs.shape) == len(
            prev_actions.shape)
        # obs is either an array of shape (batch_size, obs_dim) or (seq_len, batch_size, obs_dim)
        if len(obs.shape) == 2:
            batch_size = obs.shape[0]
            no_sequence = True
        else:
            batch_size = obs.shape[1]
            no_sequence = False

        eps = None
        if use_target:
            actor_out, new_rnn_states = self.target_actor(
                obs, prev_actions, actor_rnn_states)
        else:
            actor_out, new_rnn_states = self.actor(obs, prev_actions,
                                                   actor_rnn_states)

        if self.discrete_action:
            if self.multidiscrete:
                if use_gumbel or explore or use_target:
                    onehot_actions = list(
                        map(lambda a: gumbel_softmax(a, hard=True), actor_out))
                else:
                    onehot_actions = list(map(onehot_from_logits, actor_out))

                onehot_actions = torch.cat(onehot_actions, dim=-1)
                if explore:
                    # eps greedy exploration
                    batch_size = obs.shape[0]
                    eps = self.exploration.eval(t_env)
                    rand_numbers = torch.rand((batch_size, 1))
                    take_random = (rand_numbers < eps).int().view(-1, 1)

                    # random actions sample uniformly from action space
                    random_actions = [
                        OneHotCategorical(logits=torch.ones(
                            batch_size, self.act_dim[i])).sample()
                        for i in range(len(self.act_dim))
                    ]
                    random_actions = torch.cat(random_actions, dim=1)
                    actions = (
                        1 - take_random
                    ) * onehot_actions + take_random * random_actions
                else:
                    actions = onehot_actions
            else:
                if use_gumbel or explore or use_target:
                    onehot_actions = gumbel_softmax(
                        actor_out, available_actions,
                        hard=True)  # gumbel has a gradient
                else:
                    onehot_actions = onehot_from_logits(
                        actor_out, available_actions)  # no gradient

                if explore:
                    assert no_sequence, "Doesn't make sense to do exploration on a sequence!"
                    # eps greedy exploration
                    eps = self.exploration.eval(t_env)
                    rand_numbers = np.random.rand(batch_size, 1)
                    # random actions sample uniformly from action space
                    logits = torch.ones(batch_size, self.act_dim)
                    random_actions = avail_choose(logits,
                                                  available_actions).sample()
                    random_actions = make_onehot(random_actions, batch_size,
                                                 self.act_dim)
                    take_random = (rand_numbers < eps).astype(float)
                    actions = (
                        1.0 - take_random) * onehot_actions.detach().cpu(
                        ).numpy() + take_random * random_actions.cpu().numpy()
                else:
                    actions = onehot_actions
        else:
            if explore:
                assert no_sequence, "Cannot do exploration on a sequence!"
                actions = gaussian_noise(actor_out.shape,
                                         self.args.act_noise_std) + actor_out
            elif use_target:
                target_noise = gaussian_noise(
                    actor_out.shape, self.args.target_noise_std).clamp(
                        -self.args.target_noise_clip,
                        self.args.target_noise_clip)
                actions = actor_out + target_noise
            else:
                actions = actor_out
            # # clip the actions at the bounds of the action space
            # actions = torch.max(torch.min(actions, torch.from_numpy(self.act_space.high)), torch.from_numpy(self.act_space.low))

        return actions, new_rnn_states, eps