Exemplo n.º 1
0
    def __init__(self, num_cmd_type, num_target_type, target_feat_dim,
                 attribute_dim):
        """
        num_cmd_type: num of cmd types
        cmd_type_emb_dim: cmd emb output dim
        target_type_emb: emb function for target unit type, shared with UnitEmbedding
        target_type_emb_dim: unit type emb output dim
        xy_feat_dim: output dim for xy-net
        target_feat_dim: feat dim of enemy/resource encoder

        """
        super().__init__()

        self.num_cmd_type = num_cmd_type
        self.num_target_type = num_target_type
        self.target_feat_dim = target_feat_dim

        self.attribute_dim = attribute_dim
        self.out_dim = 4 * attribute_dim

        self.cmd_type_emb = nn.Embedding(num_cmd_type, attribute_dim)
        self.target_type_emb = nn.Embedding(num_target_type, attribute_dim)
        self.xy_net = nn.Sequential(
            weight_norm(nn.Linear(2, attribute_dim), dim=None), )
        self.enemy_feat_net = nn.Sequential(
            weight_norm(nn.Linear(target_feat_dim, attribute_dim), dim=None), )
        self.resource_feat_net = nn.Sequential(
            weight_norm(nn.Linear(target_feat_dim, attribute_dim), dim=None), )

        # compute locations in output feature tensor
        self.cmd_type_emb_end = attribute_dim
        self.target_type_emb_end = 2 * attribute_dim
        self.xy_feat_end = 3 * attribute_dim
        self.target_feat_end = 4 * attribute_dim
        assert_eq(self.target_feat_end, self.out_dim)
Exemplo n.º 2
0
    def compute_eval_loss(self, batch):
        batch = self._format_language_input_with_candidate(batch)
        glob_feat = self._forward(batch)

        cont = 1 - batch['is_base_frame']
        cont_loss = self.cont_cls.compute_loss(glob_feat, cont)

        lang_logp = self.inst_selector.compute_prob(
            batch['inst_input'],
            batch['inst'],
            glob_feat,
            log=True
        )
        lang_loss = -lang_logp.gather(1, batch['inst_idx'].unsqueeze(1)).squeeze(1)

        assert_eq(cont_loss.size(), lang_loss.size())
        lang_loss = (1 - cont.float()) * lang_loss
        loss = cont_loss + lang_loss
        loss = loss.mean()
        all_loss = {
            'loss': loss,
            'cont_loss': cont_loss.mean(),
            'lang_loss': lang_loss.mean()
        }
        return loss, all_loss
Exemplo n.º 3
0
    def compute_loss(self,
                     ufeat,
                     globfeat,
                     target_type,
                     mask,
                     *,
                     include_nil=False):
        """loss, averaged by sum(mask)

        ufeat: [batch, padded_num_unit, ufeat_dim]
        globfeat: [batch, padded_num_unit, globfeat_dim]
        target_type: [batch, padded_num_unit]
          target_type[i, j] is real unit_type iff unit_{i,j} is build unit
        mask: [batch, padded_num_unit]
          mask[i, j] = 1 iff the unit is true unit and its cmd_type == BUILD_UNIT
        """
        batch, pnum_unit, _ = ufeat.size()
        assert_eq(target_type.size(), (batch, pnum_unit))
        assert_eq(mask.size(), (batch, pnum_unit))

        logit = self.forward(ufeat, globfeat)
        # logit [batch, pnum_unit, num_unit_type]
        logp = logit2logp(logit, target_type)
        # logp [batch, pnum_unit]
        loss = -(logp * mask)
        # if sum_loss:
        loss = loss.sum(1)
        if not include_nil:
            return loss

        nil_type = torch.zeros_like(target_type)
        nil_logp = logit2logp(logit, nil_type)
        return loss, nil_logp
Exemplo n.º 4
0
    def _get_human_instruction(self, batch):
        assert_eq(batch['prev_inst'].size(0), 1)
        device = batch['prev_inst'].device

        inst = input('Please input your instruction\n')
        # inst = 'build peasant'

        import pdb
        pdb.set_trace()

        inst_idx = torch.zeros((1, )).long().to(device)
        inst_idx[0] = self.executor.inst_dict.get_inst_idx(inst)
        inst_cont = torch.zeros((1, )).long().to(device)
        if len(inst) == 0:
            # inst = batch['prev_inst']
            inst = self.prev_inst
            inst_cont[0] = 1

        self.prev_inst = inst
        raw_inst = convert_to_raw_instruction(inst, self.max_raw_chars)
        inst, inst_len = self.executor.inst_dict.parse(inst, True)
        inst = torch.LongTensor(inst).unsqueeze(0).to(device)
        inst_len = torch.LongTensor([inst_len]).to(device)
        raw_inst = torch.LongTensor([raw_inst]).to(device)

        reply = {
            'inst': inst_idx.unsqueeze(1),
            'inst_pi':
            torch.ones(1, self.num_insts).to(device) / self.num_insts,
            'cont': inst_cont.unsqueeze(1),
            'cont_pi': torch.ones(1, 2).to(device) / 2,
            'raw_inst': raw_inst
        }

        return inst, inst_len, inst_cont, reply
Exemplo n.º 5
0
 def _forward1d(self, inst, inst_len, context):
     assert_eq(inst.size(0), context.size(0))
     inst_feat = self.inst_proj(self.encoder(inst, inst_len))
     # inst_feat = self.encoder(inst, inst_len)
     # context = self.inst_proj(context)
     logit = (inst_feat * context).sum(1)
     return logit
Exemplo n.º 6
0
    def sample(self, cont_probs, probs, prev_samples):
        """Categorical sampler that can persist the previous samples
        with some probability.

        Args:
            cont_probs: probabilities of keeping the previous samples
                    [batch, 2]
            probs: probabilities returned by model forward
                    [batch, num_actions]
            prev_samples: previosly sampled actions
                    [batch]
        return:
            cont_samples: [batch]
            samples: [batch]
        """
        assert_eq(cont_probs.size(1), 2)
        cont_samples = cont_probs.multinomial(1).squeeze(1)
        new_samples = probs.multinomial(1).squeeze(1)

        assert_eq(prev_samples.size(), new_samples.size())
        samples = cont_samples * prev_samples + (1 -
                                                 cont_samples) * new_samples
        return {
            self.cont_key: cont_samples,
            self.key: samples,
        }
Exemplo n.º 7
0
    def compute_loss(self, batch):
        """used for pre-training the model with dataset"""
        batch = self._format_supervised_language_input(batch)
        glob_feat = self._forward(batch)

        cont = 1 - batch["is_base_frame"]
        cont_loss = self.cont_cls.compute_loss(glob_feat, cont)
        lang_loss = self.inst_selector.compute_loss(
            batch["pos_cand_inst"],
            batch["pos_cand_inst_len"],
            batch["neg_cand_inst"],
            batch["neg_cand_inst_len"],
            batch["inst"],
            batch["inst_len"],
            glob_feat,
            batch["inst_idx"],
        )

        assert_eq(cont_loss.size(), lang_loss.size())
        lang_loss = (1 - cont.float()) * lang_loss
        loss = cont_loss + lang_loss
        loss = loss.mean()
        all_loss = {
            "loss": loss,
            "cont_loss": cont_loss.mean(),
            "lang_loss": lang_loss.mean(),
        }
        return loss, all_loss
Exemplo n.º 8
0
    def forward(self, ufeat, efeat, globfeat, num_enemy):
        """return masked prob that each real enemy is the target

        return: prob: [batch, pnum_unit, pnum_enemy]
        """
        batch, pnum_unit, _ = ufeat.size()
        pnum_enemy = efeat.size(1)
        assert_eq(num_enemy.size(), (batch, ))

        assert globfeat is None and self.globfeat_dim == 0
        infeat = ufeat
        # infeat [batch, pnum_unit, in_dim]
        proj = self.net(infeat)
        proj = proj.unsqueeze(2).repeat(1, 1, pnum_enemy, 1)
        # proj [batch, pnum_unit, pnum_enemy, efeat_dim]
        efeat = efeat.unsqueeze(1).repeat(1, pnum_unit, 1, 1)
        # efeat [batch, pnum_unit, pnum_enemy, efeat_dim
        logit = (proj * efeat).sum(3) / self.norm
        # logit [batch, pnum_unit, pnum_enemy]
        enemy_mask = create_real_unit_mask(num_enemy, pnum_enemy)
        # enemy_mask [batch, pnum_enemy]
        enemy_mask = enemy_mask.unsqueeze(1).repeat(1, pnum_unit, 1)
        # if torch.isinf(logit).any() or torch.isnan(logit).any():
        #     import pdb
        #     pdb.set_trace()

        prob = masked_softmax(logit, enemy_mask, 2)
        # prob [batch, pnum_unit, pnum_enemy]
        return prob
Exemplo n.º 9
0
    def compute_loss(self, batch):
        """used for pre-training the model with dataset
        """
        batch = self._format_supervised_language_input(batch)
        glob_feat = self._forward(batch)

        cont = 1 - batch['is_base_frame']
        cont_loss = self.cont_cls.compute_loss(glob_feat, cont)
        lang_loss = self.inst_selector.compute_loss(
            batch['pos_cand_inst'],
            batch['pos_cand_inst_len'],
            batch['neg_cand_inst'],
            batch['neg_cand_inst_len'],
            batch['inst'],
            batch['inst_len'],
            glob_feat,
            batch['inst_idx'])

        assert_eq(cont_loss.size(), lang_loss.size())
        lang_loss = (1 - cont.float()) * lang_loss
        loss = cont_loss + lang_loss
        loss = loss.mean()
        all_loss = {
            'loss': loss,
            'cont_loss': cont_loss.mean(),
            'lang_loss': lang_loss.mean()
        }
        return loss, all_loss
Exemplo n.º 10
0
    def compute_loss(self, ufeat, efeat, globfeat, num_enemy, target_idx,
                     mask):
        """loss, averaged by sum(mask)

        ufeat: [batch, padded_num_unit, ufeat_dim]
        efeat: [batch, padded_num_enemy, efeat_dim]
        globfeat: [batch, padded_num_unit, globfeat_dim]
        num_enemy: [batch]
        target_idx: [batch, padded_num_unit]
          target_idx[i, j] is real target idx iff unit_{i,j} is attacking
        mask: [batch, padded_num_unit]
          mask[i] = 1 iff unit i is true unit and its cmd_type == ATTACK
        """
        if target_idx.min() < 0 or (target_idx.max(1)[0] > num_enemy).any():
            import pdb
            pdb.set_trace()

        batch, pnum_unit, _ = ufeat.size()
        assert_eq(target_idx.size(), (batch, pnum_unit))
        assert_eq(mask.size(), (batch, pnum_unit))

        prob = self.forward(ufeat, efeat, globfeat, num_enemy)
        # prob [batch, pnum_unit, pnum_enemy]
        prob = prob.gather(2, target_idx.unsqueeze(2)).squeeze(2)
        # prob [batch, pnum_unit]
        logp = (prob + 1e-6).log()

        loss = -(logp * mask).sum(1)
        return loss
Exemplo n.º 11
0
    def _select_feat_with_loc(feat, x, y):
        """select feature given locations

        feat: [batch, nc, h, w]
        x: [batch, pnum_unit], range: [0, 1)
        y: [batch, pnum_unit], range: [0, 1)

        return:
        selected_feat: [batch, pnum_unit, nc(feat_dim)]
        Note: returned feature is not masked at all
        """
        # feat: [batch, nc, h, w]
        x = x.long()
        y = y.long()
        if not x.max() < gc.MAP_X and x.min() >= 0:
            import pdb
            pdb.set_trace()

        assert (x.max() < gc.MAP_X and x.min() >= 0)
        assert (y.max() < gc.MAP_Y and y.min() >= 0)

        loc = y * gc.MAP_X + x
        # loc: [batch, pnum_unit]

        batch, nc, h, w = feat.size()
        assert_eq(h, gc.MAP_Y)
        assert_eq(w, gc.MAP_X)
        feat = feat.view(batch, nc, h * w)
        # feat: [batch, nc, h * w]

        loc = loc.unsqueeze(1).repeat(1, nc, 1)
        selected_feat = feat.gather(2, loc)
        # selected_feat: [batch, nc, pnum_unit]
        selected_feat = selected_feat.transpose(1, 2).contiguous()
        return selected_feat
Exemplo n.º 12
0
    def forward(self, prev_cmd, num_cmd):
        """take cmd_type, produce cmd_repr

        prev_cmd: [batch, num_padded_unit, num_padded_prev_cmds],
        num_cmd: [batch, num_padded_unit]
          0 <= ctype < num_ctype, padding cmd should have a valid unit_type
        """
        assert_eq(prev_cmd.dim(), 3)

        ctype_emb = self.ctype_emb(prev_cmd)
        # ctype_emb [batch, num_padded_unit, num_padded_prev_cmds, emb_dim]
        ctype_emb = ctype_emb.sum(2)
        return ctype_emb
Exemplo n.º 13
0
def logit2logp(logit, index):
    """return log_softmax(logit)[index] along dim=2

    logit: [batch, padded_num_unit, dim]
    index: [batch, padded_num_unit]

    return logp: [batch, padded_num_unit]
    """
    assert_eq(logit.size()[:2], index.size())
    assert_eq(logit.dim(), 3)

    logp = nn.functional.log_softmax(logit, 2)
    logp = logp.gather(2, index.unsqueeze(2)).squeeze(2)
    return logp
Exemplo n.º 14
0
def masked_softmax(logit, mask, dim):
    """masked softmax

    (assume dim = 3)
    logit: [batch, pnum_unit, pnum_enemy]
    mask: [batch, pnum_unit, pnum_enemy]
      if mask[batch, pnum_unit] == 0 for all (i.e. no enemy),
      will return uniform probability for that entry

    prob: [batch, pnum_unit, pnum_enemy]
    """
    assert_eq(logit.size(), mask.size())

    logit = logit - (1 - mask) * 1e9
    logit_max = logit.max(dim, keepdim=True)[0]  #.detach()
    exp = (logit - logit_max).exp()
    denom = exp.sum(dim, keepdim=True)
    prob = exp / exp.sum(dim, keepdim=True)
    return prob
Exemplo n.º 15
0
    def sample(self, cont_probs, probs, prev_samples):
        """Categorical sampler that can persist the previous samples
        with some probability.

        Args:
            cont_probs: probabilities of keeping the previous samples
                    [batch, 2]
            probs: probabilities returned by model forward
                    [batch, num_actions]
            prev_samples: previosly sampled actions
                    [batch]
        return:
            cont_samples: [batch]
            samples: [batch]
        """
        assert_eq(cont_probs.size(1), 2)

        if True in torch.isnan(cont_probs.sum(1)):
            import pdb
            pdb.set_trace()

        # print("Cont prob sum: ", cont_probs.sum(1))
        # print("Prob_size: ", probs.size())
        # print("Probs_sum: ", probs.sum(1))

        cont_samples = cont_probs.multinomial(1).squeeze(1)
        new_samples = probs.multinomial(1).squeeze(1)

        assert_eq(prev_samples.size(), new_samples.size())

        # What a crazy bug...
        if -1 in prev_samples:
            samples = new_samples
            cont_samples = cont_samples * 0
        else:
            samples = cont_samples * prev_samples + (
                1 - cont_samples) * new_samples

        return {
            self.cont_key: cont_samples,
            self.key: samples,
        }
Exemplo n.º 16
0
    def _process_units(self, units, padded_size):
        types = []
        xs = []
        ys = []
        hps = []
        for u in units:
            x = u['x']
            y = u['y']
            if x < 0:
                x = 0
            if y < 0:
                y = 0
            types.append(u['unit_type'])
            xs.append(x)
            ys.append(y)
            hps.append(u['hp'])

        types = np.array(types, dtype=np.int64)
        xs = np.array(xs, dtype=np.int64)
        ys = np.array(ys, dtype=np.int64)
        hps = np.array(hps, dtype=np.int64)
        num_units = len(units)
        assert_eq(num_units, types.shape[0])

        types = self._pad(types, padded_size, 0)
        xs = self._pad(xs, padded_size, 0)
        ys = self._pad(ys, padded_size, 0)
        hps = self._pad(hps, padded_size, 0)

        units_info = {
            'types': types,
            'xs': xs,
            'ys': ys,
            'hps': hps,
            'num_units': num_units
        }
        return units_info
Exemplo n.º 17
0
    def format_executor_input(self, batch, inst, inst_len, inst_cont):
        """convert batch (in c++ format) to input used by executor

        inst: [batch, max_sentence_len], LongTensor, parsed
        inst_len: [batch]
        """
        assert_eq(inst.dim(), 2)
        assert_eq(inst_len.dim(), 1)
        # assert_eq(inst_cont.dim(), 1)

        my_units = {
            'types': batch['army_type'],
            'hps': batch['army_hp'],
            'xs': batch['army_x'],
            'ys': batch['army_y'],
            'num_units': batch['num_army'].squeeze(1),
        }
        enemy_units = {
            'types': batch['enemy_type'],
            'hps': batch['enemy_hp'],
            'xs': batch['enemy_x'],
            'ys': batch['enemy_y'],
            'num_units': batch['num_enemy'].squeeze(1),
        }
        resource_units = {
            'types': batch['resource_type'],
            'hps': batch['resource_hp'],
            'xs': batch['resource_x'],
            'ys': batch['resource_y'],
            'num_units': batch['num_resource'].squeeze(1),
        }
        current_cmds = {
            'cmd_type': batch['current_cmd_type'],
            'target_type': batch['current_cmd_unit_type'],
            'target_x': batch['current_cmd_x'],
            'target_y': batch['current_cmd_y'],
            'target_gather_idx': batch['current_cmd_gather_idx'],
            'target_attack_idx': batch['current_cmd_attack_idx'],
        }

        # print('prev cmd for executor: (num units: %d)' % batch['num_army'][0][0].item())
        # print(batch['prev_cmd'][0, :batch['num_army'][0][0].item()])
        prev_cmds = batch['prev_cmd']

        # print('inst:', inst)
        # print('hist inst')
        # print(batch['hist_inst'])
        # print(batch['hist_inst_diff'])

        # hist_inst = batch['hist_inst']
        # bsize, num_inst = hist_inst.size()
        # parsed_hist_inst, hist_inst_len = parse_batch_inst(
        #     self.inst_dict, hist_inst.view(-1), hist_inst.device)
        # parsed_hist_inst = parsed_hist_inst.view(bsize, num_inst, -1)
        # hist_inst_len = hist_inst_len.view(bsize, num_inst)
        # print(parsed_hist_inst)

        # print('@@@@@@@@@@, after correction')
        # import time
        # t = time.time()
        hist_inst, hist_inst_len, hist_inst_diff = parse_hist_inst(
            self.inst_dict, batch['hist_inst'], batch['hist_inst_diff'], inst,
            inst_len, inst_cont, is_word_based(self.args.inst_encoder_type))
        # print('time for executor format hist:', time.time() - t)
        # print('hist_inst:\n', hist_inst)
        # print('hist_inst_len:\n', hist_inst_len)
        # print('hist_inst_diff:\n', hist_inst_diff)
        # print('@@@@@@@@@@')

        data = {
            'inst': inst,
            'inst_len': inst_len,
            'hist_inst': hist_inst,
            'hist_inst_len': hist_inst_len,
            'hist_inst_diff': hist_inst_diff,
            'resource_bin': batch['resource_bin'],
            'my_units': my_units,
            'enemy_units': enemy_units,
            'resource_units': resource_units,
            'prev_cmds': prev_cmds,
            'current_cmds': current_cmds,
            'map': batch['map'],
        }
        return data
Exemplo n.º 18
0
    def format_rl_executor_input(self, batch, inst, inst_len, inst_cont):
        """convert batch (in c++ format) to input used by executor

        inst: [batch, max_sentence_len], LongTensor, parsed
        inst_len: [batch]
        """
        assert_eq(inst.dim(), 2)
        assert_eq(inst_len.dim(), 1)
        # assert_eq(inst_cont.dim(), 1)

        my_units = {
            "types": batch["army_type"],
            "hps": batch["army_hp"],
            "xs": batch["army_x"],
            "ys": batch["army_y"],
            "num_units": batch["num_army"].squeeze(1),
        }
        enemy_units = {
            "types": batch["enemy_type"],
            "hps": batch["enemy_hp"],
            "xs": batch["enemy_x"],
            "ys": batch["enemy_y"],
            "num_units": batch["num_enemy"].squeeze(1),
        }
        resource_units = {
            "types": batch["resource_type"],
            "hps": batch["resource_hp"],
            "xs": batch["resource_x"],
            "ys": batch["resource_y"],
            "num_units": batch["num_resource"].squeeze(1),
        }
        # target_cmds = {
        #     'cmd_type': batch['t_current_cmd_type'],
        #     'target_type': batch['t_current_cmd_unit_type'],
        #     'target_x': batch['t_current_cmd_x'],
        #     'target_y': batch['t_current_cmd_y'],
        #     'target_gather_idx': batch['t_current_cmd_gather_idx'],
        #     'target_attack_idx': batch['t_current_cmd_attack_idx'],
        # }
        current_cmds = {
            "cmd_type": batch["current_cmd_type"],
            "target_type": batch["current_cmd_unit_type"],
            "target_x": batch["current_cmd_x"],
            "target_y": batch["current_cmd_y"],
            "target_gather_idx": batch["current_cmd_gather_idx"],
            "target_attack_idx": batch["current_cmd_attack_idx"],
        }

        # print('prev cmd for executor: (num units: %d)' % batch['num_army'][0][0].item())
        # print(batch['prev_cmd'][0, :batch['num_army'][0][0].item()])
        prev_cmds = batch["prev_cmd"]

        # print('inst:', inst)
        # print('hist inst')
        # print(batch['hist_inst'])
        # print(batch['hist_inst_diff'])

        # hist_inst = batch['hist_inst']
        # bsize, num_inst = hist_inst.size()
        # parsed_hist_inst, hist_inst_len = parse_batch_inst(
        #     self.inst_dict, hist_inst.view(-1), hist_inst.device)
        # parsed_hist_inst = parsed_hist_inst.view(bsize, num_inst, -1)
        # hist_inst_len = hist_inst_len.view(bsize, num_inst)
        # print(parsed_hist_inst)

        # print('@@@@@@@@@@, after correction')
        # import time
        # t = time.time()
        hist_inst, hist_inst_len, hist_inst_diff = parse_hist_inst(
            self.inst_dict,
            batch["hist_inst"],
            batch["hist_inst_diff"],
            inst,
            inst_len,
            inst_cont,
            is_word_based(self.args.inst_encoder_type),
        )
        # print('time for executor format hist:', time.time() - t)
        # print('hist_inst:\n', hist_inst)
        # print('hist_inst_len:\n', hist_inst_len)
        # print('hist_inst_diff:\n', hist_inst_diff)
        # print('@@@@@@@@@@')

        data = {
            "inst": inst,
            "inst_len": inst_len,
            "hist_inst": hist_inst,
            "hist_inst_len": hist_inst_len,
            "hist_inst_diff": hist_inst_diff,
            "resource_bin": batch["resource_bin"],
            "my_units": my_units,
            "enemy_units": enemy_units,
            "resource_units": resource_units,
            "prev_cmds": prev_cmds,
            "current_cmds": current_cmds,
            "map": batch["map"],
        }
        return data
Exemplo n.º 19
0
    def compute_rl_log_probs(self, batch):
        # army_feat, enemy_feat, resource_feat, glob_feat, map_feat = self.conv_encoder(batch)
        features = self.conv_encoder(batch)

        real_unit_mask, cmd_type_mask = create_loss_masks(
            batch["my_units"]["num_units"],
            batch["target_cmds"]["cmd_type"],
            self.num_cmd_type,
        )

        # global continue classfier ## TODO: Implement the global continue classifier

        army_inst = torch.cat([features["army_feat"], features["inst_feat"]], 2)
        # action-arg classifiers
        gather_loss = self.gather_cls.compute_loss(
            army_inst,
            features["resource_feat"],
            None,
            batch["resource_units"]["num_units"],
            batch["target_cmds"]["target_gather_idx"],
            cmd_type_mask[:, :, gc.CmdTypes.GATHER.value],
        )
        attack_loss = self.attack_cls.compute_loss(
            army_inst,
            features["enemy_feat"],
            None,  # glob_feat,
            batch["enemy_units"]["num_units"],
            batch["target_cmds"]["target_attack_idx"],
            cmd_type_mask[:, :, gc.CmdTypes.ATTACK.value],
        )
        build_building_loss = self.build_building_cls.compute_loss(
            army_inst,
            features["map_feat"],
            None,  # glob_feat,
            batch["target_cmds"]["target_type"],
            batch["target_cmds"]["target_x"],
            batch["target_cmds"]["target_y"],
            cmd_type_mask[:, :, gc.CmdTypes.BUILD_BUILDING.value],
        )
        build_unit_loss, nil_build_logp = self.build_unit_cls.compute_loss(
            army_inst,
            None,  # glob_feat,
            batch["target_cmds"]["target_type"],
            cmd_type_mask[:, :, gc.CmdTypes.BUILD_UNIT.value],
            include_nil=True,
        )
        move_loss = self.move_cls.compute_loss(
            army_inst,
            features["map_feat"],
            None,  # glob_feat,
            batch["target_cmds"]["target_x"],
            batch["target_cmds"]["target_y"],
            cmd_type_mask[:, :, gc.CmdTypes.MOVE.value],
        )

        # type loss
        ctype_context = torch.cat(
            [
                features["sum_army"],
                features["sum_enemy"],
                features["sum_resource"],
                features["money_feat"],
            ],
            1,
        )
        cmd_type_logp = self.cmd_type_cls.compute_prob(
            army_inst, ctype_context, logp=True
        )
        # cmd_type_logp: [batch, num_unit, num_cmd_type]

        # extra continue
        # cmd_type_prob = self.cmd_type_cls.compute_prob(army_inst, ctype_context)
        # cont_type_prob = cmd_type_prob[:, :, gc.CmdTypes.CONT.value].clamp(max=1-1e-6)
        build_unit_type_logp = cmd_type_logp[:, :, gc.CmdTypes.BUILD_UNIT.value]
        extra_cont_logp = build_unit_type_logp + nil_build_logp
        # extra_cont_logp: [batch, num_unit]
        # print('extra cont logp size:', extra_cont_logp.size())
        # the following hack only works if CONT is the last one
        assert gc.CmdTypes.CONT.value == len(gc.CmdTypes) - 1
        assert (
            extra_cont_logp.size() == cmd_type_logp[:, :, gc.CmdTypes.CONT.value].size()
        )
        cont_logp = log_sum_exp(
            cmd_type_logp[:, :, gc.CmdTypes.CONT.value], extra_cont_logp
        )
        # cont_logp: [batch, num_unit]
        cmd_type_logp = torch.cat(
            [cmd_type_logp[:, :, : gc.CmdTypes.CONT.value], cont_logp.unsqueeze(2)], 2
        )
        # cmd_type_logp: [batch, num_unit, num_cmd_type]
        cmd_type_logp = cmd_type_logp.gather(
            2, batch["target_cmds"]["cmd_type"].unsqueeze(2)
        ).squeeze(2)
        # cmd_type_logp: [batch, num_unit]
        cmd_type_loss = -(cmd_type_logp * real_unit_mask).sum(1)

        # aggregate losses
        num_my_units_size = batch["my_units"]["num_units"].size()
        assert_eq(cmd_type_loss.size(), num_my_units_size)
        assert_eq(move_loss.size(), num_my_units_size)
        assert_eq(attack_loss.size(), num_my_units_size)
        assert_eq(gather_loss.size(), num_my_units_size)
        assert_eq(build_unit_loss.size(), num_my_units_size)
        assert_eq(build_building_loss.size(), num_my_units_size)
        unit_loss = (
            cmd_type_loss
            + move_loss
            + attack_loss
            + gather_loss
            + build_unit_loss
            + build_building_loss
        )

        # unit_loss = cmd_type_loss
        # unit_loss = move_loss + attack_loss
        # unit_loss = build_unit_loss + build_building_loss
        # unit_loss = gather_loss
        # unit_loss = cmd_type_loss
        # unit_loss = attack_loss

        log_prob = -unit_loss

        all_loss = {
            "unit_loss": unit_loss.detach(),
            "cmd_type_loss": cmd_type_loss.detach(),
            "move_loss": move_loss.detach(),
            "attack_loss": attack_loss.detach(),
            "gather_loss": gather_loss.detach(),
            "build_unit_loss": build_unit_loss.detach(),
            "build_building_loss": build_building_loss.detach(),
        }

        return log_prob, all_loss
Exemplo n.º 20
0
    def compute_loss(self, batch, *, mean=True):
        # army_feat, enemy_feat, resource_feat, glob_feat, map_feat = self.conv_encoder(batch)
        features = self.conv_encoder(batch)

        real_unit_mask, cmd_type_mask = create_loss_masks(
            batch['my_units']['num_units'], batch['target_cmds']['cmd_type'],
            self.num_cmd_type)

        # global continue classfier
        glob_feat = torch.cat([
            features['sum_army'], features['sum_enemy'],
            features['sum_resource'], features['money_feat'],
            features['sum_inst']
        ], 1)
        glob_cont_loss = self.glob_cont_cls.compute_loss(
            glob_feat, batch['glob_cont'])

        army_inst = torch.cat([features['army_feat'], features['inst_feat']],
                              2)
        # action-arg classifiers
        gather_loss = self.gather_cls.compute_loss(
            army_inst, features['resource_feat'], None,
            batch['resource_units']['num_units'],
            batch['target_cmds']['target_gather_idx'],
            cmd_type_mask[:, :, gc.CmdTypes.GATHER.value])
        attack_loss = self.attack_cls.compute_loss(
            army_inst,
            features['enemy_feat'],
            None,  # glob_feat,
            batch['enemy_units']['num_units'],
            batch['target_cmds']['target_attack_idx'],
            cmd_type_mask[:, :, gc.CmdTypes.ATTACK.value])
        build_building_loss = self.build_building_cls.compute_loss(
            army_inst,
            features['map_feat'],
            None,  # glob_feat,
            batch['target_cmds']['target_type'],
            batch['target_cmds']['target_x'],
            batch['target_cmds']['target_y'],
            cmd_type_mask[:, :, gc.CmdTypes.BUILD_BUILDING.value])
        build_unit_loss, nil_build_logp = self.build_unit_cls.compute_loss(
            army_inst,
            None,  # glob_feat,
            batch['target_cmds']['target_type'],
            cmd_type_mask[:, :, gc.CmdTypes.BUILD_UNIT.value],
            include_nil=True)
        move_loss = self.move_cls.compute_loss(
            army_inst,
            features['map_feat'],
            None,  # glob_feat,
            batch['target_cmds']['target_x'],
            batch['target_cmds']['target_y'],
            cmd_type_mask[:, :, gc.CmdTypes.MOVE.value])

        # type loss
        ctype_context = torch.cat([
            features['sum_army'], features['sum_enemy'],
            features['sum_resource'], features['money_feat']
        ], 1)
        cmd_type_logp = self.cmd_type_cls.compute_prob(army_inst,
                                                       ctype_context,
                                                       logp=True)
        # cmd_type_logp: [batch, num_unit, num_cmd_type]

        # extra continue
        # cmd_type_prob = self.cmd_type_cls.compute_prob(army_inst, ctype_context)
        # cont_type_prob = cmd_type_prob[:, :, gc.CmdTypes.CONT.value].clamp(max=1-1e-6)
        build_unit_type_logp = cmd_type_logp[:, :,
                                             gc.CmdTypes.BUILD_UNIT.value]
        extra_cont_logp = build_unit_type_logp + nil_build_logp
        # extra_cont_logp: [batch, num_unit]
        # print('extra cont logp size:', extra_cont_logp.size())
        # the following hack only works if CONT is the last one
        assert gc.CmdTypes.CONT.value == len(gc.CmdTypes) - 1
        assert extra_cont_logp.size() == cmd_type_logp[:, :, gc.CmdTypes.CONT.
                                                       value].size()
        cont_logp = log_sum_exp(cmd_type_logp[:, :, gc.CmdTypes.CONT.value],
                                extra_cont_logp)
        # cont_logp: [batch, num_unit]
        cmd_type_logp = torch.cat([
            cmd_type_logp[:, :, :gc.CmdTypes.CONT.value],
            cont_logp.unsqueeze(2)
        ], 2)
        # cmd_type_logp: [batch, num_unit, num_cmd_type]
        cmd_type_logp = cmd_type_logp.gather(
            2, batch['target_cmds']['cmd_type'].unsqueeze(2)).squeeze(2)
        # cmd_type_logp: [batch, num_unit]
        cmd_type_loss = -(cmd_type_logp * real_unit_mask).sum(1)

        # aggregate losses
        num_my_units_size = batch['my_units']['num_units'].size()
        assert_eq(glob_cont_loss.size(), num_my_units_size)
        assert_eq(cmd_type_loss.size(), num_my_units_size)
        assert_eq(move_loss.size(), num_my_units_size)
        assert_eq(attack_loss.size(), num_my_units_size)
        assert_eq(gather_loss.size(), num_my_units_size)
        assert_eq(build_unit_loss.size(), num_my_units_size)
        assert_eq(build_building_loss.size(), num_my_units_size)
        unit_loss = (cmd_type_loss + move_loss + attack_loss + gather_loss +
                     build_unit_loss + build_building_loss)

        unit_loss = (1 - batch['glob_cont'].float()) * unit_loss
        loss = glob_cont_loss + unit_loss

        all_loss = {
            'loss': loss.detach(),
            'unit_loss': unit_loss.detach(),
            'cmd_type_loss': cmd_type_loss.detach(),
            'move_loss': move_loss.detach(),
            'attack_loss': attack_loss.detach(),
            'gather_loss': gather_loss.detach(),
            'build_unit_loss': build_unit_loss.detach(),
            'build_building_loss': build_building_loss.detach(),
            'glob_cont_loss': glob_cont_loss.detach()
        }

        if mean:
            for k, v in all_losses.items():
                all_losses[k] = v.mean()
            loss = loss.mean()

        return loss, all_loss
Exemplo n.º 21
0
    def forward(self, num_real_unit, cmd_type, target_type, x, y,
                target_attack_idx, target_gather_idx, enemy_feat,
                resource_feat):
        """take cmd_type, produce cmd_repr

        num_real_unit:  [batch]
        cmd_type: [batch, num_padded_unit],
          0 <= cmd_type < num_cmd_type, padding cmd should have a valid unit_type
        target_type: [batch, num_padded_unit]
        x: [batch, num_padded_unit]
        y: [batch, num_padded_unit]
        target_idx: [batch, num_padded_unit]
        enemy_feat: [batch, num_padded_enemy, target_feat_dim]
        resource_feat: [batch, num_padded_resource, target_feat_dim]

        return: cmd_feat: [cmd_type_emb; target_type_emb; xy_feat; target_feat]
        """
        batch, pnum_unit = cmd_type.size()

        assert_eq(cmd_type.size(), target_type.size())
        assert_eq(x.dim(), 2)
        assert_eq(y.dim(), 2)

        cmd_type_emb = self.cmd_type_emb(cmd_type)
        target_type_emb = self.target_type_emb(target_type)
        xy_feat = self.xy_net(torch.stack([x, y], 2))

        enemy_feat = self.enemy_feat_net(enemy_feat)
        resource_feat = self.resource_feat_net(resource_feat)
        # select the targeted enemy's feature
        target_enemy_feat = self.select_and_slice_target(
            enemy_feat, target_attack_idx, self.attribute_dim)
        target_resource_feat = self.select_and_slice_target(
            resource_feat, target_gather_idx, self.attribute_dim)

        cmd_mask = torch.zeros((batch, pnum_unit, self.num_cmd_type),
                               device=cmd_type.device)
        cmd_mask.scatter_(2, cmd_type.unsqueeze(2), 1)
        # cmd_mask [batch, pnum_unit, num_cmd_type]

        outfeat = torch.zeros((batch, pnum_unit, self.out_dim),
                              device=cmd_type.device)

        outfeat[:, :, 0:self.cmd_type_emb_end] = cmd_type_emb

        target_type_mask = (cmd_mask[:, :, gc.CmdTypes.BUILD_BUILDING.value] +
                            cmd_mask[:, :, gc.CmdTypes.BUILD_UNIT.value])
        outfeat[:, :, self.cmd_type_emb_end:self.target_type_emb_end] = (
            target_type_mask.unsqueeze(2) * target_type_emb)

        xy_feat_mask = (cmd_mask[:, :, gc.CmdTypes.BUILD_BUILDING.value] +
                        cmd_mask[:, :, gc.CmdTypes.MOVE.value])
        outfeat[:, :, self.target_type_emb_end:self.xy_feat_end] = (
            xy_feat_mask.unsqueeze(2) * xy_feat)

        enemy_feat_mask = cmd_mask[:, :, gc.CmdTypes.ATTACK.value]
        resource_feat_mask = cmd_mask[:, :, gc.CmdTypes.GATHER.value]
        outfeat[:, :, self.xy_feat_end:self.target_feat_end] = (
            enemy_feat_mask.unsqueeze(2) * target_enemy_feat +
            resource_feat_mask.unsqueeze(2) * target_resource_feat)
        return outfeat