Exemple #1
0
    def __init__(self,
                 args,
                 max_raw_chars,
                 max_instruction_span,
                 coach_mode,
                 num_resource_bin,
                 *,
                 num_unit_type=len(gc.UnitTypes),
                 num_cmd_type=len(gc.CmdTypes)):
        super().__init__()

        self.params = {
            'args': args,
            'max_raw_chars': max_raw_chars,
            'max_instruction_span': max_instruction_span,
            'coach_mode': coach_mode,
            'num_resource_bin': num_resource_bin,
            'num_unit_type': num_unit_type,
            'num_cmd_type': num_cmd_type,
        }

        self.args = args
        self.max_raw_chars = max_raw_chars
        self.max_instruction_span = max_instruction_span
        self.coach_mode = coach_mode

        self.pos_candidate_inst = None
        self.neg_candidate_inst = None

        self.args.inst_dict_path = self.args.inst_dict_path.replace(
            'scratch/rts_data', 'rts-replays')
        self.inst_dict = self.load_inst_dict(self.args.inst_dict_path)

        self.prev_inst_encoder = LSTMInstructionEncoder(
            self.inst_dict.total_vocab_size,
            self.args.word_emb_dim,
            self.args.word_emb_dropout,
            self.args.inst_hid_dim,
            self.inst_dict.pad_word_idx,
        )

        self.glob_encoder = ConvGlobEncoder(args, num_unit_type, num_cmd_type,
                                            num_resource_bin,
                                            self.prev_inst_encoder)

        # count encoders
        self.count_encoder = MlpEncoder(self.args.num_count_channels * 2,
                                        self.args.count_hid_dim,
                                        self.args.count_hid_dim,
                                        self.args.count_hid_layers - 1,
                                        activate_out=True)
        self.cons_count_encoder = MlpEncoder(num_unit_type,
                                             self.args.count_hid_dim // 2,
                                             self.args.count_hid_dim // 2,
                                             self.args.count_hid_layers - 1,
                                             activate_out=True)
        self.moving_avg_encoder = MlpEncoder(num_unit_type,
                                             self.args.count_hid_dim // 2,
                                             self.args.count_hid_dim // 2,
                                             self.args.count_hid_layers - 1,
                                             activate_out=True)
        self.frame_passed_encoder = nn.Embedding(
            max_instruction_span + 2,
            self.args.count_hid_dim // 2,
        )

        if self.args.glob_dropout > 0:
            self.glob_dropout = nn.Dropout(self.args.glob_dropout)

        self.glob_feat_dim = int(2.5 * self.args.count_hid_dim +
                                 self.glob_encoder.glob_dim)
        self.cont_cls = GlobClsHead(
            self.glob_feat_dim,
            self.args.inst_hid_dim,  # for reducing hyper-parameter
            2)

        if self.coach_mode == 'rnn':
            encoder = self.prev_inst_encoder
        elif self.coach_mode == 'bow':
            encoder = MeanBOWInstructionEncoder(
                self.inst_dict.total_vocab_size, self.args.inst_hid_dim,
                self.args.word_emb_dropout, self.inst_dict.pad_word_idx)
        elif self.coach_mode == 'onehot' or self.coach_mode == 'rnn_gen':
            pass
        else:
            assert False, 'unknown coach mode: %s' % self.coach_mode

        if self.coach_mode == 'rnn' or self.coach_mode == 'bow':
            self.inst_selector = RnnSelector(encoder, self.glob_feat_dim)
        else:
            self.inst_selector = None

        self.value = nn.utils.weight_norm(nn.Linear(self.glob_feat_dim, 1),
                                          dim=None)
        self.sampler = ContSoftmaxSampler('cont', 'cont_pi', 'inst', 'inst_pi')
Exemple #2
0
class ConvRnnCoach(nn.Module):
    @staticmethod
    def get_arg_parser():
        parser = ConvGlobEncoder.get_arg_parser()

        # data related
        parser.add_argument("--inst_dict_path",
                            type=str,
                            required=True,
                            help="path to dictionary")
        parser.add_argument("--max_sentence_length", type=int, default=15)
        parser.add_argument("--num_pos_inst", type=int, default=50)
        parser.add_argument("--num_neg_inst", type=int, default=50)

        # prev_inst encoder
        parser.add_argument("--word_emb_dim", type=int, default=32)
        parser.add_argument("--word_emb_dropout", type=float, default=0.0)
        parser.add_argument("--inst_hid_dim", type=int, default=128)

        # count feat encoder
        parser.add_argument(
            "--num_count_channels",
            type=int,
            default=CoachDataset.get_num_count_channels(),
        )
        parser.add_argument("--count_hid_dim", type=int, default=128)
        parser.add_argument("--count_hid_layers", type=int, default=2)

        parser.add_argument("--glob_dropout", type=float, default=0.0)

        return parser

    def save(self, model_file):
        torch.save(self.state_dict(), model_file)
        pickle.dump(self.params, open(model_file + ".params", "wb"))

    @classmethod
    def load(cls, model_file):
        params = pickle.load(open(model_file + ".params", "rb"))
        print(params)
        model = cls(**params)
        model.load_state_dict(torch.load(model_file))
        return model

    @classmethod
    def rl_load(
        cls,
        model_file,
        coach_rule_emb_size=0,
        inst_dict_path=None,
        coach_random_init=False,
    ):
        print(f"Using coach model path: {model_file}")
        params = pickle.load(open(model_file + ".params", "rb"))
        arg_dict = params["args"].__dict__
        if inst_dict_path:
            print("Using instructor dictionary path...")
            arg_dict["inst_dict_path"] = inst_dict_path

        ## Setting all dropouts to 0.0
        for k, v in arg_dict.items():
            if "dropout" in k:
                arg_dict[k] = 0.0

        params["coach_rule_emb_size"] = coach_rule_emb_size
        # print(params)
        model = cls(**params)
        model_dict = torch.load(model_file)
        strict = True

        if coach_rule_emb_size > 0:
            print("Coach rule Embedding Size has been set.")
            # filter_params = ["cont_cls", "inst_selector", "value"]
            strict = False
            # for key in model_dict.copy().keys():
            #     if key.startswith(tuple(filter_params)):
            #         del model_dict[key]

        if not coach_random_init:
            model.load_state_dict(model_dict, strict=strict)
        else:
            print("Randomly initializing coach.")

        return model

    def load_inst_dict(self, inst_dict_path):
        print("loading cmd dict from: ", inst_dict_path)
        if inst_dict_path is None or inst_dict_path == "":
            return None

        inst_dict = pickle.load(open(inst_dict_path, "rb"))
        inst_dict.set_max_sentence_length(self.args.max_sentence_length)
        return inst_dict

    def __init__(
            self,
            args,
            max_raw_chars,
            max_instruction_span,
            coach_mode,
            num_resource_bin,
            *,
            num_unit_type=len(gc.UnitTypes),
            num_cmd_type=len(gc.CmdTypes),
            coach_rule_emb_size=0,
    ):
        super().__init__()

        self.params = {
            "args": args,
            "max_raw_chars": max_raw_chars,
            "max_instruction_span": max_instruction_span,
            "coach_mode": coach_mode,
            "num_resource_bin": num_resource_bin,
            "num_unit_type": num_unit_type,
            "num_cmd_type": num_cmd_type,
        }

        self.num_unit_types = NUM_UNIT_TYPES
        self.coach_rule_emb_size = coach_rule_emb_size

        self.args = args
        self.max_raw_chars = max_raw_chars
        self.max_instruction_span = max_instruction_span
        self.coach_mode = coach_mode

        self.pos_candidate_inst = None
        self.neg_candidate_inst = None

        self.args.inst_dict_path = self.args.inst_dict_path.replace(
            "scratch/rts_data", "rts-replays")
        self.inst_dict = self.load_inst_dict(self.args.inst_dict_path)

        self.prev_inst_encoder = LSTMInstructionEncoder(
            self.inst_dict.total_vocab_size,
            self.args.word_emb_dim,
            self.args.word_emb_dropout,
            self.args.inst_hid_dim,
            self.inst_dict.pad_word_idx,
        )

        self.glob_encoder = ConvGlobEncoder(args, num_unit_type, num_cmd_type,
                                            num_resource_bin,
                                            self.prev_inst_encoder)

        # count encoders
        self.count_encoder = MlpEncoder(
            self.args.num_count_channels * 2,
            self.args.count_hid_dim,
            self.args.count_hid_dim,
            self.args.count_hid_layers - 1,
            activate_out=True,
        )
        self.cons_count_encoder = MlpEncoder(
            num_unit_type,
            self.args.count_hid_dim // 2,
            self.args.count_hid_dim // 2,
            self.args.count_hid_layers - 1,
            activate_out=True,
        )
        self.moving_avg_encoder = MlpEncoder(
            num_unit_type,
            self.args.count_hid_dim // 2,
            self.args.count_hid_dim // 2,
            self.args.count_hid_layers - 1,
            activate_out=True,
        )
        self.frame_passed_encoder = nn.Embedding(
            max_instruction_span + 2,
            self.args.count_hid_dim // 2,
        )

        if self.args.glob_dropout > 0:
            self.glob_dropout = nn.Dropout(self.args.glob_dropout)

        self.glob_feat_dim = int(2.5 * self.args.count_hid_dim +
                                 self.glob_encoder.glob_dim)
        self.cont_cls = GlobClsHead(
            self.glob_feat_dim,
            self.args.inst_hid_dim,  # for reducing hyper-parameter
            2,
        )

        if self.coach_rule_emb_size > 0:
            # TODO: Remove num units hardcoding
            self.rule_emb = nn.Embedding(NUM_UNIT_TYPES,
                                         self.coach_rule_emb_size)
            self.rule_lstm = torch.nn.LSTM(self.coach_rule_emb_size,
                                           self.glob_feat_dim,
                                           batch_first=True)

        if self.coach_mode == "rnn":
            encoder = self.prev_inst_encoder
        elif self.coach_mode == "bow":
            encoder = MeanBOWInstructionEncoder(
                self.inst_dict.total_vocab_size,
                self.args.inst_hid_dim,
                self.args.word_emb_dropout,
                self.inst_dict.pad_word_idx,
            )
        elif self.coach_mode == "onehot" or self.coach_mode == "rnn_gen":
            pass
        else:
            assert False, "unknown coach mode: %s" % self.coach_mode

        if self.coach_mode == "rnn" or self.coach_mode == "bow":
            self.inst_selector = RnnSelector(encoder, self.glob_feat_dim)
        else:
            self.inst_selector = None

        self.value = nn.utils.weight_norm(nn.Linear(self.glob_feat_dim, 1),
                                          dim=None)
        self.sampler = ContSoftmaxSampler("cont", "cont_pi", "inst", "inst_pi")

    @property
    def num_instructions(self):
        return self.args.num_pos_inst

    def _forward(self, batch):
        """shared forward function to compute glob feature"""
        count_input = torch.cat(
            [batch["count"], batch["base_count"] - batch["count"]], 1)
        count_feat = self.count_encoder(count_input)
        cons_count_feat = self.cons_count_encoder(batch["cons_count"])
        moving_avg_feat = self.moving_avg_encoder(batch["moving_enemy_count"])
        frame_passed_feat = self.frame_passed_encoder(batch["frame_passed"])
        features = self.glob_encoder(batch, use_prev_inst=True)

        glob = torch.cat(
            [
                features["sum_inst"],
                features["sum_army"],
                features["sum_enemy"],
                features["sum_resource"],
                features["money_feat"],
            ],
            dim=1,
        )

        glob_feat = torch.cat(
            [
                glob,
                count_feat,
                cons_count_feat,
                moving_avg_feat,
                frame_passed_feat,
            ],
            dim=1,
        )

        if "rule_tensor" in batch:
            assert self.coach_rule_emb_size > 0
            rule_emb = self.rule_emb(batch["rule_tensor"])
            output, (rule_feat, _) = self.rule_lstm(rule_emb)
            glob_feat = glob_feat + output[:, -1]

        if self.args.glob_dropout > 0:
            glob_feat = self.glob_dropout(glob_feat)

        return glob_feat

    def compute_loss(self, batch):
        """used for pre-training the model with dataset"""
        batch = self._format_supervised_language_input(batch)
        glob_feat = self._forward(batch)

        cont = 1 - batch["is_base_frame"]
        cont_loss = self.cont_cls.compute_loss(glob_feat, cont)
        lang_loss = self.inst_selector.compute_loss(
            batch["pos_cand_inst"],
            batch["pos_cand_inst_len"],
            batch["neg_cand_inst"],
            batch["neg_cand_inst_len"],
            batch["inst"],
            batch["inst_len"],
            glob_feat,
            batch["inst_idx"],
        )

        assert_eq(cont_loss.size(), lang_loss.size())
        lang_loss = (1 - cont.float()) * lang_loss
        loss = cont_loss + lang_loss
        loss = loss.mean()
        all_loss = {
            "loss": loss,
            "cont_loss": cont_loss.mean(),
            "lang_loss": lang_loss.mean(),
        }
        return loss, all_loss

    def compute_eval_loss(self, batch):
        batch = self._format_supervised_language_input(batch)
        glob_feat = self._forward(batch)

        cont = 1 - batch["is_base_frame"]
        cont_loss = self.cont_cls.compute_loss(glob_feat, cont)
        lang_loss = self.inst_selector.eval_loss(
            batch["pos_cand_inst"],
            batch["pos_cand_inst_len"],
            glob_feat,
            batch["inst_idx"],
        )
        assert_eq(cont_loss.size(), lang_loss.size())
        lang_loss = (1 - cont.float()) * lang_loss
        loss = cont_loss + lang_loss
        loss = loss.mean()
        all_loss = {
            "loss": loss,
            "cont_loss": cont_loss.mean(),
            "lang_loss": lang_loss.mean(),
        }
        return loss, all_loss

    def _format_supervised_language_input(self, batch):
        device = batch["prev_inst"].device
        pos_inst, pos_inst_len = self._get_pos_candidate_inst(device)
        neg_inst, neg_inst_len = self._get_neg_candidate_inst(
            device, batch["inst_idx"])
        batch["pos_cand_inst"] = pos_inst
        batch["pos_cand_inst_len"] = pos_inst_len
        batch["neg_cand_inst"] = neg_inst
        batch["neg_cand_inst_len"] = neg_inst_len
        return batch

    def _get_pos_candidate_inst(self, device):
        if (self.pos_candidate_inst is not None
                and self.pos_candidate_inst[0].device == device):
            inst, inst_len = self.pos_candidate_inst
        else:
            inst, inst_len = parse_batch_inst(self.inst_dict,
                                              range(self.args.num_pos_inst),
                                              device)
            self.pos_candidate_inst = (inst, inst_len)

        return inst, inst_len

    def _get_neg_candidate_inst(self, device, exclude_idx):
        if (self.neg_candidate_inst is not None
                and self.neg_candidate_inst[0].device == device):
            inst, inst_len = self.neg_candidate_inst
        else:
            inst, inst_len = parse_batch_inst(
                self.inst_dict,
                range(self.args.num_pos_inst, self.inst_dict.num_insts),
                device,
            )
            self.neg_candidate_inst = (inst, inst_len)

        # inst: [num_candidate, max_sentence_len]
        prob = np.ones((inst.size(0), ), dtype=np.float32)

        for idx in exclude_idx:
            if idx == self.inst_dict.unknown_inst_idx:
                continue
            idx = idx - self.args.num_pos_inst
            if idx >= 0:
                prob[idx] = 0
        prob = prob / prob.sum()

        num_candidate = inst.size(0)
        select = np.random.choice(inst.size(0),
                                  self.args.num_neg_inst,
                                  replace=False,
                                  p=prob)
        select = torch.LongTensor(select).to(device)
        # select: [num_inst,]
        inst_len = inst_len.gather(0, select)
        select = select.unsqueeze(1).repeat(1, inst.size(1))
        inst = inst.gather(0, select)
        return inst, inst_len

    # ============ RL related ============
    def format_coach_input(self, batch, prefix=""):
        frame_passed = batch["frame_passed"].squeeze(1)
        frame_passed = frame_passed.clamp(max=self.max_instruction_span + 1)
        data = {
            "prev_inst_idx": batch["prev_inst"].squeeze(1),
            "frame_passed": frame_passed,
            "count": batch[prefix + "count"],
            "base_count": batch[prefix + "base_count"],
            "cons_count": batch[prefix + "cons_count"],
            "moving_enemy_count": batch[prefix + "moving_enemy_count"],
        }

        if self.coach_rule_emb_size > 0:
            assert "rule_tensor" in batch
            data["rule_tensor"] = batch["rule_tensor"]

        # print(data['count'])
        # print(data['cons_count'])
        extra_data = self.glob_encoder.format_input(batch, prefix)
        data.update(extra_data)
        # print(data['prev_cmds'][0, :data['my_units']['num_units']])
        # print(data['map'][0].sum(2).sum(1))
        return data

    def _format_rl_language_input(self, batch):
        prev_inst, prev_inst_len = parse_batch_inst(
            self.inst_dict,
            batch["prev_inst_idx"].cpu().numpy(),
            batch["prev_inst_idx"].device,
        )
        batch["prev_inst"] = prev_inst
        batch["prev_inst_len"] = prev_inst_len

        inst, inst_len = self._get_pos_candidate_inst(prev_inst.device)
        batch["cand_inst"] = inst
        batch["cand_inst_len"] = inst_len
        return batch

    def rl_forward(self, batch, mode):
        """forward function use by RL"""
        batch = self._format_rl_language_input(batch)
        glob_feat = self._forward(batch)
        v = self.value(glob_feat).squeeze()

        # print("Glob Feature norm: ", glob_feat.norm(2))

        cont_prob = self.cont_cls.compute_prob(glob_feat)
        inst_prob = self.inst_selector.compute_prob(batch["cand_inst"],
                                                    batch["cand_inst_len"],
                                                    glob_feat)

        if mode is "mask":
            ############CUSTOM Analysis code################
            winning_traj_mask = [
                449,
                265,
                330,
                323,
                38,
                268,
                198,
                207,
                188,
                336,
                358,
                406,
                276,
                196,
                202,
                33,
                169,
                480,
                383,
                267,
                365,
                226,
                105,
                474,
                79,
                439,
                341,
                315,
                143,
                414,
                331,
                447,
                239,
                91,
                282,
                215,
                458,
                446,
                108,
                98,
                351,
                31,
                218,
                160,
                34,
                296,
                145,
                440,
                374,
                132,
                468,
                83,
                376,
                254,
                94,
                61,
                71,
                256,
                277,
                298,
                495,
                200,
                322,
                388,
                8,
                423,
                355,
                378,
                450,
                387,
                41,
                154,
                392,
                308,
                479,
                11,
                216,
                396,
                62,
                287,
                193,
                137,
                208,
                32,
                261,
                171,
                163,
                109,
                338,
                476,
                238,
                107,
                190,
                53,
                135,
                303,
                249,
                44,
                217,
                490,
            ]
            # winning_traj_mask = [0, 2, 11, 20, 78, 316, 9, 4, 92, 8, 14, 128, 5, 45, 12, 67, 116, 22, 80, 113]
            np_custom_mask = np.zeros(500)
            np_custom_mask[winning_traj_mask] = 1

            mask = torch.tensor(np_custom_mask).float().to(inst_prob.device)
            mask = mask.unsqueeze(0)
            inst_prob = inst_prob * mask
            inst_prob = inst_prob / inst_prob.sum(1, keepdim=True)
        else:
            assert mode == "full"

        output = {"cont_pi": cont_prob, "inst_pi": inst_prob, "v": v}
        return output

    def sample(self, batch, mode, word_based=True):
        """used for actor in ELF and visually evaulating model

        return
            inst: [batch, max_sentence_len], even inst is one-hot
            inst_len: [batch]
        """
        output = self.rl_forward(batch, mode)
        samples = self.sampler.sample(output["cont_pi"], output["inst_pi"],
                                      batch["prev_inst_idx"])

        log_prob_reply = {
            "samples": samples,
            "probs": {
                self.sampler.cont_prob_key: output["cont_pi"],
                self.sampler.prob_key: output["inst_pi"],
            },
            "value": output["v"],
        }

        reply = {
            "cont": samples["cont"].unsqueeze(1),
            "cont_pi": output["cont_pi"],
            "inst": samples["inst"].unsqueeze(1),
            "inst_pi": output["inst_pi"],
        }

        # convert format needed by executor
        samples = []
        lengths = []
        raws = []
        for idx in reply["inst"]:
            inst = self.inst_dict.get_inst(int(idx.item()))
            tokens, length = self.inst_dict.parse(inst, True)
            samples.append(tokens)
            lengths.append(length)
            raw = convert_to_raw_instruction(inst, self.max_raw_chars)
            raws.append(convert_to_raw_instruction(inst, self.max_raw_chars))

        device = reply["cont"].device
        if word_based:
            # for word based
            inst = torch.LongTensor(samples).to(device)
        else:
            inst = reply["inst"]

        inst_len = torch.LongTensor(lengths).to(device)
        reply["raw_inst"] = torch.LongTensor(raws).to(device)

        return inst, inst_len, reply["cont"], reply, log_prob_reply
Exemple #3
0
class ConvRnnCoach(nn.Module):
    @staticmethod
    def get_arg_parser():
        parser = ConvGlobEncoder.get_arg_parser()

        # data related
        parser.add_argument('--inst_dict_path',
                            type=str,
                            required=True,
                            help='path to dictionary')
        parser.add_argument('--max_sentence_length', type=int, default=15)
        parser.add_argument('--num_pos_inst', type=int, default=50)
        parser.add_argument('--num_neg_inst', type=int, default=50)

        # prev_inst encoder
        parser.add_argument('--word_emb_dim', type=int, default=32)
        parser.add_argument('--word_emb_dropout', type=float, default=0.0)
        parser.add_argument('--inst_hid_dim', type=int, default=128)

        # count feat encoder
        parser.add_argument('--num_count_channels',
                            type=int,
                            default=CoachDataset.get_num_count_channels())
        parser.add_argument('--count_hid_dim', type=int, default=128)
        parser.add_argument('--count_hid_layers', type=int, default=2)

        parser.add_argument('--glob_dropout', type=float, default=0.0)

        return parser

    def save(self, model_file):
        torch.save(self.state_dict(), model_file)
        pickle.dump(self.params, open(model_file + '.params', 'wb'))

    @classmethod
    def load(cls, model_file):
        params = pickle.load(open(model_file + '.params', 'rb'))
        print(params)
        model = cls(**params)
        model.load_state_dict(torch.load(model_file))
        return model

    def load_inst_dict(self, inst_dict_path):
        print('loading cmd dict from: ', inst_dict_path)
        if inst_dict_path is None or inst_dict_path == '':
            return None

        inst_dict = pickle.load(open(inst_dict_path, 'rb'))
        inst_dict.set_max_sentence_length(self.args.max_sentence_length)
        return inst_dict

    def __init__(self,
                 args,
                 max_raw_chars,
                 max_instruction_span,
                 coach_mode,
                 num_resource_bin,
                 *,
                 num_unit_type=len(gc.UnitTypes),
                 num_cmd_type=len(gc.CmdTypes)):
        super().__init__()

        self.params = {
            'args': args,
            'max_raw_chars': max_raw_chars,
            'max_instruction_span': max_instruction_span,
            'coach_mode': coach_mode,
            'num_resource_bin': num_resource_bin,
            'num_unit_type': num_unit_type,
            'num_cmd_type': num_cmd_type,
        }

        self.args = args
        self.max_raw_chars = max_raw_chars
        self.max_instruction_span = max_instruction_span
        self.coach_mode = coach_mode

        self.pos_candidate_inst = None
        self.neg_candidate_inst = None

        self.args.inst_dict_path = self.args.inst_dict_path.replace(
            'scratch/rts_data', 'rts-replays')
        self.inst_dict = self.load_inst_dict(self.args.inst_dict_path)

        self.prev_inst_encoder = LSTMInstructionEncoder(
            self.inst_dict.total_vocab_size,
            self.args.word_emb_dim,
            self.args.word_emb_dropout,
            self.args.inst_hid_dim,
            self.inst_dict.pad_word_idx,
        )

        self.glob_encoder = ConvGlobEncoder(args, num_unit_type, num_cmd_type,
                                            num_resource_bin,
                                            self.prev_inst_encoder)

        # count encoders
        self.count_encoder = MlpEncoder(self.args.num_count_channels * 2,
                                        self.args.count_hid_dim,
                                        self.args.count_hid_dim,
                                        self.args.count_hid_layers - 1,
                                        activate_out=True)
        self.cons_count_encoder = MlpEncoder(num_unit_type,
                                             self.args.count_hid_dim // 2,
                                             self.args.count_hid_dim // 2,
                                             self.args.count_hid_layers - 1,
                                             activate_out=True)
        self.moving_avg_encoder = MlpEncoder(num_unit_type,
                                             self.args.count_hid_dim // 2,
                                             self.args.count_hid_dim // 2,
                                             self.args.count_hid_layers - 1,
                                             activate_out=True)
        self.frame_passed_encoder = nn.Embedding(
            max_instruction_span + 2,
            self.args.count_hid_dim // 2,
        )

        if self.args.glob_dropout > 0:
            self.glob_dropout = nn.Dropout(self.args.glob_dropout)

        self.glob_feat_dim = int(2.5 * self.args.count_hid_dim +
                                 self.glob_encoder.glob_dim)
        self.cont_cls = GlobClsHead(
            self.glob_feat_dim,
            self.args.inst_hid_dim,  # for reducing hyper-parameter
            2)

        if self.coach_mode == 'rnn':
            encoder = self.prev_inst_encoder
        elif self.coach_mode == 'bow':
            encoder = MeanBOWInstructionEncoder(
                self.inst_dict.total_vocab_size, self.args.inst_hid_dim,
                self.args.word_emb_dropout, self.inst_dict.pad_word_idx)
        elif self.coach_mode == 'onehot' or self.coach_mode == 'rnn_gen':
            pass
        else:
            assert False, 'unknown coach mode: %s' % self.coach_mode

        if self.coach_mode == 'rnn' or self.coach_mode == 'bow':
            self.inst_selector = RnnSelector(encoder, self.glob_feat_dim)
        else:
            self.inst_selector = None

        self.value = nn.utils.weight_norm(nn.Linear(self.glob_feat_dim, 1),
                                          dim=None)
        self.sampler = ContSoftmaxSampler('cont', 'cont_pi', 'inst', 'inst_pi')

    @property
    def num_instructions(self):
        return self.args.num_pos_inst

    def _forward(self, batch):
        """shared forward function to compute glob feature
        """
        count_input = torch.cat(
            [batch['count'], batch['base_count'] - batch['count']], 1)
        count_feat = self.count_encoder(count_input)
        cons_count_feat = self.cons_count_encoder(batch['cons_count'])
        moving_avg_feat = self.moving_avg_encoder(batch['moving_enemy_count'])
        frame_passed_feat = self.frame_passed_encoder(batch['frame_passed'])
        features = self.glob_encoder(batch, use_prev_inst=True)

        glob = torch.cat([
            features['sum_inst'], features['sum_army'], features['sum_enemy'],
            features['sum_resource'], features['money_feat']
        ],
                         dim=1)

        glob_feat = torch.cat([
            glob,
            count_feat,
            cons_count_feat,
            moving_avg_feat,
            frame_passed_feat,
        ],
                              dim=1)

        if self.args.glob_dropout > 0:
            glob_feat = self.glob_dropout(glob_feat)

        return glob_feat

    def compute_loss(self, batch):
        """used for pre-training the model with dataset
        """
        batch = self._format_supervised_language_input(batch)
        glob_feat = self._forward(batch)

        cont = 1 - batch['is_base_frame']
        cont_loss = self.cont_cls.compute_loss(glob_feat, cont)
        lang_loss = self.inst_selector.compute_loss(
            batch['pos_cand_inst'], batch['pos_cand_inst_len'],
            batch['neg_cand_inst'], batch['neg_cand_inst_len'], batch['inst'],
            batch['inst_len'], glob_feat, batch['inst_idx'])

        assert_eq(cont_loss.size(), lang_loss.size())
        lang_loss = (1 - cont.float()) * lang_loss
        loss = cont_loss + lang_loss
        loss = loss.mean()
        all_loss = {
            'loss': loss,
            'cont_loss': cont_loss.mean(),
            'lang_loss': lang_loss.mean()
        }
        return loss, all_loss

    def compute_eval_loss(self, batch):
        batch = self._format_supervised_language_input(batch)
        glob_feat = self._forward(batch)

        cont = 1 - batch['is_base_frame']
        cont_loss = self.cont_cls.compute_loss(glob_feat, cont)
        lang_loss = self.inst_selector.eval_loss(batch['pos_cand_inst'],
                                                 batch['pos_cand_inst_len'],
                                                 glob_feat, batch['inst_idx'])
        assert_eq(cont_loss.size(), lang_loss.size())
        lang_loss = (1 - cont.float()) * lang_loss
        loss = cont_loss + lang_loss
        loss = loss.mean()
        all_loss = {
            'loss': loss,
            'cont_loss': cont_loss.mean(),
            'lang_loss': lang_loss.mean()
        }
        return loss, all_loss

    def _format_supervised_language_input(self, batch):
        device = batch['prev_inst'].device
        pos_inst, pos_inst_len = self._get_pos_candidate_inst(device)
        neg_inst, neg_inst_len = self._get_neg_candidate_inst(
            device, batch['inst_idx'])
        batch['pos_cand_inst'] = pos_inst
        batch['pos_cand_inst_len'] = pos_inst_len
        batch['neg_cand_inst'] = neg_inst
        batch['neg_cand_inst_len'] = neg_inst_len
        return batch

    def _get_pos_candidate_inst(self, device):
        if (self.pos_candidate_inst is not None
                and self.pos_candidate_inst[0].device == device):
            inst, inst_len = self.pos_candidate_inst
        else:
            inst, inst_len = parse_batch_inst(self.inst_dict,
                                              range(self.args.num_pos_inst),
                                              device)
            self.pos_candidate_inst = (inst, inst_len)

        return inst, inst_len

    def _get_neg_candidate_inst(self, device, exclude_idx):
        if (self.neg_candidate_inst is not None
                and self.neg_candidate_inst[0].device == device):
            inst, inst_len = self.neg_candidate_inst
        else:
            inst, inst_len = parse_batch_inst(
                self.inst_dict,
                range(self.args.num_pos_inst, self.inst_dict.num_insts),
                device)
            self.neg_candidate_inst = (inst, inst_len)

        # inst: [num_candidate, max_sentence_len]
        prob = np.ones((inst.size(0), ), dtype=np.float32)

        for idx in exclude_idx:
            if idx == self.inst_dict.unknown_inst_idx:
                continue
            idx = idx - self.args.num_pos_inst
            if idx >= 0:
                prob[idx] = 0
        prob = prob / prob.sum()

        num_candidate = inst.size(0)
        select = np.random.choice(inst.size(0),
                                  self.args.num_neg_inst,
                                  replace=False,
                                  p=prob)
        select = torch.LongTensor(select).to(device)
        # select: [num_inst,]
        inst_len = inst_len.gather(0, select)
        select = select.unsqueeze(1).repeat(1, inst.size(1))
        inst = inst.gather(0, select)
        return inst, inst_len

    # ============ RL related ============
    def format_coach_input(self, batch, prefix=''):
        frame_passed = batch['frame_passed'].squeeze(1)
        frame_passed = frame_passed.clamp(max=self.max_instruction_span + 1)
        data = {
            'prev_inst_idx': batch['prev_inst'].squeeze(1),
            'frame_passed': frame_passed,
            'count': batch[prefix + 'count'],
            'base_count': batch[prefix + 'base_count'],
            'cons_count': batch[prefix + 'cons_count'],
            'moving_enemy_count': batch[prefix + 'moving_enemy_count'],
        }
        # print(data['count'])
        # print(data['cons_count'])
        extra_data = self.glob_encoder.format_input(batch, prefix)
        data.update(extra_data)
        # print(data['prev_cmds'][0, :data['my_units']['num_units']])
        # print(data['map'][0].sum(2).sum(1))
        return data

    def _format_rl_language_input(self, batch):
        prev_inst, prev_inst_len = parse_batch_inst(
            self.inst_dict, batch['prev_inst_idx'].cpu().numpy(),
            batch['prev_inst_idx'].device)
        batch['prev_inst'] = prev_inst
        batch['prev_inst_len'] = prev_inst_len

        inst, inst_len = self._get_pos_candidate_inst(prev_inst.device)
        batch['cand_inst'] = inst
        batch['cand_inst_len'] = inst_len
        return batch

    def rl_forward(self, batch, mode, agent_mask=None):
        """forward function use by RL
        """
        batch = self._format_rl_language_input(batch)
        glob_feat = self._forward(batch)
        v = self.value(glob_feat).squeeze()
        cont_prob = self.cont_cls.compute_prob(glob_feat)
        inst_prob = self.inst_selector.compute_prob(batch['cand_inst'],
                                                    batch['cand_inst_len'],
                                                    glob_feat)

        if mode in ['good', 'better', 'custom']:
            #assert False

            if mode == 'better':
                mask = torch.tensor(gc.better_inst_mask).float().to(
                    inst_prob.device)
            elif mode == 'good':
                mask = torch.tensor(gc.good_inst_mask).float().to(
                    inst_prob.device)
            else:
                assert agent_mask is not None
                mask = torch.tensor(agent_mask).float().to(inst_prob.device)

            mask = mask.unsqueeze(0)
            inst_prob = inst_prob * mask
            inst_prob = inst_prob / inst_prob.sum(1, keepdim=True)
        else:
            assert mode == 'full'

        output = {'cont_pi': cont_prob, 'inst_pi': inst_prob, 'v': v}
        return output

    def sample(self, batch, mode, word_based=True, agent_mask=None):
        """used for actor in ELF and visually evaulating model

        return
            inst: [batch, max_sentence_len], even inst is one-hot
            inst_len: [batch]
        """
        output = self.rl_forward(batch, mode, agent_mask=agent_mask)
        samples = self.sampler.sample(output['cont_pi'], output['inst_pi'],
                                      batch['prev_inst_idx'])

        reply = {
            'cont': samples['cont'].unsqueeze(1),
            'cont_pi': output['cont_pi'],
            'inst': samples['inst'].unsqueeze(1),
            'inst_pi': output['inst_pi'],
        }

        # convert format needed by executor
        samples = []
        lengths = []
        raws = []
        for idx in reply['inst']:
            inst = self.inst_dict.get_inst(int(idx.item()))
            tokens, length = self.inst_dict.parse(inst, True)
            samples.append(tokens)
            lengths.append(length)
            raw = convert_to_raw_instruction(inst, self.max_raw_chars)
            raws.append(convert_to_raw_instruction(inst, self.max_raw_chars))

        device = reply['cont'].device
        if word_based:
            # for word based
            inst = torch.LongTensor(samples).to(device)
        else:
            inst = reply['inst']

        inst_len = torch.LongTensor(lengths).to(device)
        reply['raw_inst'] = torch.LongTensor(raws).to(device)
        return inst, inst_len, reply['cont'], reply
Exemple #4
0
    def __init__(
            self,
            args,
            max_raw_chars,
            max_instruction_span,
            coach_mode,
            num_resource_bin,
            *,
            num_unit_type=len(gc.UnitTypes),
            num_cmd_type=len(gc.CmdTypes),
            coach_rule_emb_size=0,
    ):
        super().__init__()

        self.params = {
            "args": args,
            "max_raw_chars": max_raw_chars,
            "max_instruction_span": max_instruction_span,
            "coach_mode": coach_mode,
            "num_resource_bin": num_resource_bin,
            "num_unit_type": num_unit_type,
            "num_cmd_type": num_cmd_type,
        }

        self.num_unit_types = NUM_UNIT_TYPES
        self.coach_rule_emb_size = coach_rule_emb_size

        self.args = args
        self.max_raw_chars = max_raw_chars
        self.max_instruction_span = max_instruction_span
        self.coach_mode = coach_mode

        self.pos_candidate_inst = None
        self.neg_candidate_inst = None

        self.args.inst_dict_path = self.args.inst_dict_path.replace(
            "scratch/rts_data", "rts-replays")
        self.inst_dict = self.load_inst_dict(self.args.inst_dict_path)

        self.prev_inst_encoder = LSTMInstructionEncoder(
            self.inst_dict.total_vocab_size,
            self.args.word_emb_dim,
            self.args.word_emb_dropout,
            self.args.inst_hid_dim,
            self.inst_dict.pad_word_idx,
        )

        self.glob_encoder = ConvGlobEncoder(args, num_unit_type, num_cmd_type,
                                            num_resource_bin,
                                            self.prev_inst_encoder)

        # count encoders
        self.count_encoder = MlpEncoder(
            self.args.num_count_channels * 2,
            self.args.count_hid_dim,
            self.args.count_hid_dim,
            self.args.count_hid_layers - 1,
            activate_out=True,
        )
        self.cons_count_encoder = MlpEncoder(
            num_unit_type,
            self.args.count_hid_dim // 2,
            self.args.count_hid_dim // 2,
            self.args.count_hid_layers - 1,
            activate_out=True,
        )
        self.moving_avg_encoder = MlpEncoder(
            num_unit_type,
            self.args.count_hid_dim // 2,
            self.args.count_hid_dim // 2,
            self.args.count_hid_layers - 1,
            activate_out=True,
        )
        self.frame_passed_encoder = nn.Embedding(
            max_instruction_span + 2,
            self.args.count_hid_dim // 2,
        )

        if self.args.glob_dropout > 0:
            self.glob_dropout = nn.Dropout(self.args.glob_dropout)

        self.glob_feat_dim = int(2.5 * self.args.count_hid_dim +
                                 self.glob_encoder.glob_dim)
        self.cont_cls = GlobClsHead(
            self.glob_feat_dim,
            self.args.inst_hid_dim,  # for reducing hyper-parameter
            2,
        )

        if self.coach_rule_emb_size > 0:
            # TODO: Remove num units hardcoding
            self.rule_emb = nn.Embedding(NUM_UNIT_TYPES,
                                         self.coach_rule_emb_size)
            self.rule_lstm = torch.nn.LSTM(self.coach_rule_emb_size,
                                           self.glob_feat_dim,
                                           batch_first=True)

        if self.coach_mode == "rnn":
            encoder = self.prev_inst_encoder
        elif self.coach_mode == "bow":
            encoder = MeanBOWInstructionEncoder(
                self.inst_dict.total_vocab_size,
                self.args.inst_hid_dim,
                self.args.word_emb_dropout,
                self.inst_dict.pad_word_idx,
            )
        elif self.coach_mode == "onehot" or self.coach_mode == "rnn_gen":
            pass
        else:
            assert False, "unknown coach mode: %s" % self.coach_mode

        if self.coach_mode == "rnn" or self.coach_mode == "bow":
            self.inst_selector = RnnSelector(encoder, self.glob_feat_dim)
        else:
            self.inst_selector = None

        self.value = nn.utils.weight_norm(nn.Linear(self.glob_feat_dim, 1),
                                          dim=None)
        self.sampler = ContSoftmaxSampler("cont", "cont_pi", "inst", "inst_pi")