Пример #1
0
    def compute_qs(self, graph, node_feature, maximum_num_enemy):

        # get qs of actions
        move_arg, attack_arg = self(graph, node_feature, maximum_num_enemy)

        qs = torch.cat((move_arg, attack_arg),
                       dim=-1)  # of all units including enemies
        np_qs = dn(qs)

        ally_node_type_index = self.conf.qnet['ally_node_type_index']
        ally_node_indices = get_filtered_node_index_by_type(
            graph, ally_node_type_index)
        qs = qs[ally_node_indices, :]  # of only ally units

        ally_tags = graph.ndata['tag']
        ally_tags = ally_tags[ally_node_indices]
        if 'enemy_tag' in graph.ndata.keys():
            enemy_tags = graph.ndata['enemy_tag']
        else:
            enemy_tags = torch.zeros_like(ally_tags).view(-1, 1)  # dummy

        enemy_tags = enemy_tags[ally_node_indices, :]

        return_dict = dict()
        # for RL training
        return_dict['qs'] = qs

        # for SC2 interfacing
        return_dict['ally_tags'] = ally_tags
        return_dict['enemy_tags'] = enemy_tags
        return return_dict
Пример #2
0
    def get_q(self, graph, node_feature, qs, ws=None):
        device = node_feature.device
        ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY)

        # compute weighted sum of qs
        if ws is None:
            ws = self.get_w(graph, node_feature)  # [#. allies x #. clusters]

        weighted_q = qs.view(-1, 1) * ws  # [#. allies x #. clusters]

        qs = torch.zeros(size=(graph.number_of_nodes(), self.num_clusters), device=device)
        qs[ally_indices, :] = weighted_q

        graph.ndata['q'] = qs
        q_aggregated = dgl.sum_nodes(graph, 'q')  # [#. graph x #. clusters]

        # compute state_dependent_bias
        graph.ndata['node_feature'] = node_feature
        sum_node_feature = dgl.sum_nodes(graph, 'node_feature')  # [#. graph x feature dim]
        q_v = self.q_b_net(sum_node_feature)  # [#. graph x #. clusters]

        _ = graph.ndata.pop('node_feature')
        _ = graph.ndata.pop('q')
        q_aggregated = q_aggregated + q_v

        return q_aggregated  # [#. graph x #. clusters]
Пример #3
0
    def get_feat(self, graph, node_feature, ws=None):

        if ws is None:
            ws = self.get_w(graph, node_feature)

        if isinstance(graph, dgl.BatchedDGLGraph):
            single_graph = False
        else:
            single_graph = True

        ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY)
        ally_node_feature = node_feature[ally_indices, :]  # [#. allies x feature dim]

        num_allies, feat_dim = ally_node_feature.shape[0], ally_node_feature.shape[1]

        ally_node_feature = ally_node_feature.unsqueeze(dim=-1)
        ally_node_feature = ally_node_feature.repeat(1, 1, self.num_clusters)  # [#. allies x feature dim x #. clusters]

        # [#. allies x feature dim x #. cluster]
        weighted_feat = ws.view(num_allies, 1, self.num_clusters) * ally_node_feature

        _wf = torch.zeros(size=(graph.number_of_nodes(), feat_dim, self.num_clusters), device=ws.device)
        _wf[ally_indices, :] = weighted_feat

        graph.ndata['weighted_feat'] = _wf
        weighted_feat = dgl.sum_nodes(graph, 'weighted_feat')  # [#. graph x feature dim x #. clusters]

        if single_graph:
            weighted_feat = weighted_feat.unsqueeze(0)

        wf = weighted_feat.transpose(2, 1)  # [#. graph x num_cluster x feature_dim]
        _nf = node_feature.unsqueeze(-1)  # [# nodes x # features x 1]

        # compute group-wise compatibility scores

        if single_graph:
            num_ally_nodes = get_number_of_ally_nodes([graph])
        else:
            num_ally_nodes = get_number_of_ally_nodes(dgl.unbatch(graph))

        repeat_wf = torch.repeat_interleave(wf,
                                            torch.tensor(num_ally_nodes, device=node_feature.device),
                                            dim=0)  # [# allies x # clusters x feature_dim]

        ally_nf_expanded = node_feature[ally_indices, :].unsqueeze(-1)  # [# allies x feature_dim x 1]

        group_dot_prd = (ally_nf_expanded * repeat_wf.transpose(2, 1)).sum(1)  # [ # allies x # clusters ]

        nf_norm_allies = torch.norm(_nf[ally_indices, :], dim=1)  # [# allies x 1]
        wf_norm = torch.norm(repeat_wf, dim=2)  # [# allies x # clusters ]

        ally_normed_group_dot_prd = group_dot_prd / (nf_norm_allies * wf_norm)  # [# allies x # clusters]
        normed_group_dot_prd = torch.zeros(size=(node_feature.shape[0], ally_normed_group_dot_prd.shape[1]),
                                           device=node_feature.device)
        normed_group_dot_prd[ally_indices, :] = ally_normed_group_dot_prd

        _ = graph.ndata.pop('weighted_feat')

        return wf, ally_normed_group_dot_prd, normed_group_dot_prd
Пример #4
0
    def get_w(self, graph, node_feature):
        ws = self.w_net(graph, node_feature)  # [#. allies x #. clusters]
        ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY)
        ally_ws = ws[ally_indices, :]  # [#. allies x #. clusters]
        if self.use_clipped_score:
            ally_ws = ally_ws.clamp(min=VERY_SMALL_NUMBER, max=10)

        ally_ws = torch.nn.functional.softmax(ally_ws, dim=1)
        return ally_ws
Пример #5
0
    def eval_1_episode(self):
        # expected return
        # dictionary = {'name': (str), 'win': (bool), 'sum_reward': (float)}

        running_wr = self.env.winning_ratio

        env_name = self.env.name
        traj = self.run_1_episode()

        last_graph = traj[-1].state
        num_allies = get_filtered_node_index_by_type(last_graph, NODE_ALLY).size()
        num_enemies = get_filtered_node_index_by_type(last_graph, NODE_ENEMY).size()

        sum_reward = np.sum([exp.reward for exp in traj._trajectory])

        eval_dict = dict()
        eval_dict['name'] = env_name
        eval_dict['win'] = num_allies > num_enemies
        eval_dict['sum_reward'] = sum_reward

        self.env.winning_ratio = running_wr
        return eval_dict
Пример #6
0
    def forward(self,
                graph,
                node_feature,
                update_node_type_indices=[NODE_ALLY]):
        graph.ndata['node_feature'] = node_feature
        for ntype_idx in update_node_type_indices:
            node_index = get_filtered_node_index_by_type(graph, ntype_idx)
            apply_func = partial(self.apply_node_function, ntype_idx=ntype_idx)
            graph.apply_nodes(func=apply_func, v=node_index)

        _ = graph.ndata.pop('node_feature')
        updated_node_feature = graph.ndata.pop('updated_node_feature')
        return updated_node_feature
    def forward(self, graph, node_feature):
        if self.use_concat:
            graph.ndata['node_feature'] = torch.cat([node_feature, graph.ndata['init_node_feature']], dim=1)
        else:
            graph.ndata['node_feature'] = node_feature

        graph.send_and_recv(graph.edges(),
                            message_func=self.message_function,
                            reduce_func=self.reduce_function)

        for ntype_idx in self.node_types:
            node_indices = get_filtered_node_index_by_type(graph, ntype_idx)
            node_updater = self.node_updater['updater{}'.format(ntype_idx)]
            apply_node_func = partial(self.apply_node_function_multi_type, updater=node_updater)
            graph.apply_nodes(apply_node_func, v=node_indices)

        updated_node_feature = graph.ndata.pop('updated_node_feature')
        _ = graph.ndata.pop('aggregated_node_feature')
        _ = graph.ndata.pop('node_feature')

        if self.use_residual:
            updated_node_feature = updated_node_feature + node_feature

        return updated_node_feature