Beispiel #1
0
class MyModel(torch.nn.Module):
    def _forward_unimplemented(self, *input: Any) -> None:
        pass

    def __init__(self, pre_train_dir: str):
        super().__init__()
        self.roberta_encoder = BertModel(config=BertConfig.from_json_file(pre_train_dir+ "config.json"))
        self.decoder_layer = XLDecoder(
            dim=args["dimension"], embedding_matrix=self.roberta_encoder.get_input_embeddings(),
            seq_len=args["max_dec_len"])

    def forward(self, input_ids, input_mask, input_seg, decode_input=None, decode_target=None):
        encoder_rep = self.roberta_encoder(input_ids, input_mask, input_seg)[0]
        return self.decoder_layer(input_ids, encoder_rep, input_mask, decode_input, decode_target,
                                  args["use_beam_search"],
                                  args["beam_width"])
class QuestionGeneration(torch.nn.Module):
    def _forward_unimplemented(self, *input: Any) -> None:
        pass

    def __init__(self, pre_train_dir: str):
        super().__init__()
        if os.path.isdir(pre_train_dir):
            self.roberta_encoder = BertModel(
                config=BertConfig.from_json_file(os.path.join(pre_train_dir, "config.json")))
        else:
            self.roberta_encoder = BertModel(
                config=BertConfig.from_pretrained(pre_train_dir))
        self.decoder_layer = XLDecoder(dim=args["dimension"],
                                       embedding_matrix=self.roberta_encoder.get_input_embeddings(),
                                       seq_len=args["max_dec_len"])

    def forward(self, input_ids, input_mask, input_seg, decode_input=None, decode_target=None):
        encoder_rep = self.roberta_encoder(input_ids, input_mask, input_seg)[0]
        # print("encoder shape: {}, encoder vector: {}".format(encoder_rep.shape, encoder_rep))
        return self.decoder_layer(input_ids, encoder_rep, input_mask, decode_input, decode_target,
                                  args["use_beam_search"],
                                  args["beam_width"])

    @classmethod
    def from_pretrained(cls, pretrained_model_path=None):
        """ load model
        :param pretrained_model_path: 模型文件绝对路径
        :return:
        """
        model = cls(pre_train_dir=args["pre_train_dir"])
        if pretrained_model_path:
            model_path = pretrained_model_path
        elif args["save_path"]:
            model_path = args["save_path"]
        else:
            raise Exception("Please input model file.")
        model.load_state_dict(torch.load(model_path, map_location=device), strict=False)

        return model
class MyModel(torch.nn.Module):
    def _forward_unimplemented(self, *input: Any) -> None:
        pass

    def __init__(self, pre_train_dir: str):
        super().__init__()
        if os.path.isdir(pre_train_dir):
            self.roberta_encoder = BertModel(
                config=BertConfig.from_json_file(os.path.join(pre_train_dir, "config.json")))
        else:
            self.roberta_encoder = BertModel(
                config=BertConfig.from_pretrained(pre_train_dir))
        self.decoder_layer = XLDecoder(
            dim=args["dimension"], embedding_matrix=self.roberta_encoder.get_input_embeddings(),
            seq_len=args["max_dec_len"])

    def forward(self, input_ids, input_mask, input_seg, decode_input=None, decode_target=None):
        # input_seg: token_type_id
        encoder_rep = self.roberta_encoder(input_ids, input_mask, input_seg)[0]
        # print("encoder shape: {}, encoder vector: {}".format(encoder_rep.shape, encoder_rep))
        return self.decoder_layer(input_ids, encoder_rep, input_mask, decode_input, decode_target,
                                  args["use_beam_search"],
                                  args["beam_width"])