def __init__(self, num_cmd_type, num_target_type, target_feat_dim, attribute_dim): """ num_cmd_type: num of cmd types cmd_type_emb_dim: cmd emb output dim target_type_emb: emb function for target unit type, shared with UnitEmbedding target_type_emb_dim: unit type emb output dim xy_feat_dim: output dim for xy-net target_feat_dim: feat dim of enemy/resource encoder """ super().__init__() self.num_cmd_type = num_cmd_type self.num_target_type = num_target_type self.target_feat_dim = target_feat_dim self.attribute_dim = attribute_dim self.out_dim = 4 * attribute_dim self.cmd_type_emb = nn.Embedding(num_cmd_type, attribute_dim) self.target_type_emb = nn.Embedding(num_target_type, attribute_dim) self.xy_net = nn.Sequential( weight_norm(nn.Linear(2, attribute_dim), dim=None), ) self.enemy_feat_net = nn.Sequential( weight_norm(nn.Linear(target_feat_dim, attribute_dim), dim=None), ) self.resource_feat_net = nn.Sequential( weight_norm(nn.Linear(target_feat_dim, attribute_dim), dim=None), ) # compute locations in output feature tensor self.cmd_type_emb_end = attribute_dim self.target_type_emb_end = 2 * attribute_dim self.xy_feat_end = 3 * attribute_dim self.target_feat_end = 4 * attribute_dim assert_eq(self.target_feat_end, self.out_dim)
def compute_eval_loss(self, batch): batch = self._format_language_input_with_candidate(batch) glob_feat = self._forward(batch) cont = 1 - batch['is_base_frame'] cont_loss = self.cont_cls.compute_loss(glob_feat, cont) lang_logp = self.inst_selector.compute_prob( batch['inst_input'], batch['inst'], glob_feat, log=True ) lang_loss = -lang_logp.gather(1, batch['inst_idx'].unsqueeze(1)).squeeze(1) assert_eq(cont_loss.size(), lang_loss.size()) lang_loss = (1 - cont.float()) * lang_loss loss = cont_loss + lang_loss loss = loss.mean() all_loss = { 'loss': loss, 'cont_loss': cont_loss.mean(), 'lang_loss': lang_loss.mean() } return loss, all_loss
def compute_loss(self, ufeat, globfeat, target_type, mask, *, include_nil=False): """loss, averaged by sum(mask) ufeat: [batch, padded_num_unit, ufeat_dim] globfeat: [batch, padded_num_unit, globfeat_dim] target_type: [batch, padded_num_unit] target_type[i, j] is real unit_type iff unit_{i,j} is build unit mask: [batch, padded_num_unit] mask[i, j] = 1 iff the unit is true unit and its cmd_type == BUILD_UNIT """ batch, pnum_unit, _ = ufeat.size() assert_eq(target_type.size(), (batch, pnum_unit)) assert_eq(mask.size(), (batch, pnum_unit)) logit = self.forward(ufeat, globfeat) # logit [batch, pnum_unit, num_unit_type] logp = logit2logp(logit, target_type) # logp [batch, pnum_unit] loss = -(logp * mask) # if sum_loss: loss = loss.sum(1) if not include_nil: return loss nil_type = torch.zeros_like(target_type) nil_logp = logit2logp(logit, nil_type) return loss, nil_logp
def _get_human_instruction(self, batch): assert_eq(batch['prev_inst'].size(0), 1) device = batch['prev_inst'].device inst = input('Please input your instruction\n') # inst = 'build peasant' import pdb pdb.set_trace() inst_idx = torch.zeros((1, )).long().to(device) inst_idx[0] = self.executor.inst_dict.get_inst_idx(inst) inst_cont = torch.zeros((1, )).long().to(device) if len(inst) == 0: # inst = batch['prev_inst'] inst = self.prev_inst inst_cont[0] = 1 self.prev_inst = inst raw_inst = convert_to_raw_instruction(inst, self.max_raw_chars) inst, inst_len = self.executor.inst_dict.parse(inst, True) inst = torch.LongTensor(inst).unsqueeze(0).to(device) inst_len = torch.LongTensor([inst_len]).to(device) raw_inst = torch.LongTensor([raw_inst]).to(device) reply = { 'inst': inst_idx.unsqueeze(1), 'inst_pi': torch.ones(1, self.num_insts).to(device) / self.num_insts, 'cont': inst_cont.unsqueeze(1), 'cont_pi': torch.ones(1, 2).to(device) / 2, 'raw_inst': raw_inst } return inst, inst_len, inst_cont, reply
def _forward1d(self, inst, inst_len, context): assert_eq(inst.size(0), context.size(0)) inst_feat = self.inst_proj(self.encoder(inst, inst_len)) # inst_feat = self.encoder(inst, inst_len) # context = self.inst_proj(context) logit = (inst_feat * context).sum(1) return logit
def sample(self, cont_probs, probs, prev_samples): """Categorical sampler that can persist the previous samples with some probability. Args: cont_probs: probabilities of keeping the previous samples [batch, 2] probs: probabilities returned by model forward [batch, num_actions] prev_samples: previosly sampled actions [batch] return: cont_samples: [batch] samples: [batch] """ assert_eq(cont_probs.size(1), 2) cont_samples = cont_probs.multinomial(1).squeeze(1) new_samples = probs.multinomial(1).squeeze(1) assert_eq(prev_samples.size(), new_samples.size()) samples = cont_samples * prev_samples + (1 - cont_samples) * new_samples return { self.cont_key: cont_samples, self.key: samples, }
def compute_loss(self, batch): """used for pre-training the model with dataset""" batch = self._format_supervised_language_input(batch) glob_feat = self._forward(batch) cont = 1 - batch["is_base_frame"] cont_loss = self.cont_cls.compute_loss(glob_feat, cont) lang_loss = self.inst_selector.compute_loss( batch["pos_cand_inst"], batch["pos_cand_inst_len"], batch["neg_cand_inst"], batch["neg_cand_inst_len"], batch["inst"], batch["inst_len"], glob_feat, batch["inst_idx"], ) assert_eq(cont_loss.size(), lang_loss.size()) lang_loss = (1 - cont.float()) * lang_loss loss = cont_loss + lang_loss loss = loss.mean() all_loss = { "loss": loss, "cont_loss": cont_loss.mean(), "lang_loss": lang_loss.mean(), } return loss, all_loss
def forward(self, ufeat, efeat, globfeat, num_enemy): """return masked prob that each real enemy is the target return: prob: [batch, pnum_unit, pnum_enemy] """ batch, pnum_unit, _ = ufeat.size() pnum_enemy = efeat.size(1) assert_eq(num_enemy.size(), (batch, )) assert globfeat is None and self.globfeat_dim == 0 infeat = ufeat # infeat [batch, pnum_unit, in_dim] proj = self.net(infeat) proj = proj.unsqueeze(2).repeat(1, 1, pnum_enemy, 1) # proj [batch, pnum_unit, pnum_enemy, efeat_dim] efeat = efeat.unsqueeze(1).repeat(1, pnum_unit, 1, 1) # efeat [batch, pnum_unit, pnum_enemy, efeat_dim logit = (proj * efeat).sum(3) / self.norm # logit [batch, pnum_unit, pnum_enemy] enemy_mask = create_real_unit_mask(num_enemy, pnum_enemy) # enemy_mask [batch, pnum_enemy] enemy_mask = enemy_mask.unsqueeze(1).repeat(1, pnum_unit, 1) # if torch.isinf(logit).any() or torch.isnan(logit).any(): # import pdb # pdb.set_trace() prob = masked_softmax(logit, enemy_mask, 2) # prob [batch, pnum_unit, pnum_enemy] return prob
def compute_loss(self, batch): """used for pre-training the model with dataset """ batch = self._format_supervised_language_input(batch) glob_feat = self._forward(batch) cont = 1 - batch['is_base_frame'] cont_loss = self.cont_cls.compute_loss(glob_feat, cont) lang_loss = self.inst_selector.compute_loss( batch['pos_cand_inst'], batch['pos_cand_inst_len'], batch['neg_cand_inst'], batch['neg_cand_inst_len'], batch['inst'], batch['inst_len'], glob_feat, batch['inst_idx']) assert_eq(cont_loss.size(), lang_loss.size()) lang_loss = (1 - cont.float()) * lang_loss loss = cont_loss + lang_loss loss = loss.mean() all_loss = { 'loss': loss, 'cont_loss': cont_loss.mean(), 'lang_loss': lang_loss.mean() } return loss, all_loss
def compute_loss(self, ufeat, efeat, globfeat, num_enemy, target_idx, mask): """loss, averaged by sum(mask) ufeat: [batch, padded_num_unit, ufeat_dim] efeat: [batch, padded_num_enemy, efeat_dim] globfeat: [batch, padded_num_unit, globfeat_dim] num_enemy: [batch] target_idx: [batch, padded_num_unit] target_idx[i, j] is real target idx iff unit_{i,j} is attacking mask: [batch, padded_num_unit] mask[i] = 1 iff unit i is true unit and its cmd_type == ATTACK """ if target_idx.min() < 0 or (target_idx.max(1)[0] > num_enemy).any(): import pdb pdb.set_trace() batch, pnum_unit, _ = ufeat.size() assert_eq(target_idx.size(), (batch, pnum_unit)) assert_eq(mask.size(), (batch, pnum_unit)) prob = self.forward(ufeat, efeat, globfeat, num_enemy) # prob [batch, pnum_unit, pnum_enemy] prob = prob.gather(2, target_idx.unsqueeze(2)).squeeze(2) # prob [batch, pnum_unit] logp = (prob + 1e-6).log() loss = -(logp * mask).sum(1) return loss
def _select_feat_with_loc(feat, x, y): """select feature given locations feat: [batch, nc, h, w] x: [batch, pnum_unit], range: [0, 1) y: [batch, pnum_unit], range: [0, 1) return: selected_feat: [batch, pnum_unit, nc(feat_dim)] Note: returned feature is not masked at all """ # feat: [batch, nc, h, w] x = x.long() y = y.long() if not x.max() < gc.MAP_X and x.min() >= 0: import pdb pdb.set_trace() assert (x.max() < gc.MAP_X and x.min() >= 0) assert (y.max() < gc.MAP_Y and y.min() >= 0) loc = y * gc.MAP_X + x # loc: [batch, pnum_unit] batch, nc, h, w = feat.size() assert_eq(h, gc.MAP_Y) assert_eq(w, gc.MAP_X) feat = feat.view(batch, nc, h * w) # feat: [batch, nc, h * w] loc = loc.unsqueeze(1).repeat(1, nc, 1) selected_feat = feat.gather(2, loc) # selected_feat: [batch, nc, pnum_unit] selected_feat = selected_feat.transpose(1, 2).contiguous() return selected_feat
def forward(self, prev_cmd, num_cmd): """take cmd_type, produce cmd_repr prev_cmd: [batch, num_padded_unit, num_padded_prev_cmds], num_cmd: [batch, num_padded_unit] 0 <= ctype < num_ctype, padding cmd should have a valid unit_type """ assert_eq(prev_cmd.dim(), 3) ctype_emb = self.ctype_emb(prev_cmd) # ctype_emb [batch, num_padded_unit, num_padded_prev_cmds, emb_dim] ctype_emb = ctype_emb.sum(2) return ctype_emb
def logit2logp(logit, index): """return log_softmax(logit)[index] along dim=2 logit: [batch, padded_num_unit, dim] index: [batch, padded_num_unit] return logp: [batch, padded_num_unit] """ assert_eq(logit.size()[:2], index.size()) assert_eq(logit.dim(), 3) logp = nn.functional.log_softmax(logit, 2) logp = logp.gather(2, index.unsqueeze(2)).squeeze(2) return logp
def masked_softmax(logit, mask, dim): """masked softmax (assume dim = 3) logit: [batch, pnum_unit, pnum_enemy] mask: [batch, pnum_unit, pnum_enemy] if mask[batch, pnum_unit] == 0 for all (i.e. no enemy), will return uniform probability for that entry prob: [batch, pnum_unit, pnum_enemy] """ assert_eq(logit.size(), mask.size()) logit = logit - (1 - mask) * 1e9 logit_max = logit.max(dim, keepdim=True)[0] #.detach() exp = (logit - logit_max).exp() denom = exp.sum(dim, keepdim=True) prob = exp / exp.sum(dim, keepdim=True) return prob
def sample(self, cont_probs, probs, prev_samples): """Categorical sampler that can persist the previous samples with some probability. Args: cont_probs: probabilities of keeping the previous samples [batch, 2] probs: probabilities returned by model forward [batch, num_actions] prev_samples: previosly sampled actions [batch] return: cont_samples: [batch] samples: [batch] """ assert_eq(cont_probs.size(1), 2) if True in torch.isnan(cont_probs.sum(1)): import pdb pdb.set_trace() # print("Cont prob sum: ", cont_probs.sum(1)) # print("Prob_size: ", probs.size()) # print("Probs_sum: ", probs.sum(1)) cont_samples = cont_probs.multinomial(1).squeeze(1) new_samples = probs.multinomial(1).squeeze(1) assert_eq(prev_samples.size(), new_samples.size()) # What a crazy bug... if -1 in prev_samples: samples = new_samples cont_samples = cont_samples * 0 else: samples = cont_samples * prev_samples + ( 1 - cont_samples) * new_samples return { self.cont_key: cont_samples, self.key: samples, }
def _process_units(self, units, padded_size): types = [] xs = [] ys = [] hps = [] for u in units: x = u['x'] y = u['y'] if x < 0: x = 0 if y < 0: y = 0 types.append(u['unit_type']) xs.append(x) ys.append(y) hps.append(u['hp']) types = np.array(types, dtype=np.int64) xs = np.array(xs, dtype=np.int64) ys = np.array(ys, dtype=np.int64) hps = np.array(hps, dtype=np.int64) num_units = len(units) assert_eq(num_units, types.shape[0]) types = self._pad(types, padded_size, 0) xs = self._pad(xs, padded_size, 0) ys = self._pad(ys, padded_size, 0) hps = self._pad(hps, padded_size, 0) units_info = { 'types': types, 'xs': xs, 'ys': ys, 'hps': hps, 'num_units': num_units } return units_info
def format_executor_input(self, batch, inst, inst_len, inst_cont): """convert batch (in c++ format) to input used by executor inst: [batch, max_sentence_len], LongTensor, parsed inst_len: [batch] """ assert_eq(inst.dim(), 2) assert_eq(inst_len.dim(), 1) # assert_eq(inst_cont.dim(), 1) my_units = { 'types': batch['army_type'], 'hps': batch['army_hp'], 'xs': batch['army_x'], 'ys': batch['army_y'], 'num_units': batch['num_army'].squeeze(1), } enemy_units = { 'types': batch['enemy_type'], 'hps': batch['enemy_hp'], 'xs': batch['enemy_x'], 'ys': batch['enemy_y'], 'num_units': batch['num_enemy'].squeeze(1), } resource_units = { 'types': batch['resource_type'], 'hps': batch['resource_hp'], 'xs': batch['resource_x'], 'ys': batch['resource_y'], 'num_units': batch['num_resource'].squeeze(1), } current_cmds = { 'cmd_type': batch['current_cmd_type'], 'target_type': batch['current_cmd_unit_type'], 'target_x': batch['current_cmd_x'], 'target_y': batch['current_cmd_y'], 'target_gather_idx': batch['current_cmd_gather_idx'], 'target_attack_idx': batch['current_cmd_attack_idx'], } # print('prev cmd for executor: (num units: %d)' % batch['num_army'][0][0].item()) # print(batch['prev_cmd'][0, :batch['num_army'][0][0].item()]) prev_cmds = batch['prev_cmd'] # print('inst:', inst) # print('hist inst') # print(batch['hist_inst']) # print(batch['hist_inst_diff']) # hist_inst = batch['hist_inst'] # bsize, num_inst = hist_inst.size() # parsed_hist_inst, hist_inst_len = parse_batch_inst( # self.inst_dict, hist_inst.view(-1), hist_inst.device) # parsed_hist_inst = parsed_hist_inst.view(bsize, num_inst, -1) # hist_inst_len = hist_inst_len.view(bsize, num_inst) # print(parsed_hist_inst) # print('@@@@@@@@@@, after correction') # import time # t = time.time() hist_inst, hist_inst_len, hist_inst_diff = parse_hist_inst( self.inst_dict, batch['hist_inst'], batch['hist_inst_diff'], inst, inst_len, inst_cont, is_word_based(self.args.inst_encoder_type)) # print('time for executor format hist:', time.time() - t) # print('hist_inst:\n', hist_inst) # print('hist_inst_len:\n', hist_inst_len) # print('hist_inst_diff:\n', hist_inst_diff) # print('@@@@@@@@@@') data = { 'inst': inst, 'inst_len': inst_len, 'hist_inst': hist_inst, 'hist_inst_len': hist_inst_len, 'hist_inst_diff': hist_inst_diff, 'resource_bin': batch['resource_bin'], 'my_units': my_units, 'enemy_units': enemy_units, 'resource_units': resource_units, 'prev_cmds': prev_cmds, 'current_cmds': current_cmds, 'map': batch['map'], } return data
def format_rl_executor_input(self, batch, inst, inst_len, inst_cont): """convert batch (in c++ format) to input used by executor inst: [batch, max_sentence_len], LongTensor, parsed inst_len: [batch] """ assert_eq(inst.dim(), 2) assert_eq(inst_len.dim(), 1) # assert_eq(inst_cont.dim(), 1) my_units = { "types": batch["army_type"], "hps": batch["army_hp"], "xs": batch["army_x"], "ys": batch["army_y"], "num_units": batch["num_army"].squeeze(1), } enemy_units = { "types": batch["enemy_type"], "hps": batch["enemy_hp"], "xs": batch["enemy_x"], "ys": batch["enemy_y"], "num_units": batch["num_enemy"].squeeze(1), } resource_units = { "types": batch["resource_type"], "hps": batch["resource_hp"], "xs": batch["resource_x"], "ys": batch["resource_y"], "num_units": batch["num_resource"].squeeze(1), } # target_cmds = { # 'cmd_type': batch['t_current_cmd_type'], # 'target_type': batch['t_current_cmd_unit_type'], # 'target_x': batch['t_current_cmd_x'], # 'target_y': batch['t_current_cmd_y'], # 'target_gather_idx': batch['t_current_cmd_gather_idx'], # 'target_attack_idx': batch['t_current_cmd_attack_idx'], # } current_cmds = { "cmd_type": batch["current_cmd_type"], "target_type": batch["current_cmd_unit_type"], "target_x": batch["current_cmd_x"], "target_y": batch["current_cmd_y"], "target_gather_idx": batch["current_cmd_gather_idx"], "target_attack_idx": batch["current_cmd_attack_idx"], } # print('prev cmd for executor: (num units: %d)' % batch['num_army'][0][0].item()) # print(batch['prev_cmd'][0, :batch['num_army'][0][0].item()]) prev_cmds = batch["prev_cmd"] # print('inst:', inst) # print('hist inst') # print(batch['hist_inst']) # print(batch['hist_inst_diff']) # hist_inst = batch['hist_inst'] # bsize, num_inst = hist_inst.size() # parsed_hist_inst, hist_inst_len = parse_batch_inst( # self.inst_dict, hist_inst.view(-1), hist_inst.device) # parsed_hist_inst = parsed_hist_inst.view(bsize, num_inst, -1) # hist_inst_len = hist_inst_len.view(bsize, num_inst) # print(parsed_hist_inst) # print('@@@@@@@@@@, after correction') # import time # t = time.time() hist_inst, hist_inst_len, hist_inst_diff = parse_hist_inst( self.inst_dict, batch["hist_inst"], batch["hist_inst_diff"], inst, inst_len, inst_cont, is_word_based(self.args.inst_encoder_type), ) # print('time for executor format hist:', time.time() - t) # print('hist_inst:\n', hist_inst) # print('hist_inst_len:\n', hist_inst_len) # print('hist_inst_diff:\n', hist_inst_diff) # print('@@@@@@@@@@') data = { "inst": inst, "inst_len": inst_len, "hist_inst": hist_inst, "hist_inst_len": hist_inst_len, "hist_inst_diff": hist_inst_diff, "resource_bin": batch["resource_bin"], "my_units": my_units, "enemy_units": enemy_units, "resource_units": resource_units, "prev_cmds": prev_cmds, "current_cmds": current_cmds, "map": batch["map"], } return data
def compute_rl_log_probs(self, batch): # army_feat, enemy_feat, resource_feat, glob_feat, map_feat = self.conv_encoder(batch) features = self.conv_encoder(batch) real_unit_mask, cmd_type_mask = create_loss_masks( batch["my_units"]["num_units"], batch["target_cmds"]["cmd_type"], self.num_cmd_type, ) # global continue classfier ## TODO: Implement the global continue classifier army_inst = torch.cat([features["army_feat"], features["inst_feat"]], 2) # action-arg classifiers gather_loss = self.gather_cls.compute_loss( army_inst, features["resource_feat"], None, batch["resource_units"]["num_units"], batch["target_cmds"]["target_gather_idx"], cmd_type_mask[:, :, gc.CmdTypes.GATHER.value], ) attack_loss = self.attack_cls.compute_loss( army_inst, features["enemy_feat"], None, # glob_feat, batch["enemy_units"]["num_units"], batch["target_cmds"]["target_attack_idx"], cmd_type_mask[:, :, gc.CmdTypes.ATTACK.value], ) build_building_loss = self.build_building_cls.compute_loss( army_inst, features["map_feat"], None, # glob_feat, batch["target_cmds"]["target_type"], batch["target_cmds"]["target_x"], batch["target_cmds"]["target_y"], cmd_type_mask[:, :, gc.CmdTypes.BUILD_BUILDING.value], ) build_unit_loss, nil_build_logp = self.build_unit_cls.compute_loss( army_inst, None, # glob_feat, batch["target_cmds"]["target_type"], cmd_type_mask[:, :, gc.CmdTypes.BUILD_UNIT.value], include_nil=True, ) move_loss = self.move_cls.compute_loss( army_inst, features["map_feat"], None, # glob_feat, batch["target_cmds"]["target_x"], batch["target_cmds"]["target_y"], cmd_type_mask[:, :, gc.CmdTypes.MOVE.value], ) # type loss ctype_context = torch.cat( [ features["sum_army"], features["sum_enemy"], features["sum_resource"], features["money_feat"], ], 1, ) cmd_type_logp = self.cmd_type_cls.compute_prob( army_inst, ctype_context, logp=True ) # cmd_type_logp: [batch, num_unit, num_cmd_type] # extra continue # cmd_type_prob = self.cmd_type_cls.compute_prob(army_inst, ctype_context) # cont_type_prob = cmd_type_prob[:, :, gc.CmdTypes.CONT.value].clamp(max=1-1e-6) build_unit_type_logp = cmd_type_logp[:, :, gc.CmdTypes.BUILD_UNIT.value] extra_cont_logp = build_unit_type_logp + nil_build_logp # extra_cont_logp: [batch, num_unit] # print('extra cont logp size:', extra_cont_logp.size()) # the following hack only works if CONT is the last one assert gc.CmdTypes.CONT.value == len(gc.CmdTypes) - 1 assert ( extra_cont_logp.size() == cmd_type_logp[:, :, gc.CmdTypes.CONT.value].size() ) cont_logp = log_sum_exp( cmd_type_logp[:, :, gc.CmdTypes.CONT.value], extra_cont_logp ) # cont_logp: [batch, num_unit] cmd_type_logp = torch.cat( [cmd_type_logp[:, :, : gc.CmdTypes.CONT.value], cont_logp.unsqueeze(2)], 2 ) # cmd_type_logp: [batch, num_unit, num_cmd_type] cmd_type_logp = cmd_type_logp.gather( 2, batch["target_cmds"]["cmd_type"].unsqueeze(2) ).squeeze(2) # cmd_type_logp: [batch, num_unit] cmd_type_loss = -(cmd_type_logp * real_unit_mask).sum(1) # aggregate losses num_my_units_size = batch["my_units"]["num_units"].size() assert_eq(cmd_type_loss.size(), num_my_units_size) assert_eq(move_loss.size(), num_my_units_size) assert_eq(attack_loss.size(), num_my_units_size) assert_eq(gather_loss.size(), num_my_units_size) assert_eq(build_unit_loss.size(), num_my_units_size) assert_eq(build_building_loss.size(), num_my_units_size) unit_loss = ( cmd_type_loss + move_loss + attack_loss + gather_loss + build_unit_loss + build_building_loss ) # unit_loss = cmd_type_loss # unit_loss = move_loss + attack_loss # unit_loss = build_unit_loss + build_building_loss # unit_loss = gather_loss # unit_loss = cmd_type_loss # unit_loss = attack_loss log_prob = -unit_loss all_loss = { "unit_loss": unit_loss.detach(), "cmd_type_loss": cmd_type_loss.detach(), "move_loss": move_loss.detach(), "attack_loss": attack_loss.detach(), "gather_loss": gather_loss.detach(), "build_unit_loss": build_unit_loss.detach(), "build_building_loss": build_building_loss.detach(), } return log_prob, all_loss
def compute_loss(self, batch, *, mean=True): # army_feat, enemy_feat, resource_feat, glob_feat, map_feat = self.conv_encoder(batch) features = self.conv_encoder(batch) real_unit_mask, cmd_type_mask = create_loss_masks( batch['my_units']['num_units'], batch['target_cmds']['cmd_type'], self.num_cmd_type) # global continue classfier glob_feat = torch.cat([ features['sum_army'], features['sum_enemy'], features['sum_resource'], features['money_feat'], features['sum_inst'] ], 1) glob_cont_loss = self.glob_cont_cls.compute_loss( glob_feat, batch['glob_cont']) army_inst = torch.cat([features['army_feat'], features['inst_feat']], 2) # action-arg classifiers gather_loss = self.gather_cls.compute_loss( army_inst, features['resource_feat'], None, batch['resource_units']['num_units'], batch['target_cmds']['target_gather_idx'], cmd_type_mask[:, :, gc.CmdTypes.GATHER.value]) attack_loss = self.attack_cls.compute_loss( army_inst, features['enemy_feat'], None, # glob_feat, batch['enemy_units']['num_units'], batch['target_cmds']['target_attack_idx'], cmd_type_mask[:, :, gc.CmdTypes.ATTACK.value]) build_building_loss = self.build_building_cls.compute_loss( army_inst, features['map_feat'], None, # glob_feat, batch['target_cmds']['target_type'], batch['target_cmds']['target_x'], batch['target_cmds']['target_y'], cmd_type_mask[:, :, gc.CmdTypes.BUILD_BUILDING.value]) build_unit_loss, nil_build_logp = self.build_unit_cls.compute_loss( army_inst, None, # glob_feat, batch['target_cmds']['target_type'], cmd_type_mask[:, :, gc.CmdTypes.BUILD_UNIT.value], include_nil=True) move_loss = self.move_cls.compute_loss( army_inst, features['map_feat'], None, # glob_feat, batch['target_cmds']['target_x'], batch['target_cmds']['target_y'], cmd_type_mask[:, :, gc.CmdTypes.MOVE.value]) # type loss ctype_context = torch.cat([ features['sum_army'], features['sum_enemy'], features['sum_resource'], features['money_feat'] ], 1) cmd_type_logp = self.cmd_type_cls.compute_prob(army_inst, ctype_context, logp=True) # cmd_type_logp: [batch, num_unit, num_cmd_type] # extra continue # cmd_type_prob = self.cmd_type_cls.compute_prob(army_inst, ctype_context) # cont_type_prob = cmd_type_prob[:, :, gc.CmdTypes.CONT.value].clamp(max=1-1e-6) build_unit_type_logp = cmd_type_logp[:, :, gc.CmdTypes.BUILD_UNIT.value] extra_cont_logp = build_unit_type_logp + nil_build_logp # extra_cont_logp: [batch, num_unit] # print('extra cont logp size:', extra_cont_logp.size()) # the following hack only works if CONT is the last one assert gc.CmdTypes.CONT.value == len(gc.CmdTypes) - 1 assert extra_cont_logp.size() == cmd_type_logp[:, :, gc.CmdTypes.CONT. value].size() cont_logp = log_sum_exp(cmd_type_logp[:, :, gc.CmdTypes.CONT.value], extra_cont_logp) # cont_logp: [batch, num_unit] cmd_type_logp = torch.cat([ cmd_type_logp[:, :, :gc.CmdTypes.CONT.value], cont_logp.unsqueeze(2) ], 2) # cmd_type_logp: [batch, num_unit, num_cmd_type] cmd_type_logp = cmd_type_logp.gather( 2, batch['target_cmds']['cmd_type'].unsqueeze(2)).squeeze(2) # cmd_type_logp: [batch, num_unit] cmd_type_loss = -(cmd_type_logp * real_unit_mask).sum(1) # aggregate losses num_my_units_size = batch['my_units']['num_units'].size() assert_eq(glob_cont_loss.size(), num_my_units_size) assert_eq(cmd_type_loss.size(), num_my_units_size) assert_eq(move_loss.size(), num_my_units_size) assert_eq(attack_loss.size(), num_my_units_size) assert_eq(gather_loss.size(), num_my_units_size) assert_eq(build_unit_loss.size(), num_my_units_size) assert_eq(build_building_loss.size(), num_my_units_size) unit_loss = (cmd_type_loss + move_loss + attack_loss + gather_loss + build_unit_loss + build_building_loss) unit_loss = (1 - batch['glob_cont'].float()) * unit_loss loss = glob_cont_loss + unit_loss all_loss = { 'loss': loss.detach(), 'unit_loss': unit_loss.detach(), 'cmd_type_loss': cmd_type_loss.detach(), 'move_loss': move_loss.detach(), 'attack_loss': attack_loss.detach(), 'gather_loss': gather_loss.detach(), 'build_unit_loss': build_unit_loss.detach(), 'build_building_loss': build_building_loss.detach(), 'glob_cont_loss': glob_cont_loss.detach() } if mean: for k, v in all_losses.items(): all_losses[k] = v.mean() loss = loss.mean() return loss, all_loss
def forward(self, num_real_unit, cmd_type, target_type, x, y, target_attack_idx, target_gather_idx, enemy_feat, resource_feat): """take cmd_type, produce cmd_repr num_real_unit: [batch] cmd_type: [batch, num_padded_unit], 0 <= cmd_type < num_cmd_type, padding cmd should have a valid unit_type target_type: [batch, num_padded_unit] x: [batch, num_padded_unit] y: [batch, num_padded_unit] target_idx: [batch, num_padded_unit] enemy_feat: [batch, num_padded_enemy, target_feat_dim] resource_feat: [batch, num_padded_resource, target_feat_dim] return: cmd_feat: [cmd_type_emb; target_type_emb; xy_feat; target_feat] """ batch, pnum_unit = cmd_type.size() assert_eq(cmd_type.size(), target_type.size()) assert_eq(x.dim(), 2) assert_eq(y.dim(), 2) cmd_type_emb = self.cmd_type_emb(cmd_type) target_type_emb = self.target_type_emb(target_type) xy_feat = self.xy_net(torch.stack([x, y], 2)) enemy_feat = self.enemy_feat_net(enemy_feat) resource_feat = self.resource_feat_net(resource_feat) # select the targeted enemy's feature target_enemy_feat = self.select_and_slice_target( enemy_feat, target_attack_idx, self.attribute_dim) target_resource_feat = self.select_and_slice_target( resource_feat, target_gather_idx, self.attribute_dim) cmd_mask = torch.zeros((batch, pnum_unit, self.num_cmd_type), device=cmd_type.device) cmd_mask.scatter_(2, cmd_type.unsqueeze(2), 1) # cmd_mask [batch, pnum_unit, num_cmd_type] outfeat = torch.zeros((batch, pnum_unit, self.out_dim), device=cmd_type.device) outfeat[:, :, 0:self.cmd_type_emb_end] = cmd_type_emb target_type_mask = (cmd_mask[:, :, gc.CmdTypes.BUILD_BUILDING.value] + cmd_mask[:, :, gc.CmdTypes.BUILD_UNIT.value]) outfeat[:, :, self.cmd_type_emb_end:self.target_type_emb_end] = ( target_type_mask.unsqueeze(2) * target_type_emb) xy_feat_mask = (cmd_mask[:, :, gc.CmdTypes.BUILD_BUILDING.value] + cmd_mask[:, :, gc.CmdTypes.MOVE.value]) outfeat[:, :, self.target_type_emb_end:self.xy_feat_end] = ( xy_feat_mask.unsqueeze(2) * xy_feat) enemy_feat_mask = cmd_mask[:, :, gc.CmdTypes.ATTACK.value] resource_feat_mask = cmd_mask[:, :, gc.CmdTypes.GATHER.value] outfeat[:, :, self.xy_feat_end:self.target_feat_end] = ( enemy_feat_mask.unsqueeze(2) * target_enemy_feat + resource_feat_mask.unsqueeze(2) * target_resource_feat) return outfeat