Example #1
0
    def forward(self, graph, node_feature, update_node_type_indices,
                update_edge_type_indices):
        """
        :param graph:
        :param node_feature:
        :param update_node_type_indices:
        :param update_edge_type_indices:
        :return:
        """
        graph.ndata['node_feature'] = node_feature

        if self.neighbor_degree == 0:  # Update features with only own features
            for ntype_idx in update_node_type_indices:
                node_index = get_filtered_node_index_by_type(graph, ntype_idx)
                apply_func = partial(self.apply_node_function_no_neighbor,
                                     ntype_idx=ntype_idx)
                graph.apply_nodes(func=apply_func, v=node_index)
        else:  # Update features with own features and 1 hop neighbor features
            for etype_idx in update_edge_type_indices:
                edge_index = get_filtered_edge_index_by_type(graph, etype_idx)
                graph.send_and_recv(edge_index,
                                    message_func=self.message_function,
                                    reduce_func=self.reduce_function)
            for ntype_idx in update_node_type_indices:
                node_index = get_filtered_node_index_by_type(graph, ntype_idx)
                apply_func = partial(self.apply_node_function_yes_neighbor,
                                     ntype_idx=ntype_idx)
                graph.apply_nodes(func=apply_func, v=node_index)

        updated_node_feature = graph.ndata.pop('node_feature')
        if self.neighbor_degree >= 1:
            graph.ndata.pop('aggregated_message')

        return updated_node_feature
Example #2
0
    def forward(self, graph, node_feature, sub_q_tots):
        graph.ndata['node_feature'] = node_feature
        device = node_feature.device

        len_groups = []
        q_tot = torch.zeros(graph.batch_size, device=device)
        for i, sub_q_tot in enumerate(sub_q_tots):
            node_indices = get_filtered_node_index_by_assignment(graph, i)

            len_groups.append(len(node_indices))

            mask = torch.zeros(size=(node_feature.shape[0], 1), device=device)
            mask[node_indices, :] = 1

            graph.ndata[
                'masked_node_feature'] = graph.ndata['node_feature'] * mask
            w_input = dgl.sum_nodes(graph, 'masked_node_feature')
            if self.rectifier == 'abs':
                q_tot = q_tot + torch.abs(self.w(w_input)).view(-1) * sub_q_tot
            elif self.rectifier == 'softplus':
                q_tot = q_tot + F.softplus(
                    self.w(w_input)).view(-1) * sub_q_tot
            else:
                raise RuntimeError("Not implemented rectifier")
            _ = graph.ndata.pop('masked_node_feature')

        # testing
        _ = graph.ndata.pop('node_feature')
        ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY)
        _v = torch.zeros(size=(graph.number_of_nodes(), node_feature.shape[1]),
                         device=device)
        _v[ally_indices, :] = node_feature[ally_indices, :]
        graph.ndata['node_feature'] = _v
        # testing

        v = self.v(dgl.sum_nodes(graph, 'node_feature')).view(-1)
        q_tot = q_tot + v

        _ = graph.ndata.pop('node_feature')

        len_groups = np.array(len_groups)
        ratio_groups = len_groups / np.sum(len_groups)

        print("Num elements in groups {}".format(len_groups))
        print("Num elements ratio {}".format(ratio_groups))

        ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY)
        target_assignment_weight = graph.ndata['normalized_score'][
            ally_indices]

        print("Average normalized scores {}".format(
            target_assignment_weight.mean(0)))

        return q_tot
Example #3
0
    def forward(self, graph, node_feature, qs, ally_node_type_index=NODE_ALLY):
        assert isinstance(graph, dgl.BatchedDGLGraph)

        w_emb = self.w_gn(graph, node_feature)  # [# nodes x # node_dim]
        w = torch.abs(self.w_ff(graph, w_emb))  # [# nodes x # 1]
        ally_node_indices = get_filtered_node_index_by_type(
            graph, ally_node_type_index)

        device = w_emb.device

        _qs = torch.zeros(size=(graph.number_of_nodes(), 1), device=device)
        w = w[ally_node_indices, :]  # [# allies x 1]
        _qs[ally_node_indices, :] = w * qs.view(-1, 1)
        graph.ndata['node_feature'] = _qs
        q_tot = dgl.sum_nodes(graph, 'node_feature')
        _ = graph.ndata.pop('node_feature')

        v_emb = self.v_gn(graph, node_feature)  # [# nodes x # node_dim]
        v = self.v_ff(graph, v_emb)  # [# nodes x # 1]
        v = v[ally_node_indices, :]  # [# allies x 1]
        _v = torch.zeros(size=(graph.number_of_nodes(), 1), device=device)
        _v[ally_node_indices, :] = v

        graph.ndata['node_feature'] = _v
        v = dgl.sum_nodes(graph, 'node_feature')
        _ = graph.ndata.pop('node_feature')

        q_tot = q_tot + v
        return q_tot.view(-1)
Example #4
0
    def compute_probs(self,
                      graph,
                      node_feature,
                      maximum_num_enemy,
                      ally_node_type_index=NODE_ALLY,
                      attack_edge_type_index=EDGE_IN_ATTACK_RANGE):
        # get logits of each action
        move_arg, hold_arg, attack_arg, high_level_prob = self(
            graph, node_feature, maximum_num_enemy, attack_edge_type_index)

        device = move_arg.device

        # Prepare un-normalized probability of attacks
        unnormed_ps = torch.cat((move_arg, hold_arg, attack_arg),
                                dim=-1)  # of all units including enemies

        ally_node_indices = get_filtered_node_index_by_type(
            graph, ally_node_type_index)
        unnormed_ps = unnormed_ps[ally_node_indices, :]  # of only ally units

        ally_tags = graph.ndata['tag']
        ally_tags = ally_tags[ally_node_indices]
        if 'enemy_tag' in graph.ndata.keys():
            enemy_tags = graph.ndata['enemy_tag']
        else:
            enemy_tags = torch.zeros_like(ally_tags,
                                          device=device).view(-1, 1)  # dummy

        enemy_tags = enemy_tags[ally_node_indices, :]

        ps = torch.nn.functional.softmax(unnormed_ps, dim=-1)
        high_level_prob = high_level_prob[
            ally_node_indices, :]  # [#. ally units x 3]

        attack_dim = ps.shape[-1] - self.move_dim - 1
        num_reps = torch.tensor([self.move_dim, 1, attack_dim],
                                dtype=torch.long,
                                device=device)
        high_level_prob = high_level_prob.repeat_interleave(num_reps, dim=1)
        ps = high_level_prob * ps

        log_ps = torch.log(ps + VERY_SMALL_NUMBER)
        ally_entropy = -torch.sum(log_ps * ps, dim=-1)  # per unit entropy
        log_p_move, log_p_hold, log_p_attack = torch.split(
            log_ps, [self.move_dim, 1, maximum_num_enemy], dim=1)

        return_dict = dict()
        # for RL training
        return_dict['unnormed_ps'] = unnormed_ps
        return_dict['probs'] = ps
        return_dict['log_ps'] = log_ps
        return_dict['log_p_move'] = log_p_move
        return_dict['log_p_hold'] = log_p_hold
        return_dict['log_p_attack'] = log_p_attack
        return_dict['ally_entropy'] = ally_entropy
        # for SC2 interfacing
        return_dict['ally_tags'] = ally_tags
        return_dict['enemy_tags'] = enemy_tags
        return return_dict
Example #5
0
 def get_q(self, graph, node_feature, maximum_num_enemy, critic=None):
     if critic is None:
         critic = self.critic
     ally_node_idx = get_filtered_node_index_by_type(graph, NODE_ALLY)
     q_m, q_h, q_a = critic(graph, node_feature, maximum_num_enemy)
     q = torch.cat([q_m, q_h, q_a], dim=-1)
     q_ally = q[ally_node_idx, :]
     return q_ally
Example #6
0
    def forward(self, graph, node_feature, update_node_type_indices):
        graph.ndata['node_feature'] = node_feature
        for ntype_idx in update_node_type_indices:
            node_index = get_filtered_node_index_by_type(graph, ntype_idx)
            apply_func = partial(self.apply_node_function, ntype_idx=ntype_idx)
            graph.apply_nodes(func=apply_func, v=node_index)

        _ = graph.ndata.pop('node_feature')
        updated_node_feature = graph.ndata.pop('updated_node_feature')
        return updated_node_feature
Example #7
0
    def eval_1_episode(self):
        # expected return
        # dictionary = {'name': (str), 'win': (bool), 'sum_reward': (float)}

        running_wr = self.env.winning_ratio

        env_name = self.env.name
        traj = self.run_1_episode()

        last_graph = traj[-1].state
        num_allies = get_filtered_node_index_by_type(last_graph, NODE_ALLY).size()
        num_enemies = get_filtered_node_index_by_type(last_graph, NODE_ENEMY).size()

        sum_reward = np.sum([exp.reward for exp in traj._trajectory])

        eval_dict = dict()
        eval_dict['name'] = env_name
        eval_dict['win'] = num_allies > num_enemies
        eval_dict['sum_reward'] = sum_reward

        self.env.winning_ratio = running_wr
        return eval_dict
Example #8
0
    def forward(self, graph, node_feature, update_node_type_indices,
                update_edge_type_indices):
        if self.use_concat:
            graph.ndata['node_feature'] = torch.cat(
                [node_feature, graph.ndata['init_node_feature']], dim=1)
        else:
            graph.ndata['node_feature'] = node_feature

        message_func = partial(
            self.message_function,
            update_edge_type_indices=update_edge_type_indices)
        reduce_func = partial(
            self.reduce_function,
            update_edge_type_indices=update_edge_type_indices)
        graph.send_and_recv(graph.edges(),
                            message_func=message_func,
                            reduce_func=reduce_func)

        if not self.use_multi_node_types:  # default behavior
            for ntype_idx in update_node_type_indices:
                node_indices = get_filtered_node_index_by_type(
                    graph, ntype_idx)
                graph.apply_nodes(self.apply_node_function, v=node_indices)

        else:  # testing
            for ntype_idx in update_node_type_indices:
                node_indices = get_filtered_node_index_by_type(
                    graph, ntype_idx)
                node_updater = self.node_updater['updater{}'.format(ntype_idx)]
                apply_node_func = partial(self.apply_node_function_multi_type,
                                          updater=node_updater)
                graph.apply_nodes(apply_node_func, v=node_indices)

        updated_node_feature = graph.ndata.pop('updated_node_feature')
        _ = graph.ndata.pop('aggregated_node_feature')
        _ = graph.ndata.pop('node_feature')
        return updated_node_feature
Example #9
0
    def forward(self, graph, node_feature):
        device = node_feature.device

        graph.ndata['node_feature'] = node_feature
        graph.apply_nodes(func=self.apply_node_function)
        prob = graph.ndata.pop('prob')
        ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY)

        _assignment = graph.ndata.pop('assignment')
        assignment = torch.ones_like(
            _assignment,
            device=device) * -1  # masking out enemy assignments as -1
        assignment[ally_indices] = _assignment[ally_indices]

        _normalized_score = graph.ndata.pop('normalized_score')
        normalized_score = torch.ones_like(_normalized_score,
                                           device=device) * -1
        normalized_score[ally_indices] = _normalized_score[ally_indices]

        return prob, assignment, normalized_score
Example #10
0
    def forward(self, graph, node_feature, update_node_type_indices, update_edge_type_indices):
        """
        :param graph: structure only graph
        :param node_feature:
        :param update_node_type_indices:
        :param update_edge_type_indices:
        :return:
        """
        graph.ndata['node_feature'] = node_feature

        message_func = partial(self.message_function, update_edge_type_indices=update_edge_type_indices)
        graph.send_and_recv(graph.edges(), message_func=message_func, reduce_func=self.reduce_function)

        for ntype_idx in update_node_type_indices:
            node_index = get_filtered_node_index_by_type(graph, ntype_idx)
            graph.apply_nodes(self.apply_node_function, v=node_index)

        updated_node_feature = graph.ndata.pop('node_feature')

        graph.ndata.pop('z')

        return updated_node_feature
Example #11
0
    def forward(self, graph, node_feature, qs, ally_node_type_index=NODE_ALLY):
        assert isinstance(graph, dgl.BatchedDGLGraph)

        w_emb = self.w_gn(graph, node_feature)  # [# nodes x # node_dim]
        if self.rectifier == 'abs':
            w = torch.abs(self.w_ff(graph, w_emb))  # [# nodes x # 1]
        elif self.rectifier == 'softplus':
            w = F.softplus(self.w_ff(graph, w_emb))  # [# nodes x # 1]
        else:
            raise RuntimeError("Not supported rectifier")

        # Curee's trick
        ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY)
        allies_assignment = graph.ndata['assignment'][ally_indices]
        target_allies = allies_assignment == self.target_assignment
        target_indices = torch.arange(target_allies.size(0))[target_allies]

        device = w_emb.device

        _qs = torch.zeros(size=(graph.number_of_nodes(), 1), device=device)
        w = w[target_indices, :]  # [# assignments x 1]
        _qs[target_indices, :] = w * qs[target_indices].view(-1, 1)
        graph.ndata['node_feature'] = _qs
        q_tot = dgl.sum_nodes(graph, 'node_feature')
        _ = graph.ndata.pop('node_feature')

        v_emb = self.v_gn(graph, node_feature)  # [# nodes x # node_dim]
        v = self.v_ff(graph, v_emb)  # [# nodes x # 1]
        v = v[target_indices, :]  # [# allies x 1]
        _v = torch.zeros(size=(graph.number_of_nodes(), 1), device=device)
        _v[target_indices, :] = v

        graph.ndata['node_feature'] = _v
        v = dgl.sum_nodes(graph, 'node_feature')
        _ = graph.ndata.pop('node_feature')

        q_tot = q_tot + v
        return q_tot.view(-1)
Example #12
0
    def forward(self, graph, node_feature, qs, ally_node_type_index=NODE_ALLY):
        assert isinstance(graph, dgl.BatchedDGLGraph)

        # w_emb = self.w_gn(graph, node_feature)  # [# nodes x # node_dim]
        # if self.rectifier == 'abs':
        #     w = torch.abs(self.w_ff(graph, w_emb))  # [# nodes x # 1]
        # elif self.rectifier == 'softplus':
        #     w = F.softplus(self.w_ff(graph, w_emb))  # [# nodes x # 1]

        # Curee's trick
        ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY)
        target_assignment_weight = graph.ndata['normalized_score'][
            ally_indices][:, self.target_assignment]

        device = node_feature.device

        _qs = torch.zeros(size=(graph.number_of_nodes(), 1), device=device)

        # w = w[ally_indices, :]  # [# assignments x 1]

        _qs[ally_indices, :] = qs.view(-1, 1) * target_assignment_weight.view(
            -1, 1)
        graph.ndata['node_feature'] = _qs
        q_tot = dgl.sum_nodes(graph, 'node_feature')
        _ = graph.ndata.pop('node_feature')

        v_emb = self.v_gn(graph, node_feature)  # [# nodes x # node_dim]
        v = self.v_ff(graph, v_emb)  # [# nodes x # 1]
        v = v[ally_indices, :]  # [# allies x 1]
        _v = torch.zeros(size=(graph.number_of_nodes(), 1), device=device)
        _v[ally_indices, :] = v

        graph.ndata['node_feature'] = _v
        v = dgl.sum_nodes(graph, 'node_feature')
        _ = graph.ndata.pop('node_feature')

        q_tot = q_tot + v
        return q_tot.view(-1)
Example #13
0
    def compute_qs(self,
                   graph,
                   node_feature,
                   maximum_num_enemy,
                   ally_node_type_index=NODE_ALLY,
                   attack_edge_type_index=None):

        if attack_edge_type_index is None:
            attack_edge_type_index = self.attack_edge_type_index

        move_arg, hold_arg, attack_arg, assignment, normalized_score = self(
            graph, node_feature, maximum_num_enemy, attack_edge_type_index)
        qs = torch.cat((move_arg, hold_arg, attack_arg),
                       dim=-1)  # of all units including enemies

        ally_node_indices = get_filtered_node_index_by_type(
            graph, ally_node_type_index)
        qs = qs[ally_node_indices, :]  # of only ally units

        ally_tags = graph.ndata['tag']
        ally_tags = ally_tags[ally_node_indices]
        if 'enemy_tag' in graph.ndata.keys():
            enemy_tags = graph.ndata['enemy_tag']
        else:
            enemy_tags = torch.zeros_like(ally_tags).view(-1, 1)  # dummy

        enemy_tags = enemy_tags[ally_node_indices, :]

        return_dict = dict()
        # for RL training
        return_dict['qs'] = qs
        return_dict['assignment'] = assignment
        return_dict['normalized_score'] = normalized_score

        # for SC2 interfacing
        return_dict['ally_tags'] = ally_tags
        return_dict['enemy_tags'] = enemy_tags
        return return_dict
Example #14
0
    def fit(self, device='cpu'):
        # the prefix 'c' indicates #current# time stamp inputs
        # the prefix 'n' indicates #next# time stamp inputs

        # expected specs:
        # bs = batch_size, nt = hist_num_time_steps
        # 'h_graph' = list of graph lists [[g_(0,0), g_(0,1), ... g_(0,nt)],
        #                                  [g_(1,0), g_(1,1), ..., g_(1,nt)],
        #                                  [g_(2,0), ..., g_(bs, 0), ... g_(bs, nt)]]
        # 'graph' = list of graphs  [g_(0), g_(1), ..., g_(bs)]

        fit_conf = self.conf.fit_conf

        batch_size = fit_conf['batch_size']
        hist_num_time_steps = fit_conf['hist_num_time_steps']

        c_h_graph, c_graph, actions, rewards, n_h_graph, n_graph, dones = self.buffer.sample(
            batch_size)

        c_maximum_num_enemy = get_largest_number_of_enemy_nodes(c_graph)
        n_maximum_num_enemy = get_largest_number_of_enemy_nodes(n_graph)

        # casting actions to one torch tensor
        actions = torch.cat(actions).long()

        # 'c_graph' is now list of graphs
        c_ally_units = [
            len(get_filtered_node_index_by_type(graph, NODE_ALLY))
            for graph in c_graph
        ]
        c_ally_units = torch.Tensor(c_ally_units).long()

        # prepare rewards
        rewards = torch.Tensor(rewards)
        rewards = rewards.repeat_interleave(c_ally_units, dim=0)

        # preparing dones
        dones = torch.Tensor(dones)
        dones = dones.repeat_interleave(c_ally_units, dim=0)

        # batching graphs
        list_c_h_graph = [g for L in c_h_graph for g in L]
        list_n_h_graph = [g for L in n_h_graph for g in L]

        c_hist_graph = dgl.batch(list_c_h_graph)
        n_hist_graph = dgl.batch(list_n_h_graph)

        c_curr_graph = dgl.batch(c_graph)
        n_curr_graph = dgl.batch(n_graph)

        if device != 'cpu':
            c_hist_graph.to(torch.device('cuda'))
            n_hist_graph.to(torch.device('cuda'))
            c_curr_graph.to(torch.device('cuda'))
            n_curr_graph.to(torch.device('cuda'))
            actions = actions.to(torch.device('cuda'))
            rewards = rewards.to(torch.device('cuda'))
            dones = dones.to(torch.device('cuda'))

        c_hist_feature = c_hist_graph.ndata.pop('node_feature')
        c_curr_feature = c_curr_graph.ndata.pop('node_feature')

        n_hist_feature = n_hist_graph.ndata.pop('node_feature')
        n_curr_feature = n_curr_graph.ndata.pop('node_feature')

        fit_return_dict = self.brain.fit(
            num_time_steps=hist_num_time_steps,
            c_hist_graph=c_hist_graph,
            c_hist_feature=c_hist_feature,
            c_curr_graph=c_curr_graph,
            c_curr_feature=c_curr_feature,
            c_maximum_num_enemy=c_maximum_num_enemy,
            n_hist_graph=n_hist_graph,
            n_hist_feature=n_hist_feature,
            n_curr_graph=n_curr_graph,
            n_curr_feature=n_curr_feature,
            n_maximum_num_enemy=n_maximum_num_enemy,
            actions=actions,
            rewards=rewards,
            dones=dones)

        return fit_return_dict