コード例 #1
0
    def forward(self, policies, advantages, actions, seq_lens, tformat, *args, **kwargs):

        assert tformat in ["a*bs*t*v"], "invalid input format!"

        policies = policies.clone()
        advantages = advantages.clone().detach().unsqueeze(0).repeat(policies.shape[0],1,1,1)
        actions = actions.clone()

        # last elements of advantages are NaNs
        mask = advantages.clone().fill_(1.0).byte()
        _pad_zero(mask, tformat, seq_lens)
        mask[:, :, :-1, :] = mask[:, :, 1:, :] # account for terminal NaNs of targets
        mask[:, :, -1, :] = 0.0  # handles case of seq_len=limit_len
        _pad_zero(policies, tformat, seq_lens)
        _pad_zero(actions, tformat, seq_lens)
        advantages[~mask] = 0.0

        pi_taken = th.gather(policies, _vdim(tformat), actions.long())
        pi_taken_mask = (pi_taken < 10e-40)
        log_pi_taken = th.log(pi_taken)
        log_pi_taken[pi_taken_mask] = 0.0

        loss = - log_pi_taken * advantages * mask.float()
        norm = mask.sum(_bsdim(tformat), keepdim=True)
        norm[norm == 0.0] = 1.0
        loss_mean = loss.sum(_bsdim(tformat), keepdim=True) / norm.float()
        loss_mean = loss_mean.squeeze(_vdim(tformat)).squeeze(_bsdim(tformat))

        output_tformat = "a*t"
        return loss_mean, output_tformat
コード例 #2
0
ファイル: mackrl.py プロジェクト: ewanlee/mackrl
    def forward(self, policies, advantages, tformat, seq_lens, *args,
                **kwargs):
        assert tformat in ["a*bs*t*v"], "invalid input format!"
        n_agents = kwargs["n_agents"]

        policies = policies.clone()
        advantages = advantages.clone().detach().unsqueeze(0).repeat(
            policies.shape[0], 1, 1, 1)

        _pad_zero(policies, tformat, seq_lens)
        # last elements of advantages are NaNs
        mask = advantages.clone().fill_(1.0).byte()
        _pad_zero(mask, tformat, seq_lens)
        mask[:, :, :
             -1, :] = mask[:, :, 1:, :]  # account for terminal NaNs of targets
        mask[:, :, -1, :] = 0.0  # handles case of seq_len=limit_len
        advantages[~mask.bool()] = 0.0

        _pad(policies, tformat, seq_lens, 1.0)
        policy_mask = (policies < 10e-40)
        log_policies = th.log(policies)
        log_policies[policy_mask] = 0.0

        nan_mask = policies.clone().fill_(1.0)
        _pad_zero(nan_mask, tformat, seq_lens)
        loss = -log_policies * advantages * nan_mask.float()
        norm = nan_mask.sum(_bsdim(tformat), keepdim=True)
        norm[norm == 0.0] = 1.0
        loss_mean = loss.sum(_bsdim(tformat), keepdim=True) / norm.float()
        loss_mean = loss_mean.squeeze(_vdim(tformat)).squeeze(_bsdim(tformat))

        loss_mean = loss_mean / n_agents

        output_tformat = "a*t"
        return loss_mean, output_tformat
コード例 #3
0
ファイル: basic.py プロジェクト: wwxFromTju/mackrl
    def forward(self, inputs, tformat, loss_fn=None, hidden_states=None, **kwargs):
        _check_inputs_validity(inputs, self.input_shapes, tformat)

        # Execute model branch "main"
        x, params, tformat = _to_batch(inputs["main"], tformat)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = _from_batch(x, params, tformat)

        losses = None
        if self.output_type in ["policies"]:
            log_softmax = kwargs.get("log_softmax", False)
            if log_softmax:
                x = F.log_softmax(x, dim=_vdim(tformat))
            else:
                x = F.softmax(x, dim=_vdim(tformat))
        if loss_fn is not None:
            losses, _ = loss_fn(x, tformat=tformat)

        return x, hidden_states, losses, tformat  # output, hidden state, losses
コード例 #4
0
ファイル: flounderl_runner.py プロジェクト: wwxFromTju/mackrl
    def _add_episode_stats(self, T_env, **kwargs):
        super()._add_episode_stats(T_env, **kwargs)

        test_suffix = "" if not self.test_mode else "_test"
        if self.args.obs_noise and self.test_mode:
            obs_noise_std = kwargs.get("obs_noise_std", 0.0)
            test_suffix = "_noise{}_test".format(obs_noise_std)

        tmp = self.episode_buffer["policies_level1"][0]
        entropy1 = np.nanmean(np.nansum((-th.log(tmp)*tmp).cpu().numpy(), axis=2))
        self._add_stat("policy_level1_entropy",
                       entropy1,
                       T_env=T_env,
                       suffix=test_suffix)

        for _i in range(_n_agent_pair_samples(self.n_agents)):
            tmp = self.episode_buffer["policies_level2__sample{}".format(_i)][0]
            entropy2 = np.nanmean(np.nansum((-th.log(tmp) * tmp).cpu().numpy(), axis=2))
            self._add_stat("policy_level2_entropy_sample{}".format(_i),
                           entropy2,
                           T_env=T_env,
                           suffix=test_suffix)

        #entropy3 = np.nanmean(np.nansum((-th.log(tmp) * tmp).cpu().numpy(), axis=2))
        self._add_stat("policy_level3_entropy",
                        self.episode_buffer.get_stat("policy_entropy", policy_label="policies_level3"),
                        T_env=T_env,
                        suffix=test_suffix)

        actions_level2 = []
        for i in range(self.n_agents // 2):
            actions_level2_tmp, _ = self.episode_buffer.get_col(col="actions_level2__sample{}".format(i))
            actions_level2.append(actions_level2_tmp)
        actions_level2 = th.cat(actions_level2, 0)
        delegation_rate = (th.sum(actions_level2==0.0).float() / (actions_level2.nelement() - th.sum(actions_level2!=actions_level2)).float()).item()
        self._add_stat("level2_delegation_rate",
                       delegation_rate,
                       T_env=T_env,
                       suffix=test_suffix)

        # common knowledge overlap between all agents
        overlap_all = th.sum((self.episode_buffer["obs_intersection_all"][0] > 0.0), dim=_vdim("bs*t*v")).float().mean().item()
        self._add_stat("obs_intersection_all_rate",
                       overlap_all,
                       T_env=T_env,
                       suffix=test_suffix)

        # common knowledge overlap between agent pairs
        overlaps = []
        for _pair_id, (id_1, id_2) in enumerate(_ordered_agent_pairings(self.n_agents)):
            overlap_pair = th.sum(self.episode_buffer["obs_intersection__pair{}".format(_pair_id)][0] > 0.0, dim=_vdim("bs*t*v")).float().mean().item()
            overlaps.append(overlap_pair)
            self._add_stat("obs_intersection_pair{}".format(_pair_id),
                           overlap_pair,
                           T_env=T_env,
                           suffix=test_suffix)

        self._add_stat("obs_intersection_pairs_all".format(_pair_id),
                       sum(overlaps) / float(len(overlaps)),
                       T_env=T_env,
                       suffix=test_suffix)

        # TODO: Policy entropy across levels! (Use suffix)
        return
コード例 #5
0
ファイル: mackrl.py プロジェクト: wwxFromTju/mackrl
    def forward(self, inputs, actions, hidden_states, tformat, loss_fn=None, **kwargs):
        seq_lens = kwargs["seq_lens"]
        # TODO: How do we handle the loss propagation for recurrent layers??

        # generate level 1-3 outputs
        out_level1, hidden_states_level1, losses_level1, tformat_level1 = self.model_level1(inputs=inputs["level1"]["agent_input_level1"],
                                                                                            hidden_states=hidden_states["level1"],
                                                                                            loss_fn=None, #loss_fn,
                                                                                            tformat=tformat["level1"],
                                                                                            #n_agents=self.n_agents,
                                                                                            **kwargs)

        pairwise_avail_actions = inputs["level2"]["agent_input_level2"]["avail_actions_pair"]
        ttype = th.cuda.FloatTensor if pairwise_avail_actions.is_cuda else th.FloatTensor
        delegation_avails = Variable(ttype(pairwise_avail_actions.shape[0],
                                           pairwise_avail_actions.shape[1],
                                           pairwise_avail_actions.shape[2], 1).fill_(1.0), requires_grad=False)
        pairwise_avail_actions = th.cat([delegation_avails, pairwise_avail_actions], dim=_vdim(tformat["level2"]))
        out_level2, hidden_states_level2, losses_level2, tformat_level2 = self.models["level2_{}".format(0)](inputs=inputs["level2"]["agent_input_level2"],
                                                                                                            hidden_states=hidden_states["level2"],
                                                                                                            loss_fn=None, #loss_fn,
                                                                                                            tformat=tformat["level2"],
                                                                                                            pairwise_avail_actions=pairwise_avail_actions,
                                                                                                            **kwargs)

        out_level3, hidden_states_level3, losses_level3, tformat_level3 = self.models["level3_{}".format(0)](inputs["level3"]["agent_input_level3"],
                                                                                                            hidden_states=hidden_states["level3"],
                                                                                                            loss_fn=None,
                                                                                                            tformat=tformat["level3"],
                                                                                                            #seq_lens=seq_lens,
                                                                                                            **kwargs)


        # for each agent pair (a,b), calculate p_a_b = p(u_a, u_b|tau_a tau_b CK_ab) = p(u_d|CK_ab)*pi(u_a|tau_a)*pi(u_b|tau_b) + p(u_ab|CK_ab)
        # output dim of p_a_b is (n_agents choose 2) x bs x t x n_actions**2

        # Bulk NaN masking
        # out_level1[out_level1 != out_level1] = 0.0
        # out_level2[out_level2 != out_level2] = 0.0
        # out_level3[out_level3 != out_level3] = 0.0
        # actions = actions.detach()
        # actions[actions!=actions] = 0.0

        p_d = out_level2[:, :, :, 0:1]
        p_ab = out_level2[:, :, :, 1:-1]

        _actions = actions.clone()
        _actions[actions != actions] = 0.0
        pi = out_level3
        pi_actions_selected = out_level3.gather(_vdim(tformat_level3), _actions.long()).clone()
        # pi_actions_selected[pi_actions_selected  != pi_actions_selected ] = 0.0 #float("nan")
        pi_actions_selected_mask = (actions != actions)
        avail_actions_level3 = inputs["level3"]["agent_input_level3"]["avail_actions"].clone().data
        avail_actions_selected = avail_actions_level3.gather(_vdim(tformat_level3), _actions.long()).clone()

        # RENORMALIZE THIS STUFF
        pi_a_cross_pi_b_list = []
        pi_ab_list = []
        pi_corr_list = []

        # DEBUG: if prob is 0 throw a debug!
        # debug
        # 0 (0, 1)
        # 1 (0, 2)
        # 2 (1, 2)

        for _i, (_a, _b) in enumerate(self.ordered_agent_pairings): #_ordered_agent_pairings(self.n_agents)[:self.args.n_pair_samples] if hasattr("n_pair_samples", self.args) else _ordered_agent_pairings(self.n_agents)): #(_ordered_agent_pairings(self.n_agents)):
            # calculate pi_a_cross_pi_b # TODO: Set disallowed joint actions to NaN!
            pi_a = pi[_a:_a+1].clone()
            pi_b = pi[_b:_b+1].clone()
            x, params_x, tformat_x = _to_batch(out_level3[_a:_a+1], tformat["level3"])
            y, params_y, tformat_y  = _to_batch(out_level3[_b:_b+1], tformat["level3"])
            actions_masked = actions.clone()
            actions_masked[actions!=actions] = 0.0
            _actions_x, _actions_params, _actions_tformat = _to_batch(actions_masked[_a:_a+1].clone(), tformat["level3"])
            _actions_y, _actions_params, _actions_tformat = _to_batch(actions_masked[_b:_b+1].clone(), tformat["level3"])
            _x = x.gather(1, _actions_x.long())
            _y = y.gather(1, _actions_y.long())
            z = _x * _y
            u = _from_batch(z, params_x, tformat_x)
            pi_a_cross_pi_b_list.append(u)
            # calculate p_ab_selected
            _p_ab = p_ab[_i:_i+1].clone()

            joint_actions = _action_pair_2_joint_actions((actions_masked[_a:_a+1], actions_masked[_b:_b+1]), self.n_actions)
            _z = _p_ab.gather(_vdim(tformat_level2), joint_actions.long())
            # Set probabilities corresponding to jointly-disallowed actions to 0.0
            avail_flags = pairwise_avail_actions[_i:_i+1].gather(_vdim(tformat_level2), joint_actions.long() + 1).clone()
            _z[avail_flags==0.0] = 0.0 # TODO: RENORMALIZE?
            pi_ab_list.append(_z)
            if not ((hasattr(self.args, "mackrl_always_delegate") and self.args.mackrl_always_delegate) or \
                    (hasattr(self.args, "mackrl_delegate_if_zero_ck") and self.args.mackrl_delegate_if_zero_ck)):
                # Calculate corrective delegation (when unavailable actions are selected)
                aa_a, _, _ = _to_batch(avail_actions_level3[_a:_a+1].clone(), tformat_level3)
                aa_b, _, _ = _to_batch(avail_actions_level3[_b:_b+1].clone(), tformat_level3)
                paa, params_paa, tformat_paa = _to_batch(pairwise_avail_actions[_i:_i+1].clone(), tformat_level3)
                _pi_a, _, _ = _to_batch(pi_a, tformat_level3)
                _pi_b, _, _ = _to_batch(pi_b, tformat_level3)
                # x = th.bmm(th.unsqueeze(aa_a, 2),  th.unsqueeze(aa_b, 1))
                # At least one action unavailable.
                diff_matrix =  th.relu(paa[:, 1:-1].view(-1, self.n_actions, self.n_actions) -
                                       th.bmm(th.unsqueeze(aa_a[:, :-1], 2),  th.unsqueeze(aa_b[:, :-1], 1)))

                diff_matrix = diff_matrix * _p_ab.view(-1, self.n_actions, self.n_actions)

                both_unavailable = th.relu(paa[:, 1:-1].view(-1, self.n_actions, self.n_actions) -
                                      th.add(th.unsqueeze(aa_a[:, :-1], 2), th.unsqueeze(aa_b[:, :-1], 1)))

                both_unavailable = both_unavailable * _p_ab.view(-1, self.n_actions, self.n_actions)
                both_unavailable_weight = th.sum(both_unavailable.view(-1, self.n_actions*self.n_actions), -1, keepdim=True)

                # If neither component of the joint action is available both get resampled, with the probability of the independent actors.
                correction = both_unavailable_weight * pi_actions_selected[_a:_a + 1].view(-1,1) * pi_actions_selected[_b:_b + 1].view(-1,1)

                act_a = actions_masked[_a:_a + 1].clone().view(-1, 1,1).long()
                act_b = actions_masked[_b:_b + 1].clone().view(-1, 1,1).long()
                b_resamples = th.sum(th.gather(diff_matrix, 1, act_a.repeat([1,1, self.n_actions])),-1)
                b_resamples = b_resamples * pi_actions_selected[_b:_b + 1].view(-1,1)

                a_resamples = th.sum(th.gather(diff_matrix, 2, act_b.repeat([1, self.n_actions,1])), 1)
                a_resamples = a_resamples * pi_actions_selected[_a:_a + 1].view(-1,1)

                correction = correction + b_resamples + a_resamples
                pi_corr_list.append(_from_batch(correction, params_paa, tformat_paa))

        pi_a_cross_pi_b = th.cat(pi_a_cross_pi_b_list, dim=0)
        pi_ab_selected = th.cat(pi_ab_list, dim=0)

        p_a_b = p_d * pi_a_cross_pi_b + pi_ab_selected
        agent_pairs = self.ordered_agent_pairings
        if not ( (hasattr(self.args, "mackrl_always_delegate") and self.args.mackrl_always_delegate) or \
               (hasattr(self.args, "mackrl_delegate_if_zero_ck") and self.args.mackrl_delegate_if_zero_ck) ):
            pi_corr = th.cat(pi_corr_list, dim=0)
            p_a_b += pi_corr # DEBUG

        pi_c_prod_list = []
        p_a_b_pair_list = []

        for pairs in self.ordered_2_agent_pairings:
            p_a_b_pair_list.append(reduce(operator.mul, [p_a_b[pair] for pair in pairs]))

            agents_in_pairs = list(itertools.chain.from_iterable([agent_pairs[pair] for pair in pairs]))

            # calculate pi_c_prod
            _pi_actions_selected = pi_actions_selected.clone()
            _pi_actions_selected[pi_actions_selected_mask] = 0.0
            for _agent in agents_in_pairs:
                _pi_actions_selected[_agent:_agent + 1] = 1.0
            _k = th.prod(_pi_actions_selected, dim=_adim(tformat_level3), keepdim=True)
            pi_c_prod_list.append(_k)

        pi_ab_pair = th.stack(p_a_b_pair_list, dim=0)
        pi_c_prod = th.cat(pi_c_prod_list, dim=0)

        # next, calculate p_a_b * prod(p, -a-b)
        p_prod = pi_ab_pair * pi_c_prod

        # now, calculate p_a_b_c
        _tmp =  out_level1.transpose(_adim(tformat_level1), _vdim(tformat_level1))
        _tmp[_tmp!=_tmp] = 0.0
        p_a_b_c = (p_prod * _tmp).sum(dim=_adim(tformat_level1), keepdim=True)

        if self.args.debug_mode in ["check_probs"]:
            if not hasattr(self, "action_table"):
                self.action_table = {}
            if not hasattr(self, "actions_sampled"):
                self.actions_sampled = 0
            actions_flat = actions.view(self.n_agents, -1)
            for id in range(actions_flat.shape[1]):
                act = tuple(actions_flat[:, id].tolist())
                if act in self.action_table:
                    self.action_table[act] += 1
                else:
                    self.action_table[act] = 0
            self.actions_sampled += actions_flat.shape[0]

        if self.args.debug_mode in ["check_probs"]:
            actions_flat = actions.view(self.n_agents, -1)
            for id in range(actions_flat.shape[1]):
                print("sampled: ",
                      self.action_table[tuple(actions_flat[:, id].tolist())] / self.actions_sampled,
                      " pred: ",
                      p_a_b_c.view(-1)[id])

        # DEBUG MODE HERE!
        # _check_nan(pi_c_prod)
        # _check_nan(p_d)
        # _check_nan(pi_ab_selected)
        # _check_nan(pi_a_cross_pi_b)
        # _check_nan(p_a_b)
        # _check_nan(p_a_b_c)

        # agent_parameters = list(self.parameters())
        # pi_c_prod.sum().backward(retain_graph=True) # p_prod throws NaN
        # _check_nan(agent_parameters)
        # p_a_b.sum().backward(retain_graph=True)
        # _check_nan(agent_parameters)
        # p_prod.sum().backward(retain_graph=True)
        # _check_nan(agent_parameters)

        hidden_states = {"level1": hidden_states_level1,
                         "level2": hidden_states_level2,
                         "level3": hidden_states_level3
                        }

        loss = loss_fn(policies=p_a_b_c, tformat=tformat_level3)

        # loss = p_a_b_c.sum(), "a*bs*t*v"
        #loss[0].sum().backward(retain_graph=True)
        # loss[0].backward(retain_graph=True)

        # try:
        #     _check_nan(agent_parameters)
        # except Exception as e:
        #     for a, b in self.named_parameters():
        #         print("{}:{}".format(a, b.grad))
        #     a = 5
        #     pass

        return p_a_b_c, hidden_states, loss, tformat_level3 # note: policies can have NaNs in it!!