예제 #1
0
def train_and_validate(args):
    set_seed(args.seed)

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build model.
    model = build_model(args)

    # Load or initialize parameters.
    if args.pretrained_model_path is not None:
        # Initialize with pretrained model.
        model = load_model(model, args.pretrained_model_path)
    else:
        # Initialize with normal distribution.
        for n, p in list(model.named_parameters()):
            if 'gamma' not in n and 'beta' not in n:
                p.data.normal_(0, 0.02)

    if args.dist_train:
        # Multiprocessing distributed mode.
        mp.spawn(worker,
                 nprocs=args.ranks_num,
                 args=(args.gpu_ranks, args, model),
                 daemon=False)
    elif args.single_gpu:
        # Single GPU mode.
        worker(args.gpu_id, None, args, model)
    else:
        # CPU mode.
        worker(None, None, args, model)
예제 #2
0
 def __init__(self,
              args,
              do_lower_case=True,
              max_len=None,
              do_basic_tokenize=True,
              never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
     """Constructs a BertTokenizer.
     Args:
       vocab_file: Path to a one-wordpiece-per-line vocabulary file
       do_lower_case: Whether to lower case the input
                      Only has an effect when do_wordpiece_only=False
       do_basic_tokenize: Whether to do basic tokenization before wordpiece.
       max_len: An artificial maximum length to truncate tokenized sequences to;
                      Effective maximum length is always the minimum of this
                      value (if specified) and the underlying BERT model's
                      sequence length.
       never_split: List of tokens which will never be split during tokenization.
                      Only has an effect when do_wordpiece_only=False
     """
     self.vocab = Vocab()
     self.vocab.load(args.vocab_path, is_quiet=True)
     self.ids_to_tokens = collections.OrderedDict([
         (ids, tok) for ids, tok in enumerate(self.vocab.i2w)
     ])
     self.do_basic_tokenize = do_basic_tokenize
     if do_basic_tokenize:
         self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
                                               never_split=never_split)
     self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
     self.max_len = max_len if max_len is not None else int(1e12)
예제 #3
0
    def __init__(self, args, is_src=True):
        self.vocab = None
        self.sp_model = None
        if is_src == True:
            spm_model_path = args.spm_model_path
            vocab_path = args.vocab_path
        else:
            spm_model_path = args.tgt_spm_model_path
            vocab_path = args.tgt_vocab_path

        if spm_model_path:
            try:
                import sentencepiece as spm
            except ImportError:
                raise ImportError(
                    "You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
                    "pip install sentencepiece")
            self.sp_model = spm.SentencePieceProcessor()
            self.sp_model.Load(spm_model_path)
            self.vocab = {
                self.sp_model.IdToPiece(i): i
                for i in range(self.sp_model.GetPieceSize())
            }
        else:
            self.vocab = Vocab()
            self.vocab.load(vocab_path, is_quiet=True)
            self.vocab = self.vocab.w2i
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
예제 #4
0
파일: trainer.py 프로젝트: yingnengd/UER-py
def train_and_validate(args):
    set_seed(args.seed)

    # Load vocabulary.
    if args.spm_model_path:
        try:
            import sentencepiece as spm
        except ImportError:
            raise ImportError(
                "You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
                "pip install sentencepiece")
        sp_model = spm.SentencePieceProcessor()
        sp_model.Load(args.spm_model_path)
        args.vocab = {
            sp_model.IdToPiece(i): i
            for i in range(sp_model.GetPieceSize())
        }
        args.tokenizer = str2tokenizer[args.tokenizer](args)
        if args.target == "seq2seq":
            tgt_sp_model = spm.SentencePieceProcessor()
            tgt_sp_model.Load(args.tgt_spm_model_path)
            args.tgt_vocab = {
                tgt_sp_model.IdToPiece(i): i
                for i in range(tgt_sp_model.GetPieceSize())
            }
    else:
        args.tokenizer = str2tokenizer[args.tokenizer](args)
        args.vocab = args.tokenizer.vocab
        if args.target == "seq2seq":
            tgt_vocab = Vocab()
            tgt_vocab.load(args.tgt_vocab_path)
            args.tgt_vocab = tgt_vocab.w2i

    # Build model.
    model = build_model(args)

    # Load or initialize parameters.
    if args.pretrained_model_path is not None:
        # Initialize with pretrained model.
        model = load_model(model, args.pretrained_model_path)
    else:
        # Initialize with normal distribution.
        for n, p in list(model.named_parameters()):
            if "gamma" not in n and "beta" not in n:
                p.data.normal_(0, 0.02)

    if args.dist_train:
        # Multiprocessing distributed mode.
        mp.spawn(worker,
                 nprocs=args.ranks_num,
                 args=(args.gpu_ranks, args, model),
                 daemon=False)
    elif args.single_gpu:
        # Single GPU mode.
        worker(args.gpu_id, None, args, model)
    else:
        # CPU mode.
        worker(None, None, args, model)
예제 #5
0
class Tokenizer(object):
    def __init__(self, args):
        self.vocab = Vocab()
        self.vocab.load(args.vocab_path, is_quiet=True)

    def tokenize(self, text):
        raise NotImplementedError

    def convert_tokens_to_ids(self, tokens):
        """Converts a sequence of tokens into ids using the vocab."""
        for token in tokens:
            ids.append(self.vocab.w2i[token])
        return ids
예제 #6
0
    def __init__(self,
                 args,
                 do_lower_case=True,
                 max_len=None,
                 do_basic_tokenize=True,
                 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
        """Constructs a BertTokenizer.
        Args:
          vocab_file: Path to a one-wordpiece-per-line vocabulary file
          do_lower_case: Whether to lower case the input
                         Only has an effect when do_wordpiece_only=False
          do_basic_tokenize: Whether to do basic tokenization before wordpiece.
          max_len: An artificial maximum length to truncate tokenized sequences to;
                         Effective maximum length is always the minimum of this
                         value (if specified) and the underlying BERT model's
                         sequence length.
          never_split: List of tokens which will never be split during tokenization.
                         Only has an effect when do_wordpiece_only=False
        """
        self.sp_model = None
        if args.spm_model_path:
            try:
                import sentencepiece as spm
            except ImportError:
                raise ImportError(
                    "Please install sentencepiece from https://github.com/google/sentencepiece."
                )

            self.sp_model = spm.SentencePieceProcessor(
                model_file=args.spm_model_path)
            # Note(mingdachen): For the purpose of consisent API, we are
            # generating a vocabulary for the sentence piece tokenizer.
            self.vocab = {
                self.sp_model.IdToPiece(i): i
                for i in range(self.sp_model.GetPieceSize())
            }

        else:
            self.vocab = Vocab()
            self.vocab.load(args.vocab_path, is_quiet=True)
            self.ids_to_tokens = collections.OrderedDict([
                (ids, tok) for ids, tok in enumerate(self.vocab.i2w)
            ])
            self.do_basic_tokenize = do_basic_tokenize
            if do_basic_tokenize:
                self.basic_tokenizer = BasicTokenizer(
                    do_lower_case=do_lower_case, never_split=never_split)
            self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
        self.max_len = max_len if max_len is not None else int(1e12)
예제 #7
0
def get_args_aida_task():
    args = Bunch()
    args.seq_len = 64
    args.row_wise_fill = True
    args.mask_mode = 'cross-wise'
    args.additional_ban = 2
    args.table_object = 'first-column'
    args.pooling = 'avg-token'

    args.pretrained_model_path = "./models/bert_model.bin-000"
    args.vocab_path = 'models/google_uncased_en_vocab.txt'
    args.vocab = Vocab()
    args.vocab.load(args.vocab_path)
    args.emb_size = 768
    args.embedding = 'tab'  # before: bert
    args.encoder = 'bertTab'
    args.subword_type = 'none'
    args.tokenizer = 'bert'
    args.tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"](args)

    args.feedforward_size = 3072
    args.hidden_size = 768
    args.heads_num = 12
    args.layers_num = 12
    args.learning_rate = 2e-5
    args.batch_size = 4
    args.dropout = 0.1

    # args.target = 'bert'
    return args
예제 #8
0
class Tokenizer(object):
    def __init__(self, args, is_src=True):
        self.vocab = None
        self.sp_model = None
        if is_src == True:
            spm_model_path = args.spm_model_path
            vocab_path = args.vocab_path
        else:
            spm_model_path = args.tgt_spm_model_path
            vocab_path = args.tgt_vocab_path

        if spm_model_path:
            try:
                import sentencepiece as spm
            except ImportError:
                raise ImportError(
                    "You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
                    "pip install sentencepiece")
            self.sp_model = spm.SentencePieceProcessor()
            self.sp_model.Load(spm_model_path)
            self.vocab = {
                self.sp_model.IdToPiece(i): i
                for i in range(self.sp_model.GetPieceSize())
            }
        else:
            self.vocab = Vocab()
            self.vocab.load(vocab_path, is_quiet=True)
            self.vocab = self.vocab.w2i
        self.inv_vocab = {v: k for k, v in self.vocab.items()}

    def tokenize(self, text):
        raise NotImplementedError

    def convert_tokens_to_ids(self, tokens):
        if self.sp_model:
            return [
                self.sp_model.PieceToId(printable_text(token))
                for token in tokens
            ]
        else:
            return convert_by_vocab(self.vocab, tokens)

    def convert_ids_to_tokens(self, ids):
        if self.sp_model:
            return [self.sp_model.IdToPiece(id_) for id_ in ids]
        else:
            return convert_by_vocab(self.inv_vocab, ids)
예제 #9
0
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Path options.
    parser.add_argument("--corpus_path", type=str, required=True,
                        help="Path of the corpus for pretraining.")
    parser.add_argument("--vocab_path", type=str, required=True,
                        help="Path of the vocabulary file.")
    parser.add_argument("--dataset_path", type=str, default="dataset.pt",
                        help="Path of the preprocessed dataset.")
    
    # Preprocess options.
    parser.add_argument("--tokenizer", choices=["bert", "char", "space"], default="bert",
                        help="Specify the tokenizer." 
                             "Original Google BERT uses bert tokenizer on Chinese corpus."
                             "Char tokenizer segments sentences into characters."
                             "Space tokenizer segments sentences into words according to space."
                             )
    parser.add_argument("--processes_num", type=int, default=1,
                        help="Split the whole dataset into `processes_num` parts, "
                             "and each part is fed to a single process in training step.")
    parser.add_argument("--target", choices=["bert", "lm", "cls", "mlm", "nsp", "s2s", "bilm"], default="bert",
                        help="The training target of the pretraining model.")
    parser.add_argument("--docs_buffer_size", type=int, default=100000,
                        help="The buffer size of documents in memory, specific to targets that require negative sampling.")
    parser.add_argument("--instances_buffer_size", type=int, default=25600,
                        help="The buffer size of instances in memory.")
    parser.add_argument("--seq_length", type=int, default=128, help="Sequence length of instances.")
    parser.add_argument("--dup_factor", type=int, default=5,
                        help="Duplicate instances multiple times.")
    parser.add_argument("--short_seq_prob", type=float, default=0.1,
                        help="Probability of truncating sequence."
                             "The larger value, the higher probability of using short (truncated) sequence.")
    parser.add_argument("--seed", type=int, default=7, help="Random seed.")
    
    args = parser.parse_args()
    
    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    
    # Build tokenizer.
    tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"](args)
        
    # Build and save dataset.
    dataset = globals()[args.target.capitalize() + "Dataset"](args, vocab, tokenizer)
    dataset.build_and_save(args.processes_num)
예제 #10
0
def set_args(predefined_dict_groups):
    # options for model
    args = Bunch()
    args.mask_mode = 'cross-wise'  # in ['row_wise', 'col_wise', 'cross_wise', 'cross_and_hier_wise']
    args.additional_ban = 0
    # args.pooling = 'avg-token'
    args.pooling = 'avg-cell-seg'
    args.table_object = 'first-column'
    args.noise_num = 2
    args.seq_len = 100
    args.row_wise_fill = True

    args.pretrained_model_path = "./models/bert_model.bin-000"
    args.vocab_path = 'models/google_uncased_en_vocab.txt'
    args.vocab = Vocab()
    args.vocab.load(args.vocab_path)
    args.emb_size = 768
    args.embedding = 'tab'  # before: bert
    args.encoder = 'bertTab'
    args.subword_type = 'none'
    args.tokenizer = 'bert'
    args.tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"](args)

    args.feedforward_size = 3072
    args.hidden_size = 768
    args.heads_num = 12
    args.layers_num = 12
    args.learning_rate = 2e-5
    # args.learning_rate = 1e-4
    args.warmup = 0.1
    args.batch_size = 32
    args.dropout = 0.1
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # other options
    args.shuffle_rows = True
    args.report_steps = 100
    for predefined_dict_group in predefined_dict_groups.values():
        for k, v in predefined_dict_group.items():
            args[k] = v
    args.labels_map = get_labels_map_from_aida_file_2(args.train_path)
    args.labels_num = len(args.labels_map)

    # logger and tensorboard writer
    if args.tx_logger_dir_name:
        rq = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
        args.summary_writer = SummaryWriter(logdir=os.path.join(
            args.tx_logger_dir_name, '-'.join([args.exp_name, rq])))
    else:
        args.summary_writer = None
    if args.logger_dir_name is not None:
        args.logger_name = 'detail'
        args.logger = get_logger(logger_name=args.logger_name,
                                 dir_name=args.logger_dir_name,
                                 file_name=args.logger_file_name)
    else:
        args.logger = None
    return args
예제 #11
0
def set_args_2():
    # options for model
    args = Bunch()
    args.mask_mode = 'cross-wise'  # in ['row_wise', 'col_wise', 'cross_wise', 'cross_and_hier_wise']
    args.additional_ban = 0
    # args.pooling = 'avg-token'
    args.pooling = 'avg-cell-seg'
    args.table_object = 'first-column'
    args.noise_num = 2
    args.seq_len = 100
    args.row_wise_fill = True

    args.pretrained_model_path = "./models/bert_model.bin-000"
    args.vocab_path = 'models/google_uncased_en_vocab.txt'
    args.vocab = Vocab()
    args.vocab.load(args.vocab_path)
    args.emb_size = 768
    args.embedding = 'tab'  # before: bert
    args.encoder = 'bertTab'
    args.subword_type = 'none'
    args.tokenizer = 'bert'

    args.feedforward_size = 3072
    args.hidden_size = 768
    args.heads_num = 12
    args.layers_num = 12
    args.learning_rate = 2e-5
    args.warmup = 0.1
    args.batch_size = 32
    args.dropout = 0.1
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    args.train_path = './data/aida/IO/train_samples'
    args.t2d_path = './data/aida/IO/test_samples_t2d'
    args.limaye_path = './data/aida/IO/test_samples_limaye'
    args.wiki_path = './data/aida/IO/test_samples_wikipedia'

    # other options
    args.report_steps = 100
    args.labels_map = get_labels_map_from_aida_file_2(args.train_path)
    args.labels_num = len(args.labels_map)
    args.tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"](args)
    return args
예제 #12
0
        "--tokenizer",
        choices=["char", "word", "space", "mixed"],
        default="char",
        help="Specify the tokenizer."
        "Original Google BERT uses char tokenizer on Chinese corpus."
        "Word tokenizer supports online word segmentation based on jieba segmentor."
        "Space tokenizer segments sentences into words according to space."
        "Mixed tokenizer segments sentences into words according to space."
        "If words are not in the vocabulary, the tokenizer splits words into characters."
    )

    args = parser.parse_args()
    args = load_hyperparam(args)

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build and load model.
    bert_model = build_model(args)
    pretrained_model = torch.load(args.model_path)
    bert_model.load_state_dict(pretrained_model, strict=True)
    seq_encoder = SequenceEncoder(bert_model)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        seq_encoder = nn.DataParallel(seq_encoder)
예제 #13
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--load_model_path",
                        default=None,
                        type=str,
                        help="Path of the NER model.")
    parser.add_argument("--vocab_path",
                        type=str,
                        required=True,
                        help="Path of the vocabulary file.")
    parser.add_argument("--test_path", type=str, help="Path of the testset.")
    parser.add_argument("--prediction_path",
                        default=None,
                        type=str,
                        help="Path of the prediction file.")
    parser.add_argument("--config_path",
                        default="./models/bert_base_config.json",
                        type=str,
                        help="Path of the config file.")
    parser.add_argument("--label2id_path",
                        type=str,
                        required=True,
                        help="Path of the label2id file.")

    # Model options.
    parser.add_argument("--batch_size",
                        type=int,
                        default=128,
                        help="Batch_size.")
    parser.add_argument("--seq_length",
                        default=128,
                        type=int,
                        help="Sequence length.")
    parser.add_argument("--embedding",
                        choices=["bert", "word"],
                        default="bert",
                        help="Emebdding type.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                              "cnn", "gatedcnn", "attn", "synt", \
                                              "rcnn", "crnn", "gpt", "bilstm"], \
                                              default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")
    parser.add_argument("--factorized_embedding_parameterization",
                        action="store_true",
                        help="Factorized embedding parameterization.")
    parser.add_argument("--parameter_sharing",
                        action="store_true",
                        help="Parameter sharing.")

    args = parser.parse_args()

    # Load the hyperparameters of the config file.
    args = load_hyperparam(args)

    with open(args.label2id_path, mode="r", encoding="utf-8") as f:
        l2i = json.load(f)
        print("Labels: ", l2i)
        l2i["[PAD]"] = len(l2i)

    i2l = {}
    for key, value in l2i.items():
        i2l[value] = key

    args.l2i = l2i

    args.labels_num = len(l2i)

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build sequence labeling model.
    model = NerTagger(args)
    model = load_model(model, args.load_model_path)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)

    instances = read_dataset(args, args.test_path)

    src = torch.LongTensor([ins[0] for ins in instances])
    seg = torch.LongTensor([ins[1] for ins in instances])

    instances_num = src.size(0)
    batch_size = args.batch_size

    print("The number of prediction instances: ", instances_num)

    model.eval()

    with open(args.prediction_path, mode="w", encoding="utf-8") as f:
        f.write("pred_label" + "\n")
        for i, (src_batch,
                seg_batch) in enumerate(batch_loader(batch_size, src, seg)):
            src_batch = src_batch.to(device)
            seg_batch = seg_batch.to(device)
            with torch.no_grad():
                _, logits = model(src_batch, None, seg_batch)
            pred = logits.argmax(dim=-1)
            # Storing sequence length of instances in a batch.
            seq_length_batch = []
            for seg in seg_batch.cpu().numpy().tolist():
                for j in range(len(seg) - 1, -1, -1):
                    if seg[j] != 0:
                        break
                seq_length_batch.append(j + 1)
            pred = pred.cpu().numpy().tolist()
            for j in range(0, len(pred), args.seq_length):
                for label_id in pred[j:j +
                                     seq_length_batch[j // args.seq_length]]:
                    f.write(i2l[label_id] + " ")
                f.write("\n")
예제 #14
0
                             "Space tokenizer segments sentences into words according to space."
                             )

    parser.add_argument("--batch_size", type=int, default=64,
                        help="Batch size.")
    parser.add_argument("--seq_length", type=int, default=128,
                        help="Sequence length.")

    parser.add_argument("--topn", type=int, default=15)

    args = parser.parse_args()
    args = load_hyperparam(args)

    args.spm_model_path = None

    vocab = Vocab()
    vocab.load(args.vocab_path)

    cand_vocab = Vocab()
    cand_vocab.load(args.cand_vocab_path)

    args.tokenizer = str2tokenizer[args.tokenizer](args)

    model = SequenceEncoder(args)    
 
    pretrained_model = torch.load(args.load_model_path)
    model.load_state_dict(pretrained_model, strict=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
예제 #15
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path",
                        default=None,
                        type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path",
                        default="./models/classifier_model.bin",
                        type=str,
                        help="Path of the output model.")
    parser.add_argument("--vocab_path",
                        default="./models/google_vocab.txt",
                        type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path",
                        type=str,
                        required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path",
                        type=str,
                        required=True,
                        help="Path of the devset.")
    parser.add_argument("--test_path", type=str, help="Path of the testset.")
    parser.add_argument("--config_path",
                        default="./models/bert_base_config.json",
                        type=str,
                        help="Path of the config file.")

    # Model options.
    parser.add_argument("--batch_size",
                        type=int,
                        default=64,
                        help="Batch size.")
    parser.add_argument("--seq_length",
                        type=int,
                        default=128,
                        help="Sequence length.")
    parser.add_argument("--embedding",
                        choices=["bert", "word"],
                        default="bert",
                        help="Emebdding type.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                                   "cnn", "gatedcnn", "attn", \
                                                   "rcnn", "crnn", "gpt", "bilstm"], \
                                                   default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")
    parser.add_argument("--pooling",
                        choices=["mean", "max", "first", "last"],
                        default="first",
                        help="Pooling type.")

    # Subword options.
    parser.add_argument("--subword_type",
                        choices=["none", "char"],
                        default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path",
                        type=str,
                        default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder",
                        choices=["avg", "lstm", "gru", "cnn"],
                        default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num",
                        type=int,
                        default=2,
                        help="The number of subencoder layers.")

    # Tokenizer options.
    parser.add_argument(
        "--tokenizer",
        choices=["bert", "char", "space"],
        default="bert",
        help="Specify the tokenizer."
        "Original Google BERT uses bert tokenizer on Chinese corpus."
        "Char tokenizer segments sentences into characters."
        "Space tokenizer segments sentences into words according to space.")

    # Optimizer options.
    parser.add_argument("--learning_rate",
                        type=float,
                        default=2e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup",
                        type=float,
                        default=0.1,
                        help="Warm up value.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.5, help="Dropout.")
    parser.add_argument("--epochs_num",
                        type=int,
                        default=3,
                        help="Number of epochs.")
    parser.add_argument("--report_steps",
                        type=int,
                        default=100,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7, help="Random seed.")

    # Evaluation options.
    parser.add_argument("--mean_reciprocal_rank",
                        action="store_true",
                        help="Evaluation metrics for DBQA dataset.")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    # Count the number of labels.
    labels_set = set()
    columns = {}
    with open(args.train_path, mode="r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            try:
                line = line.strip().split("\t")
                if line_id == 0:
                    for i, column_name in enumerate(line):
                        columns[column_name] = i
                    continue
                label = int(line[columns["label"]])
                labels_set.add(label)
            except:
                pass
    args.labels_num = len(labels_set)

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build bert model.
    # A pseudo target is added.
    args.target = "bert"
    model = build_model(args)

    # Load or initialize parameters.
    if args.pretrained_model_path is not None:
        # Initialize with pretrained model.
        model.load_state_dict(torch.load(args.pretrained_model_path),
                              strict=False)
    else:
        # Initialize with normal distribution.
        for n, p in list(model.named_parameters()):
            if 'gamma' not in n and 'beta' not in n:
                p.data.normal_(0, 0.02)

    # Build classification model.
    model = BertClassifier(args, model)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = nn.DataParallel(model)

    model = model.to(device)

    # Datset loader.
    def batch_loader(batch_size, input_ids, label_ids, mask_ids):
        instances_num = input_ids.size()[0]
        for i in range(instances_num // batch_size):
            input_ids_batch = input_ids[i * batch_size:(i + 1) * batch_size, :]
            label_ids_batch = label_ids[i * batch_size:(i + 1) * batch_size]
            mask_ids_batch = mask_ids[i * batch_size:(i + 1) * batch_size, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch
        if instances_num > instances_num // batch_size * batch_size:
            input_ids_batch = input_ids[instances_num // batch_size *
                                        batch_size:, :]
            label_ids_batch = label_ids[instances_num // batch_size *
                                        batch_size:]
            mask_ids_batch = mask_ids[instances_num // batch_size *
                                      batch_size:, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch

    # Build tokenizer.
    tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"](args)

    # Read dataset.
    def read_dataset(path):
        dataset = []
        with open(path, mode="r", encoding="utf-8") as f:
            for line_id, line in enumerate(f):
                if line_id == 0:
                    continue
                try:
                    line = line.strip().split('\t')
                    if len(line) == 2:
                        label = int(line[columns["label"]])
                        text = line[columns["text_a"]]
                        tokens = [
                            vocab.get(t) for t in tokenizer.tokenize(text)
                        ]
                        tokens = [CLS_ID] + tokens
                        mask = [1] * len(tokens)
                        if len(tokens) > args.seq_length:
                            tokens = tokens[:args.seq_length]
                            mask = mask[:args.seq_length]
                        while len(tokens) < args.seq_length:
                            tokens.append(0)
                            mask.append(0)
                        dataset.append((tokens, label, mask))
                    elif len(line) == 3:  # For sentence pair input.
                        label = int(line[columns["label"]])
                        text_a, text_b = line[columns["text_a"]], line[
                            columns["text_b"]]

                        tokens_a = [
                            vocab.get(t) for t in tokenizer.tokenize(text_a)
                        ]
                        tokens_a = [CLS_ID] + tokens_a + [SEP_ID]
                        tokens_b = [
                            vocab.get(t) for t in tokenizer.tokenize(text_b)
                        ]
                        tokens_b = tokens_b + [SEP_ID]

                        tokens = tokens_a + tokens_b
                        mask = [1] * len(tokens_a) + [2] * len(tokens_b)

                        if len(tokens) > args.seq_length:
                            tokens = tokens[:args.seq_length]
                            mask = mask[:args.seq_length]
                        while len(tokens) < args.seq_length:
                            tokens.append(0)
                            mask.append(0)
                        dataset.append((tokens, label, mask))
                    elif len(line) == 4:  # For dbqa input.
                        qid = int(line[columns["qid"]])
                        label = int(line[columns["label"]])
                        text_a, text_b = line[columns["text_a"]], line[
                            columns["text_b"]]

                        tokens_a = [
                            vocab.get(t) for t in tokenizer.tokenize(text_a)
                        ]
                        tokens_a = [CLS_ID] + tokens_a + [SEP_ID]
                        tokens_b = [
                            vocab.get(t) for t in tokenizer.tokenize(text_b)
                        ]
                        tokens_b = tokens_b + [SEP_ID]

                        tokens = tokens_a + tokens_b
                        mask = [1] * len(tokens_a) + [2] * len(tokens_b)

                        if len(tokens) > args.seq_length:
                            tokens = tokens[:args.seq_length]
                            mask = mask[:args.seq_length]
                        while len(tokens) < args.seq_length:
                            tokens.append(0)
                            mask.append(0)
                        dataset.append((tokens, label, mask, qid))
                    else:
                        pass

                except:
                    pass
        return dataset

    # Evaluation function.
    def evaluate(args, is_test):
        if is_test:
            dataset = read_dataset(args.test_path)
        else:
            dataset = read_dataset(args.dev_path)

        input_ids = torch.LongTensor([sample[0] for sample in dataset])
        label_ids = torch.LongTensor([sample[1] for sample in dataset])
        mask_ids = torch.LongTensor([sample[2] for sample in dataset])

        batch_size = args.batch_size
        instances_num = input_ids.size()[0]
        if is_test:
            print("The number of evaluation instances: ", instances_num)

        correct = 0
        # Confusion matrix.
        confusion = torch.zeros(args.labels_num,
                                args.labels_num,
                                dtype=torch.long)

        model.eval()

        if not args.mean_reciprocal_rank:
            for i, (input_ids_batch, label_ids_batch,
                    mask_ids_batch) in enumerate(
                        batch_loader(batch_size, input_ids, label_ids,
                                     mask_ids)):
                input_ids_batch = input_ids_batch.to(device)
                label_ids_batch = label_ids_batch.to(device)
                mask_ids_batch = mask_ids_batch.to(device)
                with torch.no_grad():
                    loss, logits = model(input_ids_batch, label_ids_batch,
                                         mask_ids_batch)
                logits = nn.Softmax(dim=1)(logits)
                pred = torch.argmax(logits, dim=1)
                gold = label_ids_batch
                for j in range(pred.size()[0]):
                    confusion[pred[j], gold[j]] += 1
                correct += torch.sum(pred == gold).item()

            if is_test:
                print("Confusion matrix:")
                print(confusion)
                print("Report precision, recall, and f1:")
            for i in range(confusion.size()[0]):
                p = confusion[i, i].item() / confusion[i, :].sum().item()
                r = confusion[i, i].item() / confusion[:, i].sum().item()
                f1 = 2 * p * r / (p + r)
                if is_test:
                    print("Label {}: {:.3f}, {:.3f}, {:.3f}".format(
                        i, p, r, f1))
            print("Acc. (Correct/Total): {:.4f} ({}/{}) ".format(
                correct / len(dataset), correct, len(dataset)))
            return correct / len(dataset)
        else:
            for i, (input_ids_batch, label_ids_batch,
                    mask_ids_batch) in enumerate(
                        batch_loader(batch_size, input_ids, label_ids,
                                     mask_ids)):
                input_ids_batch = input_ids_batch.to(device)
                label_ids_batch = label_ids_batch.to(device)
                mask_ids_batch = mask_ids_batch.to(device)
                with torch.no_grad():
                    loss, logits = model(input_ids_batch, label_ids_batch,
                                         mask_ids_batch)
                logits = nn.Softmax(dim=1)(logits)
                if i == 0:
                    logits_all = logits
                if i >= 1:
                    logits_all = torch.cat((logits_all, logits), 0)

            order = -1
            gold = []
            for i in range(len(dataset)):
                qid = dataset[i][3]
                label = dataset[i][1]
                if qid == order:
                    j += 1
                    if label == 1:
                        gold.append((qid, j))
                else:
                    order = qid
                    j = 0
                    if label == 1:
                        gold.append((qid, j))

            label_order = []
            order = -1
            for i in range(len(gold)):
                if gold[i][0] == order:
                    templist.append(gold[i][1])
                elif gold[i][0] != order:
                    order = gold[i][0]
                    if i > 0:
                        label_order.append(templist)
                    templist = []
                    templist.append(gold[i][1])
            label_order.append(templist)

            order = -1
            score_list = []
            for i in range(len(logits_all)):
                score = float(logits_all[i][1])
                qid = int(dataset[i][3])
                if qid == order:
                    templist.append(score)
                else:
                    order = qid
                    if i > 0:
                        score_list.append(templist)
                    templist = []
                    templist.append(score)
            score_list.append(templist)

            rank = []
            pred = []
            for i in range(len(score_list)):
                if len(label_order[i]) == 1:
                    if label_order[i][0] < len(score_list[i]):
                        true_score = score_list[i][label_order[i][0]]
                        score_list[i].sort(reverse=True)
                        for j in range(len(score_list[i])):
                            if score_list[i][j] == true_score:
                                rank.append(1 / (j + 1))
                    else:
                        rank.append(0)

                else:
                    true_rank = len(score_list[i])
                    for k in range(len(label_order[i])):
                        if label_order[i][k] < len(score_list[i]):
                            true_score = score_list[i][label_order[i][k]]
                            temp = sorted(score_list[i], reverse=True)
                            for j in range(len(temp)):
                                if temp[j] == true_score:
                                    if j < true_rank:
                                        true_rank = j
                    if true_rank < len(score_list[i]):
                        rank.append(1 / (true_rank + 1))
                    else:
                        rank.append(0)
            MRR = sum(rank) / len(rank)
            print(MRR)
            return MRR

    # Training phase.
    print("Start training.")
    trainset = read_dataset(args.train_path)
    random.shuffle(trainset)
    instances_num = len(trainset)
    batch_size = args.batch_size

    input_ids = torch.LongTensor([example[0] for example in trainset])
    label_ids = torch.LongTensor([example[1] for example in trainset])
    mask_ids = torch.LongTensor([example[2] for example in trainset])

    train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    print("Batch size: ", batch_size)
    print("The number of training instances:", instances_num)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup,
                         t_total=train_steps)

    total_loss = 0.
    result = 0.0
    best_result = 0.0

    for epoch in range(1, args.epochs_num + 1):
        model.train()
        for i, (input_ids_batch, label_ids_batch, mask_ids_batch) in enumerate(
                batch_loader(batch_size, input_ids, label_ids, mask_ids)):
            model.zero_grad()

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)

            loss, _ = model(input_ids_batch, label_ids_batch, mask_ids_batch)
            if torch.cuda.device_count() > 1:
                loss = torch.mean(loss)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".
                      format(epoch, i + 1, total_loss / args.report_steps))
                total_loss = 0.
            loss.backward()
            optimizer.step()
        result = evaluate(args, False)
        if result > best_result:
            best_result = result
            save_model(model, args.output_model_path)
        else:
            continue

    # Evaluation phase.
    if args.test_path is not None:
        print("Test set evaluation.")
        model = load_model(model, args.output_model_path)
        evaluate(args, True)
예제 #16
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--corpus_path", required=True)
    parser.add_argument("--vocab_path", required=True)
    parser.add_argument("--workers_num",
                        type=int,
                        default=1,
                        help="The number of processes to build vocabulary.")
    parser.add_argument(
        "--min_count",
        type=int,
        default=1,
        help="The minimum count of words retained in the vocabulary.")

    tokenizer_opts(parser)

    args = parser.parse_args()

    vocab_path = args.vocab_path
    args.vocab_path, args.spm_model_path = "./models/reserved_vocab.txt", None

    # Build tokenizer.
    tokenizer = str2tokenizer[args.tokenizer](args)

    # Build and save vocabulary.
    vocab = Vocab()
    vocab.build(args.corpus_path, tokenizer, args.workers_num, args.min_count)
    vocab.save(vocab_path)
예제 #17
0
    new_model[embedding_key] = torch.tensor(new_embedding, dtype=torch.float32)
    if bool:
        new_model[softmax_key] = torch.tensor(new_softmax, dtype=torch.float32)
        new_model[softmax_bias_key] = torch.tensor(new_softmax_bias,
                                                   dtype=torch.float32)
    return new_model


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Input options.
    parser.add_argument("--old_model_path", type=str)
    parser.add_argument("--old_vocab_path", type=str)
    parser.add_argument("--new_vocab_path", type=str)

    # Output options.
    parser.add_argument("--new_model_path", type=str)

    args = parser.parse_args()
    old_vocab = Vocab()
    old_vocab.load(args.old_vocab_path)
    new_vocab = Vocab()
    new_vocab.load(args.new_vocab_path)

    old_model = torch.load(args.old_model_path, map_location="cpu")

    new_model = adapter(old_model, old_vocab, new_vocab)
    print("Output adapted new model.")
    torch.save(new_model, args.new_model_path)
예제 #18
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--load_model_path",
                        default=None,
                        type=str,
                        help="Path of the classfier model.")
    parser.add_argument("--vocab_path",
                        type=str,
                        required=True,
                        help="Path of the vocabulary file.")
    parser.add_argument("--test_path", type=str, help="Path of the testset.")
    parser.add_argument("--prediction_path",
                        default=None,
                        type=str,
                        help="Path of the prediction file.")
    parser.add_argument("--config_path",
                        default="./models/bert_base_config.json",
                        type=str,
                        help="Path of the config file.")

    # Model options.
    parser.add_argument("--batch_size",
                        type=int,
                        default=128,
                        help="Batch size.")
    parser.add_argument("--seq_length",
                        type=int,
                        default=128,
                        help="Sequence length.")
    parser.add_argument("--labels_num",
                        type=int,
                        required=True,
                        help="Number of prediction labels.")
    parser.add_argument("--embedding",
                        choices=["bert", "word"],
                        default="bert",
                        help="Emebdding type.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                              "cnn", "gatedcnn", "attn", "synt", \
                                              "rcnn", "crnn", "gpt", "bilstm"], \
                                              default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")
    parser.add_argument("--pooling",
                        choices=["mean", "max", "first", "last"],
                        default="first",
                        help="Pooling type.")
    parser.add_argument("--factorized_embedding_parameterization",
                        action="store_true",
                        help="Factorized embedding parameterization.")
    parser.add_argument("--parameter_sharing",
                        action="store_true",
                        help="Parameter sharing.")

    # Tokenizer options.
    parser.add_argument(
        "--tokenizer",
        choices=["bert", "char", "space"],
        default="bert",
        help="Specify the tokenizer."
        "Original Google BERT uses bert tokenizer on Chinese corpus."
        "Char tokenizer segments sentences into characters."
        "Space tokenizer segments sentences into words according to space.")

    # Output options.
    parser.add_argument("--output_logits",
                        action="store_true",
                        help="Write logits to output file.")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build classification model and load parameters.
    args.soft_targets = False
    model = Classifier(args)
    model = load_model(model, args.load_model_path)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)

    # Build tokenizer.
    args.tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"](args)

    dataset = read_dataset(args, args.test_path)

    src = torch.LongTensor([sample[0] for sample in dataset])
    seg = torch.LongTensor([sample[1] for sample in dataset])

    batch_size = args.batch_size
    instances_num = src.size()[0]

    print("The number of prediction instances: ", instances_num)

    model.eval()

    with open(args.prediction_path, mode="w", encoding="utf-8") as f:
        if args.output_logits:
            f.write("label" + "\t" + "logits" + "\n")
        else:
            f.write("label" + "\n")
        for i, (src_batch,
                seg_batch) in enumerate(batch_loader(batch_size, src, seg)):
            src_batch = src_batch.to(device)
            seg_batch = seg_batch.to(device)
            with torch.no_grad():
                _, logits = model(src_batch, None, seg_batch)

            pred = torch.argmax(logits, dim=1)
            pred = pred.cpu().numpy().tolist()
            logits = logits.cpu().numpy().tolist()

            if args.output_logits:
                for j in range(len(pred)):
                    f.write(
                        str(pred[j]) + "\t" +
                        " ".join([str(v) for v in logits[j]]) + "\n")
            else:
                for j in range(len(pred)):
                    f.write(str(pred[j]) + "\n")
예제 #19
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--load_model_path",
                        default=None,
                        type=str,
                        help="Path of the classfier model.")
    parser.add_argument("--vocab_path",
                        type=str,
                        required=True,
                        help="Path of the vocabulary file.")
    parser.add_argument("--test_path", type=str, help="Path of the testset.")
    parser.add_argument("--prediction_path",
                        default=None,
                        type=str,
                        help="Path of the prediction file.")
    parser.add_argument("--config_path",
                        default="./models/bert_base_config.json",
                        type=str,
                        help="Path of the config file.")

    # Model options.
    parser.add_argument("--batch_size",
                        type=int,
                        default=64,
                        help="Batch size.")
    parser.add_argument("--seq_length",
                        type=int,
                        default=512,
                        help="Sequence length.")
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument("--embedding",
                        choices=["bert", "word"],
                        default="bert",
                        help="Emebdding type.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                              "cnn", "gatedcnn", "attn", "synt", \
                                              "rcnn", "crnn", "gpt", "bilstm"], \
                                              default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")
    parser.add_argument("--factorized_embedding_parameterization",
                        action="store_true",
                        help="Factorized embedding parameterization.")
    parser.add_argument("--parameter_sharing",
                        action="store_true",
                        help="Parameter sharing.")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build model and load parameters.
    model = MachineReadingComprehension(args)
    model = load_model(model, args.load_model_path)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)

    # Build tokenizer.
    args.tokenizer = CharTokenizer(args)

    dataset, examples = read_dataset(args, args.test_path)

    src = torch.LongTensor([sample[0] for sample in dataset])
    seg = torch.LongTensor([sample[1] for sample in dataset])
    start_position = torch.LongTensor([sample[2] for sample in dataset])
    end_position = torch.LongTensor([sample[3] for sample in dataset])

    batch_size = args.batch_size
    instances_num = len(dataset)

    print("The number of prediction instances: ", instances_num)

    model.eval()

    with open(args.prediction_path, mode="w", encoding="utf-8") as f:

        start_prob_all, end_prob_all = [], []

        for i, (src_batch, seg_batch, start_position_batch,
                end_position_batch) in enumerate(
                    batch_loader(batch_size, src, seg, start_position,
                                 end_position)):
            src_batch = src_batch.to(device)
            seg_batch = seg_batch.to(device)
            start_position_batch = start_position_batch.to(device)
            end_position_batch = end_position_batch.to(device)

            with torch.no_grad():
                loss, start_logits, end_logits = model(src_batch, seg_batch,
                                                       start_position_batch,
                                                       end_position_batch)

            start_prob = nn.Softmax(dim=1)(start_logits)
            end_prob = nn.Softmax(dim=1)(end_logits)

            for j in range(start_prob.size()[0]):
                start_prob_all.append(start_prob[j])
                end_prob_all.append(end_prob[j])

        pred_answers = get_answers(dataset, start_prob_all, end_prob_all)

        output = {}
        for i in range(len(examples)):
            question_id = examples[i][2]
            start_pred_pos = pred_answers[i][1]
            end_pred_pos = pred_answers[i][2]

            prediction = examples[i][0][start_pred_pos:end_pred_pos]
            output[question_id] = prediction

        f.write(json.dumps(output, indent=4, ensure_ascii=False) + "\n")
예제 #20
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path",
                        default=None,
                        type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path",
                        default="./models/classifier_model.bin",
                        type=str,
                        help="Path of the output model.")
    parser.add_argument("--vocab_path",
                        default="./models/google_vocab.txt",
                        type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path",
                        type=str,
                        required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path",
                        type=str,
                        required=True,
                        help="Path of the devset.")
    parser.add_argument("--test_path",
                        type=str,
                        required=True,
                        help="Path of the testset.")
    parser.add_argument("--config_path",
                        default="./models/google_config.json",
                        type=str,
                        help="Path of the config file.")

    # Model options.
    parser.add_argument("--batch_size",
                        type=int,
                        default=32,
                        help="Batch size.")
    parser.add_argument("--seq_length",
                        type=int,
                        default=256,
                        help="Sequence length.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                                   "cnn", "gatedcnn", "attn", \
                                                   "rcnn", "crnn", "gpt", "bilstm"], \
                                                   default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")
    parser.add_argument("--pooling",
                        choices=["mean", "max", "first", "last"],
                        default="first",
                        help="Pooling type.")

    # Subword options.
    parser.add_argument("--subword_type",
                        choices=["none", "char"],
                        default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path",
                        type=str,
                        default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder",
                        choices=["avg", "lstm", "gru", "cnn"],
                        default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num",
                        type=int,
                        default=2,
                        help="The number of subencoder layers.")

    # Tokenizer options.
    parser.add_argument(
        "--tokenizer",
        choices=["bert", "char", "word", "space"],
        default="bert",
        help="Specify the tokenizer."
        "Original Google BERT uses bert tokenizer on Chinese corpus."
        "Char tokenizer segments sentences into characters."
        "Word tokenizer supports online word segmentation based on jieba segmentor."
        "Space tokenizer segments sentences into words according to space.")

    # Optimizer options.
    parser.add_argument("--learning_rate",
                        type=float,
                        default=2e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup",
                        type=float,
                        default=0.1,
                        help="Warm up value.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.5, help="Dropout.")
    parser.add_argument("--epochs_num",
                        type=int,
                        default=5,
                        help="Number of epochs.")
    parser.add_argument("--report_steps",
                        type=int,
                        default=100,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7, help="Random seed.")

    # Evaluation options.
    parser.add_argument("--mean_reciprocal_rank",
                        action="store_true",
                        help="Evaluation metrics for DBQA dataset.")

    # kg
    parser.add_argument("--kg_name", required=True, help="KG name or path")
    parser.add_argument("--workers_num",
                        type=int,
                        default=1,
                        help="number of process for loading dataset")
    parser.add_argument("--no_vm",
                        action="store_true",
                        help="Disable the visible_matrix")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    # Count the number of labels.
    labels_set = set()
    columns = {}
    with open(args.train_path, mode="r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            try:
                line = line.strip().split("\t")
                if line_id == 0:
                    for i, column_name in enumerate(line):
                        columns[column_name] = i
                    continue
                label = int(line[columns["label"]])
                labels_set.add(label)
            except:
                pass
    args.labels_num = len(labels_set)

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build bert model.
    # A pseudo target is added.
    args.target = "bert"
    model = build_model(args)

    # Load or initialize parameters.
    if args.pretrained_model_path is not None:
        # Initialize with pretrained model.
        model.load_state_dict(torch.load(args.pretrained_model_path),
                              strict=False)
    else:
        # Initialize with normal distribution.
        for n, p in list(model.named_parameters()):
            if 'gamma' not in n and 'beta' not in n:
                p.data.normal_(0, 0.02)

    # Build classification model.
    model = BertClassifier(args, model)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = nn.DataParallel(model)

    model = model.to(device)

    # Datset loader.
    def batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vms):
        instances_num = input_ids.size()[0]
        for i in range(instances_num // batch_size):
            input_ids_batch = input_ids[i * batch_size:(i + 1) * batch_size, :]
            label_ids_batch = label_ids[i * batch_size:(i + 1) * batch_size]
            mask_ids_batch = mask_ids[i * batch_size:(i + 1) * batch_size, :]
            pos_ids_batch = pos_ids[i * batch_size:(i + 1) * batch_size, :]
            vms_batch = vms[i * batch_size:(i + 1) * batch_size]
            yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch
        if instances_num > instances_num // batch_size * batch_size:
            input_ids_batch = input_ids[instances_num // batch_size *
                                        batch_size:, :]
            label_ids_batch = label_ids[instances_num // batch_size *
                                        batch_size:]
            mask_ids_batch = mask_ids[instances_num // batch_size *
                                      batch_size:, :]
            pos_ids_batch = pos_ids[instances_num // batch_size *
                                    batch_size:, :]
            vms_batch = vms[instances_num // batch_size * batch_size:]

            yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch

    # Build knowledge graph.
    if args.kg_name == 'none':
        spo_files = []
    else:
        spo_files = [args.kg_name]
    kg = KnowledgeGraph(spo_files=spo_files, predicate=True)

    def read_dataset(path, workers_num=1):

        print("Loading sentences from {}".format(path))
        sentences = []
        with open(path, mode='r', encoding="utf-8") as f:
            for line_id, line in enumerate(f):
                if line_id == 0:
                    continue
                sentences.append(line)
        sentence_num = len(sentences)

        print(
            "There are {} sentence in total. We use {} processes to inject knowledge into sentences."
            .format(sentence_num, workers_num))

        if workers_num > 1:
            params = []
            sentence_per_block = int(sentence_num / workers_num) + 1
            for i in range(workers_num):
                params.append((i, sentences[i * sentence_per_block:(i + 1) *
                                            sentence_per_block], columns, kg,
                               vocab, args))
            pool = Pool(workers_num)
            res = pool.map(add_knowledge_worker, params)
            pool.close()
            pool.join()
            dataset = [sample for block in res for sample in block]
        else:
            params = (0, sentences, columns, kg, vocab, args)
            dataset = add_knowledge_worker(params)

        return dataset

    # Evaluation function.
    def evaluate(args, is_test, metrics='Acc'):
        if is_test:
            dataset = read_dataset(args.test_path,
                                   workers_num=args.workers_num)
        else:
            dataset = read_dataset(args.dev_path, workers_num=args.workers_num)

        input_ids = torch.LongTensor([sample[0] for sample in dataset])
        label_ids = torch.LongTensor([sample[1] for sample in dataset])
        mask_ids = torch.LongTensor([sample[2] for sample in dataset])
        pos_ids = torch.LongTensor([example[3] for example in dataset])
        vms = [example[4] for example in dataset]

        batch_size = args.batch_size
        instances_num = input_ids.size()[0]
        if is_test:
            print("The number of evaluation instances: ", instances_num)

        correct = 0
        # Confusion matrix.
        confusion = torch.zeros(args.labels_num,
                                args.labels_num,
                                dtype=torch.long)

        model.eval()

        if not args.mean_reciprocal_rank:
            for i, (input_ids_batch, label_ids_batch, mask_ids_batch,
                    pos_ids_batch, vms_batch) in enumerate(
                        batch_loader(batch_size, input_ids, label_ids,
                                     mask_ids, pos_ids, vms)):

                # vms_batch = vms_batch.long()
                vms_batch = torch.LongTensor(vms_batch)

                input_ids_batch = input_ids_batch.to(device)
                label_ids_batch = label_ids_batch.to(device)
                mask_ids_batch = mask_ids_batch.to(device)
                pos_ids_batch = pos_ids_batch.to(device)
                vms_batch = vms_batch.to(device)

                with torch.no_grad():
                    try:
                        loss, logits = model(input_ids_batch, label_ids_batch,
                                             mask_ids_batch, pos_ids_batch,
                                             vms_batch)
                    except:
                        print(input_ids_batch)
                        print(input_ids_batch.size())
                        print(vms_batch)
                        print(vms_batch.size())

                logits = nn.Softmax(dim=1)(logits)
                pred = torch.argmax(logits, dim=1)
                gold = label_ids_batch
                for j in range(pred.size()[0]):
                    confusion[pred[j], gold[j]] += 1
                correct += torch.sum(pred == gold).item()

            if is_test:
                print("Confusion matrix:")
                print(confusion)
                print("Report precision, recall, and f1:")

            # for i in range(confusion.size()[0]):
            #     p = confusion[i,i].item()/confusion[i,:].sum().item()
            #     r = confusion[i,i].item()/confusion[:,i].sum().item()
            #     f1 = 2*p*r / (p+r)
            #     if i == 1:
            #         label_1_f1 = f1
            #     print("Label {}: {:.3f}, {:.3f}, {:.3f}".format(i,p,r,f1))
            print("Acc. (Correct/Total): {:.4f} ({}/{}) ".format(
                correct / len(dataset), correct, len(dataset)))
            if metrics == 'Acc':
                return correct / len(dataset)
            elif metrics == 'f1':
                return label_1_f1
            else:
                return correct / len(dataset)
        else:
            for i, (input_ids_batch, label_ids_batch, mask_ids_batch,
                    pos_ids_batch, vms_batch) in enumerate(
                        batch_loader(batch_size, input_ids, label_ids,
                                     mask_ids, pos_ids, vms)):

                vms_batch = torch.LongTensor(vms_batch)

                input_ids_batch = input_ids_batch.to(device)
                label_ids_batch = label_ids_batch.to(device)
                mask_ids_batch = mask_ids_batch.to(device)
                pos_ids_batch = pos_ids_batch.to(device)
                vms_batch = vms_batch.to(device)

                with torch.no_grad():
                    loss, logits = model(input_ids_batch, label_ids_batch,
                                         mask_ids_batch, pos_ids_batch,
                                         vms_batch)
                logits = nn.Softmax(dim=1)(logits)
                if i == 0:
                    logits_all = logits
                if i >= 1:
                    logits_all = torch.cat((logits_all, logits), 0)

            order = -1
            gold = []
            for i in range(len(dataset)):
                qid = dataset[i][-1]
                label = dataset[i][1]
                if qid == order:
                    j += 1
                    if label == 1:
                        gold.append((qid, j))
                else:
                    order = qid
                    j = 0
                    if label == 1:
                        gold.append((qid, j))

            label_order = []
            order = -1
            for i in range(len(gold)):
                if gold[i][0] == order:
                    templist.append(gold[i][1])
                elif gold[i][0] != order:
                    order = gold[i][0]
                    if i > 0:
                        label_order.append(templist)
                    templist = []
                    templist.append(gold[i][1])
            label_order.append(templist)

            order = -1
            score_list = []
            for i in range(len(logits_all)):
                score = float(logits_all[i][1])
                qid = int(dataset[i][-1])
                if qid == order:
                    templist.append(score)
                else:
                    order = qid
                    if i > 0:
                        score_list.append(templist)
                    templist = []
                    templist.append(score)
            score_list.append(templist)

            rank = []
            pred = []
            print(len(score_list))
            print(len(label_order))
            for i in range(len(score_list)):
                if len(label_order[i]) == 1:
                    if label_order[i][0] < len(score_list[i]):
                        true_score = score_list[i][label_order[i][0]]
                        score_list[i].sort(reverse=True)
                        for j in range(len(score_list[i])):
                            if score_list[i][j] == true_score:
                                rank.append(1 / (j + 1))
                    else:
                        rank.append(0)

                else:
                    true_rank = len(score_list[i])
                    for k in range(len(label_order[i])):
                        if label_order[i][k] < len(score_list[i]):
                            true_score = score_list[i][label_order[i][k]]
                            temp = sorted(score_list[i], reverse=True)
                            for j in range(len(temp)):
                                if temp[j] == true_score:
                                    if j < true_rank:
                                        true_rank = j
                    if true_rank < len(score_list[i]):
                        rank.append(1 / (true_rank + 1))
                    else:
                        rank.append(0)
            MRR = sum(rank) / len(rank)
            print("MRR", MRR)
            return MRR

    # Training phase.
    print("Start training.")
    trainset = read_dataset(args.train_path, workers_num=args.workers_num)

    print("Shuffling dataset")
    random.shuffle(trainset)
    instances_num = len(trainset)
    batch_size = args.batch_size

    print("Trans data to tensor.")
    print("input_ids")
    input_ids = torch.LongTensor([example[0] for example in trainset])

    print("label_ids")
    label_ids = torch.LongTensor([example[1] for example in trainset])

    print("mask_ids")
    mask_ids = torch.LongTensor([example[2] for example in trainset])

    print("pos_ids")
    pos_ids = torch.LongTensor([example[3] for example in trainset])

    print("vms")
    vms = [example[4] for example in trainset]

    train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    print("Batch size: ", batch_size)
    print("The number of training instances:", instances_num)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup,
                         t_total=train_steps)

    total_loss = 0.
    result = 0.0
    best_result = 0.0

    for epoch in range(1, args.epochs_num + 1):
        model.train()
        for i, (input_ids_batch, label_ids_batch, mask_ids_batch,
                pos_ids_batch, vms_batch) in enumerate(
                    batch_loader(batch_size, input_ids, label_ids, mask_ids,
                                 pos_ids, vms)):
            model.zero_grad()

            vms_batch = torch.LongTensor(vms_batch)

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            pos_ids_batch = pos_ids_batch.to(device)
            vms_batch = vms_batch.to(device)

            loss, _ = model(input_ids_batch,
                            label_ids_batch,
                            mask_ids_batch,
                            pos=pos_ids_batch,
                            vm=vms_batch)
            if torch.cuda.device_count() > 1:
                loss = torch.mean(loss)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".
                      format(epoch, i + 1, total_loss / args.report_steps))
                sys.stdout.flush()
                total_loss = 0.
            loss.backward()
            optimizer.step()

        print("Start evaluation on dev dataset.")
        result = evaluate(args, False)
        if result > best_result:
            best_result = result
            save_model(model, args.output_model_path)
        else:
            continue

        print("Start evaluation on test dataset.")
        evaluate(args, True)

    # Evaluation phase.
    print("Final evaluation on the test dataset.")
    #model save
    if torch.cuda.device_count() > 1:
        model.module.load_state_dict(torch.load(args.output_model_path))
    else:
        model.load_state_dict(torch.load(args.output_model_path))
    evaluate(args, True)
예제 #21
0
    parser.add_argument("--spm_model_path", default=None, type=str,
                        help="Path of the sentence piece model.")
    parser.add_argument("--word_embedding_path", default=None, type=str,
                        help="Path of the output word embedding.")

    args = parser.parse_args()

    if args.spm_model_path:
        try:
            import sentencepiece as spm
        except ImportError:
            raise ImportError("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
                                                    "pip install sentencepiece")
        sp_model = spm.SentencePieceProcessor()
        sp_model.Load(args.spm_model_path)
        vocab = Vocab()
        vocab.i2w = {i: sp_model.IdToPiece(i) for i in range(sp_model.GetPieceSize())}
    else:
        vocab = Vocab()
        vocab.load(args.vocab_path)

    pretrained_model = torch.load(args.load_model_path)
    embedding = pretrained_model["embedding.word_embedding.weight"]

    with open(args.word_embedding_path, mode="w", encoding="utf-8") as f:
        head=str(list(embedding.size())[0])+" "+str(list(embedding.size())[1])+"\n"
        f.write(head)

        for i in range(len(vocab.i2w)):
            word = vocab.i2w[i]
            word_embedding = embedding[vocab.get(word), :]
예제 #22
0
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path", default=None, type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_path", default="./models/tagger_model.bin", type=str,
                        help="Path of the output model.")
    parser.add_argument("--vocab_path", default="./models/google_vocab.txt", type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path", type=str, required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path", type=str, required=True,
                        help="Path of the devset.")
    parser.add_argument("--test_path", type=str, required=True,
                        help="Path of the testset.")
    parser.add_argument("--config_path", default="./models/google_config.json", type=str,
                        help="Path of the config file.")

    # Model options.
    parser.add_argument("--batch_size", type=int, default=16,
                        help="Batch_size.")
    parser.add_argument("--seq_length", default=128, type=int,
                        help="Sequence length.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                                   "cnn", "gatedcnn", "attn", \
                                                   "rcnn", "crnn", "gpt", "bilstm"], \
                                                   default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional", action="store_true", help="Specific to recurrent model.")
    
    # Subword options.
    parser.add_argument("--subword_type", choices=["none", "char"], default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path", type=str, default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder", choices=["avg", "lstm", "gru", "cnn"], default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num", type=int, default=2, help="The number of subencoder layers.")

    # Optimizer options.
    parser.add_argument("--learning_rate", type=float, default=2e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup", type=float, default=0.1,
                        help="Warm up value.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.1,
                        help="Dropout.")
    parser.add_argument("--epochs_num", type=int, default=5,
                        help="Number of epochs.")
    parser.add_argument("--report_steps", type=int, default=100,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7,
                        help="Random seed.")

    # kg
    parser.add_argument("--kg_name", required=True, help="KG name or path")
    parser.add_argument("--log_file",help='记录log信息')
    parser.add_argument('--task_name',default=None,type=str)
    parser.add_argument("--mode",default='regular',type=str)
    parser.add_argument('--run_time',default=None,type=str)
    parser.add_argument("--commit_id",default=None,type=str)
    parser.add_argument("--fold_nb",default=0,type=str)
    parser.add_argument("--tensorboard_dir",default=None)

    parser.add_argument("--need_birnn",default=False,type=bool)
    parser.add_argument("--rnn_dim",default=128,type=int)
    parser.add_argument("--model_name",default='bert',type=str)
    parser.add_argument("--pku_model_name",default='default',type=str)
    parser.add_argument("--has_token",default=False)

    parser.add_argument("--do_train",default=False,type=bool)
    parser.add_argument("--do_test",default=True,type=bool)

    args = parser.parse_args()
    args.run_time = datetime.datetime.today().strftime('%Y-%m-%d_%H-%M-%S')

    # Load the hyperparameters of the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    s = Save_Log(args)
    logger = init_logger(args.log_file)

    print(args)
    logger.info(args)

    os.makedirs(args.output_path,exist_ok=True)
    writer = SummaryWriter(logdir=os.path.join(args.tensorboard_dir, "eval",'{}_{}_{}_{}'.format(args.task_name,args.fold_nb,args.run_time,args.commit_id)), comment="Linear")

    labels_map = {"[PAD]": 0, "[ENT]": 1}
    begin_ids = []

    # Find tagging labels
    with open(args.train_path, mode="r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            if line_id == 0:
                continue
            labels = line.strip().split("\t")[1].split()
            for l in labels:
                if l not in labels_map:
                    if l.startswith("B") or l.startswith("S"):
                        begin_ids.append(len(labels_map))
                    labels_map[l] = len(labels_map)
    
    print("Labels: ", labels_map)
    logger.info(labels_map)
    args.labels_num = len(labels_map)
    id2label = {labels_map[key]:key for key in labels_map}
    print("id2label:",id2label)
    logger.info(id2label)
    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build knowledge graph.
    if args.kg_name == 'none':
        spo_files = []
    else:
        spo_files = [args.kg_name]
    kg = KnowledgeGraph(spo_files=spo_files,pku_model_name= args.pku_model_name,predicate=False)

    # Build bert model.
    # A pseudo target is added.
    args.target = "bert"
    model = build_model(args)

    # Load or initialize parameters.
    if args.pretrained_model_path is not None:
        # Initialize with pretrained model.
        model.load_state_dict(torch.load(args.pretrained_model_path), strict=False)  
    else:
        # Initialize with normal distribution.
        for n, p in list(model.named_parameters()):
            if 'gamma' not in n and 'beta' not in n:
                p.data.normal_(0, 0.02)
    
    # Build sequence labeling model.
    if(args.model_name=='bert'):
    # model = BertTagger_with_LSTMCRF(args, model)
        model = BertTagger(args, model)
    elif(args.model_name == 'bertcrf'):
        model = BertTagger_with_LSTMCRF(args, model)
    logger.info(model)
    # print("model:",model)

    # print("model bert Tagger:",model)
    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)

    model = model.to(device)
    args.device = device

    # Datset loader.
    def batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vm_ids, tag_ids):
        instances_num = input_ids.size()[0]
        for i in range(instances_num // batch_size):
            input_ids_batch = input_ids[i*batch_size: (i+1)*batch_size, :]
            label_ids_batch = label_ids[i*batch_size: (i+1)*batch_size, :]
            mask_ids_batch = mask_ids[i*batch_size: (i+1)*batch_size, :]
            pos_ids_batch = pos_ids[i*batch_size: (i+1)*batch_size, :]
            vm_ids_batch = vm_ids[i*batch_size: (i+1)*batch_size, :, :]
            tag_ids_batch = tag_ids[i*batch_size: (i+1)*batch_size, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch
        if instances_num > instances_num // batch_size * batch_size:
            input_ids_batch = input_ids[instances_num//batch_size*batch_size:, :]
            label_ids_batch = label_ids[instances_num//batch_size*batch_size:, :]
            mask_ids_batch = mask_ids[instances_num//batch_size*batch_size:, :]
            pos_ids_batch = pos_ids[instances_num//batch_size*batch_size:, :]
            vm_ids_batch = vm_ids[instances_num//batch_size*batch_size:, :, :]
            tag_ids_batch = tag_ids[instances_num//batch_size*batch_size:, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch

    # Read dataset.
    def read_dataset(path):
        dataset = []
        with open(path, mode="r", encoding="utf-8") as f:
            f.readline()
            tokens, labels = [], []
            for line_id, line in enumerate(f):
                tokens, labels = line.strip().split("\t")
                # print("token:",tokens)
                # print("label:",labels)
                # print("len tokens:",len(tokens.split(' ')),"len labels:",len(labels.split(' ')))
                text = ''.join(tokens.split(" "))
                # print("len text:",len(text))
                tokens, pos, vm, tag = kg.add_knowledge_with_vm([text], add_pad=True, max_length=args.seq_length)
                tokens = tokens[0]
                # print("len2 text:",len(tokens),"len label:",len(labels))

                pos = pos[0]
                vm = vm[0].astype("bool")
                tag = tag[0]

                tokens = [vocab.get(t) for t in tokens]
                labels = [labels_map[l] for l in labels.split(" ")]
                # print("len3 text:",len(tokens),"len label:",len(labels))

                mask = [1] * len(tokens)
                # print('tokens:',tokens)
                # print("label:",labels)
                # assert len(tokens) == len(labels),(len(tokens),len(labels))

                new_labels = []
                j = 0
                for i in range(len(tokens)):
                    if tag[i] == 0 and tokens[i] != PAD_ID:
                        new_labels.append(labels[j])
                        j += 1
                    elif tag[i] == 1 and tokens[i] != PAD_ID:  # 是添加的实体
                        new_labels.append(labels_map['[ENT]'])
                    else:
                        new_labels.append(labels_map[PAD_TOKEN])

                dataset.append([tokens, new_labels, mask, pos, vm, tag])
        
        return dataset

    # Evaluation function.
    def evaluate(args,epoch, is_test):
        f1 = 0
        if is_test:
            dataset = read_dataset(args.test_path)
        else:
            dataset = read_dataset(args.dev_path)

        input_ids = torch.LongTensor([sample[0] for sample in dataset])
        label_ids = torch.LongTensor([sample[1] for sample in dataset])
        mask_ids = torch.LongTensor([sample[2] for sample in dataset])
        pos_ids = torch.LongTensor([sample[3] for sample in dataset])
        vm_ids = torch.BoolTensor([sample[4] for sample in dataset])
        tag_ids = torch.LongTensor([sample[5] for sample in dataset])

        instances_num = input_ids.size(0)
        batch_size = args.batch_size

        if is_test:
            print("Batch size: ", batch_size)
            print("The number of test instances:", instances_num)
 
        correct = 0
        gold_entities_num = 0
        pred_entities_num = 0

        by_type_correct = {}
        by_type_gold_nb = {}
        by_type_pred_nb = {}

        confusion = torch.zeros(len(labels_map), len(labels_map), dtype=torch.long)

        pred_labels = []
        gold_labels = []
        origin_tokens = []

        model.eval()

        for i, (input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vm_ids, tag_ids)):

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            pos_ids_batch = pos_ids_batch.to(device)
            tag_ids_batch = tag_ids_batch.to(device)
            vm_ids_batch = vm_ids_batch.long().to(device)
            # print("batch size:",batch_size)
            loss, _, pred, gold = model(input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch)
            # print(pred.size(),gold.size())
            # print("pred:",pred)
            # print("gold:",gold)

            """
            pred: tensor([2, 2, 2,  ..., 2, 2, 2], device='cuda:0')
            gold: tensor([2, 2, 2,  ..., 0, 0, 0], device='cuda:0')

            """
            # print("input id batch:",input_ids_batch.size())
            for input_ids in input_ids_batch:
                for id in input_ids:
                    origin_tokens.append(vocab.i2w[id])
            for p,g in zip(pred,gold):

                pred_labels.append(id2label[int(p)] )
                gold_labels.append(id2label[int(g)])

            # pred_labels.append(pred)

            # gold_labels.append(gold)
            # print("pred label",pred_labels)
            # print("gold label:",gold_labels)

            for j in range(gold.size()[0]):
                if gold[j].item() in begin_ids:
                    gold_entities_num += 1
                    if(gold[j].item() not in by_type_gold_nb):
                        by_type_gold_nb[gold[j].item()] = 1
                    else:
                        by_type_gold_nb[gold[j].item()] += 1

 
            for j in range(pred.size()[0]):
                if pred[j].item() in begin_ids and gold[j].item() != labels_map["[PAD]"]:
                    pred_entities_num += 1
                    if (pred[j].item() not in by_type_pred_nb):
                        by_type_pred_nb[pred[j].item()] = 1
                    else:
                        by_type_pred_nb[pred[j].item()] += 1

            pred_entities_pos = []
            gold_entities_pos = []
            start, end = 0, 0

            for j in range(gold.size()[0]):
                if gold[j].item() in begin_ids:
                    start = j
                    type = gold[j].item()
                    # print("gold j item:",gold[j].item())

                    for k in range(j+1, gold.size()[0]):
                        
                        if gold[k].item() == labels_map['[ENT]']:
                            continue

                        if gold[k].item() == labels_map["[PAD]"] or gold[k].item() == labels_map["O"] or gold[k].item() in begin_ids:
                            end = k - 1
                            break
                    else:
                        end = gold.size()[0] - 1
                    gold_entities_pos.append((start, end,type))
            
            for j in range(pred.size()[0]):
                if pred[j].item() in begin_ids and gold[j].item() != labels_map["[PAD]"] and gold[j].item() != labels_map["[ENT]"]:
                    start = j
                    type = pred[j].item()
                    for k in range(j+1, pred.size()[0]):

                        if gold[k].item() == labels_map['[ENT]']:
                            continue

                        if pred[k].item() == labels_map["[PAD]"] or pred[k].item() == labels_map["O"] or pred[k].item() in begin_ids:
                            end = k - 1
                            break
                    else:
                        end = pred.size()[0] - 1
                    pred_entities_pos.append((start, end,type))

            for entity in pred_entities_pos:
                if entity not in gold_entities_pos:
                    continue
                else: 
                    correct += 1
                    if(entity[2] not in by_type_correct):
                        by_type_correct[entity[2]] = 1
                    else:
                        by_type_correct[entity[2]] += 1



        if(not is_test):

            print("Report precision, recall, and f1:")
            logger.info("Report precision, recall, and f1:")
            p = correct / pred_entities_num
            r = correct / gold_entities_num
            f1 = 2 * p * r / (p + r)
            logger.info("{:.3f}, {:.3f}, {:.3f}".format(p, r, f1))
            print("{:.3f}, {:.3f}, {:.3f}".format(p, r, f1))
            writer.add_scalar("Eval/precision", p, epoch)
            writer.add_scalar("Eval/recall", r, epoch)
            writer.add_scalar("Eval/f1_score", f1, epoch)

            for type in by_type_correct:
                p = by_type_correct[type] / by_type_pred_nb[type]
                r = by_type_correct[type] / by_type_gold_nb[type]
                f1 = 2 * p * r / (p + r)
                print("{}:{:.3f}, {:.3f}, {:.3f}".format(id2label[type][2:], p, r, f1))
                logger.info("{}:{:.3f}, {:.3f}, {:.3f}".format(id2label[type][2:], p, r, f1))
                writer.add_scalar("Eval/precision_{}".format(id2label[type][2:]), p, epoch)
                writer.add_scalar("Eval/recall_{}".format(id2label[type][2:]), r, epoch)
                writer.add_scalar("Eval/f1_score_{}".format(id2label[type][2:]), f1, epoch)

        with open(os.path.join(args.output_path,'pred_label_test1_{}.txt').format(is_test),'w',encoding='utf-8') as file:
            print("!!!!!!!! saving in ",os.path.join(args.output_path,'pred_label_test1_{}.txt'))
            i = 0
            while i < len(pred_labels):
                len_ = args.seq_length
                if('[PAD]' in origin_tokens[i:i+args.seq_length]):
                    len_ = origin_tokens[i:i+args.seq_length].index('[PAD]')
                file.write(' '.join(origin_tokens[i:i+len_]))
                # print("pred:",pred_labels[i:i+len_])
                file.write('\t'+' '.join(pred_labels[i:i+len_]))
                file.write('\t'+' '.join(gold_labels[i:i+len_])+'\n')

                i += args.seq_length

        return f1

    # Training phase.
    print("args train test:",args.do_train,args.do_test)
    if(args.do_train):
        print("Start training.")
        logger.info("Start training.")
        instances = read_dataset(args.train_path)

        input_ids = torch.LongTensor([ins[0] for ins in instances])
        label_ids = torch.LongTensor([ins[1] for ins in instances])
        mask_ids = torch.LongTensor([ins[2] for ins in instances])
        pos_ids = torch.LongTensor([ins[3] for ins in instances])
        vm_ids = torch.BoolTensor([ins[4] for ins in instances])
        tag_ids = torch.LongTensor([ins[5] for ins in instances])

        instances_num = input_ids.size(0)
        batch_size = args.batch_size
        train_steps = int(instances_num * args.epochs_num / batch_size) + 1

        logger.info("Batch size: {}".format(batch_size))
        print("Batch size: ", batch_size)
        print("The number of training instances:", instances_num)
        logger.info("The number of training instances:{}".format(instances_num))

        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
        ]
        optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup,
                             t_total=train_steps)

        total_loss = 0.
        f1 = 0.0
        best_f1 = 0.0
        total_step = 0
        for epoch in range(1, args.epochs_num + 1):
            print("Epoch ", epoch)
            model.train()
            for i, (
            input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch, tag_ids_batch) in enumerate(
                    batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vm_ids, tag_ids)):
                model.zero_grad()
                total_step += 1
                input_ids_batch = input_ids_batch.to(device)
                label_ids_batch = label_ids_batch.to(device)
                mask_ids_batch = mask_ids_batch.to(device)
                pos_ids_batch = pos_ids_batch.to(device)
                tag_ids_batch = tag_ids_batch.to(device)
                vm_ids_batch = vm_ids_batch.long().to(device)

                loss, _, _, _ = model(input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vm_ids_batch)
                if torch.cuda.device_count() > 1:
                    loss = torch.mean(loss)
                total_loss += loss.item()
                if (i + 1) % args.report_steps == 0:
                    logger.info("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".format(epoch, i + 1,
                                                                                            total_loss / args.report_steps))

                    writer.add_scalar("Train/loss", total_loss / args.report_steps, total_step)

                    total_loss = 0.

                loss.backward()
                optimizer.step()

            # Evaluation phase.
            print("Start evaluate on dev dataset.")
            logger.info("Start evaluate on dev dataset.")
            f1 = evaluate(args, epoch, False)
            # print("Start evaluation on test dataset.")
            # evaluate(args, True)

            if f1 > best_f1:
                best_f1 = f1
                save_model(model, os.path.join(args.output_path, '{}.bin').format(args.task_name))
            else:
                continue

    if(args.do_test):
        # Evaluation phase.
        print("Final evaluation on test dataset.")
        logger.info("Final evaluation on test dataset.")
        if torch.cuda.device_count() > 1:
            model.module.load_state_dict(torch.load(os.path.join(args.output_path, "{}.bin".format(args.task_name))))
        else:
            model.load_state_dict(torch.load(os.path.join(args.output_path, "{}.bin".format(args.task_name))))

        evaluate(args, args.epochs_num, True)

        print("============over=================={}".format(args.fold_nb))
예제 #23
0
        "--tokenizer",
        choices=["bert", "char", "word", "space"],
        default="bert",
        help="Specify the tokenizer."
        "Original Google BERT uses bert tokenizer on Chinese corpus."
        "Char tokenizer segments sentences into characters."
        "Word tokenizer supports online word segmentation based on jieba segmentor."
        "Space tokenizer segments sentences into words according to space.")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    # Load Vocabulary
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build bert model.
    model = build_model(args)

    # Load pretrained model.
    pretrained_model_dict = torch.load(args.pretrained_model_path)
    model.load_state_dict(pretrained_model_dict, strict=False)

    model = GenerateModel(args, model)

    # Build tokenizer.
    tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"](args)
예제 #24
0
    parser.add_argument(
        "--tokenizer",
        choices=["bert", "char", "space"],
        default="bert",
        help="Specify the tokenizer."
        "Original Google BERT uses bert tokenizer on Chinese corpus."
        "Char tokenizer segments sentences into characters."
        "Space tokenizer segments sentences into words according to space.")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    # Load Vocabulary
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    model = GenerateModel(args)

    # Load pretrained model.
    pretrained_model_dict = torch.load(args.pretrained_model_path)
    model.load_state_dict(pretrained_model_dict, strict=False)

    # Build tokenizer.
    args.tokenizer = str2tokenizer[args.tokenizer](args)

    def top_k_top_p_filtering(logits, top_k, top_p):
        top_k = min(top_k, logits.size(-1))  # Safety check
        if top_k > 0:
예제 #25
0
파일: run_mrc.py 프로젝트: hccngu/UER-py
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path",
                        default=None,
                        type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path",
                        default="./models/QA_model.bin",
                        type=str,
                        help="Path of the output model.")
    parser.add_argument("--vocab_path",
                        default="./models/google_vocab.txt",
                        type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path",
                        type=str,
                        required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path",
                        type=str,
                        required=True,
                        help="Path of the devset.")
    parser.add_argument("--test_path", type=str, help="Path of the testset.")
    parser.add_argument("--config_path",
                        default="./models/google_config.json",
                        type=str,
                        help="Path of the config file.")

    # Model options.
    parser.add_argument("--batch_size",
                        type=int,
                        default=64,
                        help="Batch size.")
    parser.add_argument("--seq_length",
                        type=int,
                        default=100,
                        help="Sequence length.")
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument("--embedding",
                        choices=["bert", "word"],
                        default="bert",
                        help="Emebdding type.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                                   "cnn", "gatedcnn", "attn", \
                                                   "rcnn", "crnn", "gpt"], \
                                                   default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")

    # Subword options.
    parser.add_argument("--subword_type",
                        choices=["none", "char"],
                        default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path",
                        type=str,
                        default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder",
                        choices=["avg", "lstm", "gru", "cnn"],
                        default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num",
                        type=int,
                        default=2,
                        help="The number of subencoder layers.")

    # Tokenizer options.
    parser.add_argument(
        "--tokenizer",
        choices=["bert", "char", "space"],
        default="char",
        help="Specify the tokenizer."
        "Original Google BERT uses bert tokenizer on Chinese corpus."
        "Char tokenizer segments sentences into characters."
        "Space tokenizer segments sentences into words according to space.")

    # Optimizer options.
    parser.add_argument("--learning_rate",
                        type=float,
                        default=3e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup",
                        type=float,
                        default=0.1,
                        help="Warm up value.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.5, help="Dropout.")
    parser.add_argument("--epochs_num",
                        type=int,
                        default=3,
                        help="Number of epochs.")
    parser.add_argument("--report_steps",
                        type=int,
                        default=100,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7, help="Random seed.")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    args.target = "bert"
    bert_model = build_model(args)
    # Load or initialize parameters.
    if args.pretrained_model_path is not None:
        # Initialize with pretrained model.
        bert_model.load_state_dict(torch.load(args.pretrained_model_path),
                                   strict=False)
    else:
        # Initialize with normal distribution.
        for n, p in list(bert_model.named_parameters()):
            if 'gamma' not in n and 'beta' not in n:
                p.data.normal_(0, 0.02)

    # Build QA model.
    model = BertQuestionAnswering(args, bert_model)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = nn.DataParallel(model)

    model = model.to(device)

    # Dataset loader.
    def batch_loader(batch_size, input_ids, mask_ids, start_positions,
                     end_positions):
        instances_num = input_ids.size()[0]
        for i in range(instances_num // batch_size):
            input_ids_batch = input_ids[i * batch_size:(i + 1) * batch_size, :]
            mask_ids_batch = mask_ids[i * batch_size:(i + 1) * batch_size, :]
            start_positions_batch = start_positions[i * batch_size:(i + 1) *
                                                    batch_size]
            end_positions_batch = end_positions[i * batch_size:(i + 1) *
                                                batch_size]
            yield input_ids_batch, mask_ids_batch, start_positions_batch, end_positions_batch
        if instances_num > instances_num // batch_size * batch_size:
            input_ids_batch = input_ids[instances_num // batch_size *
                                        batch_size:, :]
            mask_ids_batch = mask_ids[instances_num // batch_size *
                                      batch_size:, :]
            start_positions_batch = start_positions[instances_num //
                                                    batch_size * batch_size:]
            end_positions_batch = end_positions[instances_num // batch_size *
                                                batch_size:]
            yield input_ids_batch, mask_ids_batch, start_positions_batch, end_positions_batch

    # Build tokenizer.
    tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"](args)

    # Read examples.
    def read_examples(path):
        examples = []
        with open(path, 'r', encoding='utf-8') as fp:
            all_dict = json.loads(fp.read())
            v1 = all_dict["data"]
            for i in range(len(v1)):
                data_dict = v1[i]
                v2 = data_dict["paragraphs"]

                for j in range(len(v2)):
                    para_dict = v2[j]
                    context = para_dict["context"]
                    v3 = para_dict["qas"]

                    for m in range(len(v3)):
                        qas_dict = v3[m]
                        question = qas_dict["question"]
                        question_id = qas_dict["id"]
                        v4 = qas_dict["answers"]

                        answers = []
                        start_positions = []
                        end_positions = []

                        for n in range(len(v4)):
                            ans_dict = v4[n]
                            answer = ans_dict["text"]
                            start_position = ans_dict["answer_start"]
                            end_position = start_position + len(answer)

                            answers.append(answer)
                            start_positions.append(start_position)
                            end_positions.append(end_position)

                        examples.append(
                            (context, question, question_id, start_positions,
                             end_positions, answers))

        return examples

    def convert_examples_to_dataset(examples, args):
        dataset = []
        print("The number of questions in the dataset", len(examples))
        for i in range(len(examples)):
            context = examples[i][0]
            question = examples[i][1]
            q_len = len(question)
            question_id = examples[i][2]

            start_positions_true = examples[i][3][0]  #待修改
            end_positions_true = examples[i][4][0]

            answers = examples[i][5]
            max_context_length = args.seq_length - q_len - 3
            # divide the context to some spans
            _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
                "DocSpan", ["start", "length"])
            doc_spans = []
            start_offset = 0
            while start_offset < len(context):
                length = len(context) - start_offset
                if length > max_context_length:
                    length = max_context_length
                doc_spans.append(_DocSpan(start=start_offset, length=length))
                if start_offset + length == len(context):
                    break
                start_offset += min(length, args.doc_stride)

            for (doc_span_index, doc_span) in enumerate(doc_spans):
                doc_span_start = doc_span.start
                span_context = context[doc_span_start:doc_span_start +
                                       doc_span.length]
                # convert the start or end position to real position in tokens
                start_positions = start_positions_true - doc_span_start + q_len + 2
                end_positions = end_positions_true - doc_span_start + q_len + 2
                # the answers of some question are not in the doc_span, we ignore them.
                if start_positions < q_len + 2 or start_positions > doc_span.length + q_len + 2 or end_positions < q_len + 2 or end_positions > doc_span.length + q_len + 2:
                    continue

                tokens_a = [vocab.get(t) for t in tokenizer.tokenize(question)]
                tokens_a = [CLS_ID] + tokens_a + [SEP_ID]
                tokens_b = [
                    vocab.get(t) for t in tokenizer.tokenize(span_context)
                ]
                tokens_b = tokens_b + [SEP_ID]
                tokens = tokens_a + tokens_b
                mask = [1] * len(tokens_a) + [2] * len(tokens_b)

                while len(tokens) < args.seq_length:
                    tokens.append(0)
                    mask.append(0)

                dataset.append(
                    (tokens, mask, start_positions, end_positions, answers,
                     question_id, q_len, doc_span_index, doc_span_start))
        return dataset

    # Evaluation function.
    def evaluate(args, is_test):
        # some calculation functions
        def mixed_segmentation(in_str, rm_punc=False):
            in_str = str(in_str).lower().strip()
            segs_out = []
            temp_str = ""
            sp_char = [
                '-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',',
                '。', ':', '?', '!', '“', '”', ';', '’', '《', '》', '……', '·',
                '、', '「', '」', '(', ')', '-', '~', '『', '』'
            ]
            for char in in_str:
                if rm_punc and char in sp_char:
                    continue
                if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char:
                    if temp_str != "":
                        ss = nltk.word_tokenize(temp_str)
                        segs_out.extend(ss)
                        temp_str = ""
                    segs_out.append(char)
                else:
                    temp_str += char

            #handling last part
            if temp_str != "":
                ss = nltk.word_tokenize(temp_str)
                segs_out.extend(ss)

            return segs_out

        # remove punctuation
        def remove_punctuation(in_str):
            in_str = str(in_str).lower().strip()
            sp_char = [
                '-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',',
                '。', ':', '?', '!', '“', '”', ';', '’', '《', '》', '……', '·',
                '、', '「', '」', '(', ')', '-', '~', '『', '』'
            ]
            out_segs = []
            for char in in_str:
                if char in sp_char:
                    continue
                else:
                    out_segs.append(char)
            return ''.join(out_segs)

        # find longest common string
        def find_lcs(s1, s2):
            m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)]
            mmax = 0
            p = 0
            for i in range(len(s1)):
                for j in range(len(s2)):
                    if s1[i] == s2[j]:
                        m[i + 1][j + 1] = m[i][j] + 1
                        if m[i + 1][j + 1] > mmax:
                            mmax = m[i + 1][j + 1]
                            p = i + 1
            return s1[p - mmax:p], mmax

        def calc_f1_score(answers, prediction):
            f1_scores = []
            for i in range(len(answers)):
                ans = answers[i]
                ans_segs = mixed_segmentation(ans, rm_punc=True)
                prediction_segs = mixed_segmentation(prediction, rm_punc=True)
                lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
                if lcs_len == 0:
                    f1_scores.append(0)
                else:
                    precision = 1.0 * lcs_len / len(prediction_segs)
                    recall = 1.0 * lcs_len / len(ans_segs)
                    f1 = (2 * precision * recall) / (precision + recall)
                    f1_scores.append(f1)
            return max(f1_scores)

        def calc_em_score(answers, prediction):
            em = 0
            for i in range(len(answers)):
                ans = answers[i]
                ans_ = remove_punctuation(ans)
                prediction_ = remove_punctuation(prediction)
                if ans_ == prediction_:
                    em = 1
                    break
            return em

        def is_max_score(score_list):
            score_max = -100
            index_max = 0
            best_start_prediction = 0
            best_end_prediction = 0
            for i in range(len(score_list)):
                if score_max <= score_list[i][3]:
                    score_max = score_list[i][3]
                    index_max = score_list[i][0]
                    best_start_prediction = score_list[i][1]
                    best_end_prediction = score_list[i][2]
            return index_max, best_start_prediction, best_end_prediction

        if is_test:
            examples = read_examples(args.test_path)
            dataset = convert_examples_to_dataset(examples, args)

        else:
            examples = read_examples(args.dev_path)
            dataset = convert_examples_to_dataset(examples, args)

        input_ids = torch.LongTensor([sample[0] for sample in dataset])
        mask_ids = torch.LongTensor([sample[1] for sample in dataset])
        start_positions = torch.LongTensor([sample[2] for sample in dataset])
        end_positions = torch.LongTensor([sample[3] for sample in dataset])

        batch_size = args.batch_size
        instances_num = input_ids.size()[0]

        if is_test:
            print("The number of evaluation instances: ", instances_num)
        model.eval()
        start_logits_all = []
        end_logits_all = []
        start_pred_all = []
        end_pred_all = []
        for i, (input_ids_batch, mask_ids_batch, start_positions_batch,
                end_positions_batch) in enumerate(
                    batch_loader(batch_size, input_ids, mask_ids,
                                 start_positions, end_positions)):
            model.zero_grad()
            input_ids_batch = input_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            start_positions_batch = start_positions_batch.to(device)
            end_positions_batch = end_positions_batch.to(device)

            with torch.no_grad():
                loss, start_logits, end_logits = model(input_ids_batch,
                                                       mask_ids_batch,
                                                       start_positions_batch,
                                                       end_positions_batch)

            start_logits = nn.Softmax(dim=1)(start_logits)
            end_logits = nn.Softmax(dim=1)(end_logits)

            start_pred = torch.argmax(start_logits, dim=1)
            end_pred = torch.argmax(end_logits, dim=1)

            start_pred = start_pred.cpu().numpy().tolist()
            end_pred = end_pred.cpu().numpy().tolist()

            start_logits = start_logits.cpu().numpy().tolist()
            end_logits = end_logits.cpu().numpy().tolist()

            start_logits_max = []
            end_logits_max = []
            for j in range(len(start_pred)):
                start_logits_max.append(start_logits[j][start_pred[j]])
                end_logits_max.append(end_logits[j][end_pred[j]])

            start_logits_all += start_logits_max
            end_logits_all += end_logits_max
            start_pred_all += start_pred
            end_pred_all += end_pred

        assert len(start_pred_all) == len(dataset)
        assert len(start_logits_all) == len(dataset)

        # couster by question id and chose the best answer in doc_spans
        order = -1
        pred_list = []
        templist = []
        for i in range(len(dataset)):
            qid = dataset[i][5]
            q_len = dataset[i][6]
            span_index = dataset[i][7]
            doc_span_start = dataset[i][8]

            score1 = float(start_logits_all[i])
            score2 = float(end_logits_all[i])
            score = (score1 + score2) / 2

            pre_start_pred = start_pred_all[i] + doc_span_start - q_len - 2
            pre_end_pred = end_pred_all[i] + doc_span_start - q_len - 2

            if qid == order:
                templist.append(
                    (span_index, pre_start_pred, pre_end_pred, score))
            else:
                order = qid
                if i > 0:
                    span_index_max, best_start_prediction, best_end_prediction = is_max_score(
                        templist)
                    pred_list.append((span_index_max, best_start_prediction,
                                      best_end_prediction))
                templist = []
                templist.append(
                    (span_index, pre_start_pred, pre_end_pred, score))
        span_index_max, best_start_prediction, best_end_prediction = is_max_score(
            templist)
        pred_list.append(
            (span_index_max, best_start_prediction, best_end_prediction))

        assert len(pred_list) == len(examples)

        #strat pred
        f1 = 0
        em = 0
        total_count = len(examples)
        skip_count = 0
        for i in range(len(examples)):
            question_id = examples[i][2]
            answers = examples[i][5]
            span_index = pred_list[i][0]
            start_prediction = pred_list[i][1]
            end_prediction = pred_list[i][2]

            #error prediction
            if end_prediction <= start_prediction:
                skip_count += 1
                continue

            prediction = examples[i][0][start_prediction:end_prediction]

            f1 += calc_f1_score(answers, prediction)
            em += calc_em_score(answers, prediction)

        f1_score = 100.0 * f1 / total_count
        em_score = 100.0 * em / total_count
        avg = (f1_score + em_score) * 0.5
        print("Avg: {:.4f},F1:{:.4f},EM:{:.4f},Total:{},Skip:{}".format(
            avg, f1_score, em_score, total_count, skip_count))
        return avg

    # Training phase
    print("Start training.")
    batch_size = args.batch_size
    print("Batch size: ", batch_size)
    examples = read_examples(args.train_path)
    trainset = convert_examples_to_dataset(examples, args)
    random.shuffle(trainset)
    instances_num = len(trainset)

    input_ids = torch.LongTensor([sample[0] for sample in trainset])
    mask_ids = torch.LongTensor([sample[1] for sample in trainset])
    start_positions = torch.LongTensor([sample[2] for sample in trainset])
    end_positions = torch.LongTensor([sample[3] for sample in trainset])

    train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    print("The number of training instances:", instances_num)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup,
                         t_total=train_steps)

    total_loss = 0.
    result = 0.0
    best_result = 0.0

    for epoch in range(1, args.epochs_num + 1):
        model.train()

        for i, (input_ids_batch, mask_ids_batch, start_positions_batch,
                end_positions_batch) in enumerate(
                    batch_loader(batch_size, input_ids, mask_ids,
                                 start_positions, end_positions)):
            model.zero_grad()
            input_ids_batch = input_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            start_positions_batch = start_positions_batch.to(device)
            end_positions_batch = end_positions_batch.to(device)

            loss, _, _ = model(input_ids_batch, mask_ids_batch,
                               start_positions_batch, end_positions_batch)
            if torch.cuda.device_count() > 1:
                loss = torch.mean(loss)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".
                      format(epoch, i + 1, total_loss / args.report_steps))
                total_loss = 0.
            loss.backward()
            optimizer.step()
        result = evaluate(args, False)
        if result > best_result:
            best_result = result
            save_model(model, args.output_model_path)
        else:
            continue

    # Evaluation phase.
    if args.test_path is not None:
        print("Test set evaluation.")
        model = load_model(model, args.output_model_path)
        evaluate(args, True)
예제 #26
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path",
                        default=None,
                        type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path",
                        default="./models/classifier_model.bin",
                        type=str,
                        help="Path of the output model.")
    parser.add_argument("--vocab_path",
                        type=str,
                        required=True,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path",
                        type=str,
                        required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path",
                        type=str,
                        required=True,
                        help="Path of the devset.")
    parser.add_argument("--test_path", type=str, help="Path of the testset.")
    parser.add_argument("--config_path",
                        default="./models/bert_base_config.json",
                        type=str,
                        help="Path of the config file.")

    # Model options.
    parser.add_argument("--batch_size",
                        type=int,
                        default=64,
                        help="Batch size.")
    parser.add_argument("--seq_length",
                        type=int,
                        default=128,
                        help="Sequence length.")
    parser.add_argument("--embedding",
                        choices=["bert", "word"],
                        default="bert",
                        help="Emebdding type.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                              "cnn", "gatedcnn", "attn", "synt", \
                                              "rcnn", "crnn", "gpt", "bilstm"], \
                                              default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")
    parser.add_argument("--pooling",
                        choices=["mean", "max", "first", "last"],
                        default="first",
                        help="Pooling type.")
    parser.add_argument("--factorized_embedding_parameterization",
                        action="store_true",
                        help="Factorized embedding parameterization.")
    parser.add_argument("--parameter_sharing",
                        action="store_true",
                        help="Parameter sharing.")

    # Tokenizer options.
    parser.add_argument(
        "--tokenizer",
        choices=["bert", "char", "space"],
        default="bert",
        help="Specify the tokenizer."
        "Original Google BERT uses bert tokenizer on Chinese corpus."
        "Char tokenizer segments sentences into characters."
        "Space tokenizer segments sentences into words according to space.")

    # Optimizer options.
    parser.add_argument("--soft_targets",
                        action='store_true',
                        help="Train model with logits.")
    parser.add_argument("--learning_rate",
                        type=float,
                        default=2e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup",
                        type=float,
                        default=0.1,
                        help="Warm up value.")
    parser.add_argument(
        "--fp16",
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit."
    )
    parser.add_argument(
        "--fp16_opt_level",
        choices=["O0", "O1", "O2", "O3"],
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.5, help="Dropout.")
    parser.add_argument("--epochs_num",
                        type=int,
                        default=3,
                        help="Number of epochs.")
    parser.add_argument("--report_steps",
                        type=int,
                        default=100,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7, help="Random seed.")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    # Count the number of labels.
    args.labels_num = count_labels_num(args.train_path)

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build classification model.
    model = Classifier(args)

    # Load or initialize parameters.
    load_or_initialize_parameters(args, model)

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(args.device)
    args.model = model

    # Build tokenizer.
    args.tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"](args)

    # Training phase.
    trainset = read_dataset(args, args.train_path)
    random.shuffle(trainset)
    instances_num = len(trainset)
    batch_size = args.batch_size

    src = torch.LongTensor([example[0] for example in trainset])
    tgt = torch.LongTensor([example[1] for example in trainset])
    seg = torch.LongTensor([example[2] for example in trainset])
    if args.soft_targets:
        soft_tgt = torch.FloatTensor([example[3] for example in trainset])
    else:
        soft_tgt = None

    args.train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    print("Batch size: ", batch_size)
    print("The number of training instances:", instances_num)

    optimizer, scheduler = build_optimizer(args, model)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)
        args.amp = amp

    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)

    total_loss, result, best_result = 0., 0., 0.

    print("Start training.")

    for epoch in range(1, args.epochs_num + 1):
        model.train()
        for i, (src_batch, tgt_batch, seg_batch, soft_tgt_batch) in enumerate(
                batch_loader(batch_size, src, tgt, seg, soft_tgt)):
            loss = train_model(args, model, optimizer, scheduler, src_batch,
                               tgt_batch, seg_batch, soft_tgt_batch)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".
                      format(epoch, i + 1, total_loss / args.report_steps))
                total_loss = 0.

        result = evaluate(args, read_dataset(args, args.dev_path))
        if result > best_result:
            best_result = result
            save_model(model, args.output_model_path)

    # Evaluation phase.
    if args.test_path is not None:
        print("Test set evaluation.")
        if torch.cuda.device_count() > 1:
            model.module.load_state_dict(torch.load(args.output_model_path))
        else:
            model.load_state_dict(torch.load(args.output_model_path))
        evaluate(args, read_dataset(args, args.test_path), True)
예제 #27
0
                        action="store_true",
                        help="Factorized embedding parameterization.")
    parser.add_argument("--parameter_sharing",
                        action="store_true",
                        help="Parameter sharing.")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    # Build tokenizer.
    args.tokenizer = str2tokenizer[args.tokenizer](args)

    # Load Vocabulary
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    if args.target == "mt":
        tgt_vocab = Vocab()
        tgt_vocab.load(args.tgt_vocab_path)
        args.tgt_vocab = tgt_vocab

    model = S2sGenerationModel(args)
    # Load pretrained model.
    pretrained_model_dict = torch.load(args.pretrained_model_path)
    model.load_state_dict(pretrained_model_dict, strict=True)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
예제 #28
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path",
                        default=None,
                        type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path",
                        default="./models/tagger_model.bin",
                        type=str,
                        help="Path of the output model.")
    parser.add_argument("--vocab_path",
                        default="./models/google_vocab.txt",
                        type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path",
                        type=str,
                        required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path",
                        type=str,
                        required=True,
                        help="Path of the devset.")
    parser.add_argument("--test_path",
                        type=str,
                        required=True,
                        help="Path of the testset.")
    parser.add_argument("--config_path",
                        default="./models/google_config.json",
                        type=str,
                        help="Path of the config file.")

    # Model options
    parser.add_argument("--batch_size",
                        type=int,
                        default=32,
                        help="Batch_size.")
    parser.add_argument("--seq_length",
                        default=128,
                        type=int,
                        help="Sequence length.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                                   "cnn", "gatedcnn", "attn", \
                                                   "rcnn", "crnn", "gpt", "bilstm"], \
                                                   default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")

    # Subword options.
    parser.add_argument("--subword_type",
                        choices=["none", "char"],
                        default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path",
                        type=str,
                        default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder",
                        choices=["avg", "lstm", "gru", "cnn"],
                        default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num",
                        type=int,
                        default=2,
                        help="The number of subencoder layers.")

    # Optimizer options.
    parser.add_argument("--learning_rate",
                        type=float,
                        default=2e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup",
                        type=float,
                        default=0.1,
                        help="Warm up value.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout.")
    parser.add_argument("--epochs_num",
                        type=int,
                        default=3,
                        help="Number of epochs.")
    parser.add_argument("--report_steps",
                        type=int,
                        default=100,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7, help="Random seed.")

    args = parser.parse_args()

    # Load the hyperparameters of the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    # Find tagging labels.
    labels_map = {"NULL": 0, "O": 1}  # ID for padding and non-entity.
    with open(args.train_path, mode="r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            if line_id == 0:
                continue
            line = line.strip().split()
            if len(line) != 2:
                continue
            if line[1] not in labels_map:
                labels_map[line[1]] = len(labels_map)

    print("Labels: ", labels_map)
    print("Label Num: ", len(labels_map))
    args.labels_num = len(labels_map)

    # Create the bad pairs
    args.bad_pairs = []
    args.good_pairs = []
    for key1, value1 in labels_map.items():
        key1 = key1.strip().split('-')
        if len(key1) < 1 or len(key1) > 2:
            print("Error label: ", key1)
            exit()
        for key2, value2 in labels_map.items():
            key2 = key2.strip().split('-')
            if len(key2) == 1:
                continue
            if len(key1) == 1 and len(key2) == 2:
                if key2[0] == 'I':
                    args.bad_pairs.append([value1, value2])
                continue
            # p(B-X -> I-Y) = 0
            if key1[1] != key2[1] and key1[0] == 'B' and key2[0] == 'I':
                args.bad_pairs.append([value1, value2])
            # p(I-X -> I-Y) = 0
            if key1[1] != key2[1] and key1[0] == 'I' and key2[0] == 'I':
                args.bad_pairs.append([value1, value2])
            # p(B-X -> I-X) = 10
            if key1[1] == key2[1] and key1[0] == 'B' and key2[0] == 'I':
                args.good_pairs.append([value1, value2])

    print("Bad pairs: ", args.bad_pairs)
    print("Good pairs: ", args.good_pairs)

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build bert model.
    # A pseudo target is added.
    args.target = "bert"
    model = build_model(args)

    # Load or initialize parameters.
    if args.pretrained_model_path is not None:
        # Initialize with pretrained model.
        model.load_state_dict(torch.load(args.pretrained_model_path),
                              strict=False)
    else:
        # Initialize with normal distribution.
        for n, p in list(model.named_parameters()):
            if 'gamma' not in n and 'beta' not in n:
                p.data.normal_(0, 0.02)

    # Some other parameters
    args.lstm_hidden = args.hidden_size
    args.lstm_layers = 2
    args.lstm_dropout = 0.1
    if torch.cuda.is_available():
        args.use_cuda = True

    # Build sequence labeling model.
    model = CCKSTagger(args, model)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = nn.DataParallel(model)
        model = model.module

    model = model.to(device)

    # Datset loader.
    def batch_loader(batch_size, input_ids, label_ids, mask_ids):
        instances_num = input_ids.size()[0]
        for i in range(instances_num // batch_size):
            input_ids_batch = input_ids[i * batch_size:(i + 1) * batch_size, :]
            label_ids_batch = label_ids[i * batch_size:(i + 1) * batch_size, :]
            mask_ids_batch = mask_ids[i * batch_size:(i + 1) * batch_size, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch
        if instances_num > instances_num // batch_size * batch_size:
            input_ids_batch = input_ids[instances_num // batch_size *
                                        batch_size:, :]
            label_ids_batch = label_ids[instances_num // batch_size *
                                        batch_size:, :]
            mask_ids_batch = mask_ids[instances_num // batch_size *
                                      batch_size:, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch

    # Read dataset.
    def read_dataset(path):
        dataset = []
        with open(path, mode="r", encoding="utf-8") as f:
            tokens, labels = [], []
            for line_id, line in enumerate(f):
                if line_id == 0:
                    continue
                line = line.strip().split()
                if len(line) != 2:
                    if len(labels) == 0:
                        continue
                    assert len(tokens) == len(labels)
                    tokens = [vocab.get(t) for t in tokens]
                    labels = [labels_map[l] for l in labels]
                    mask = [1] * len(tokens)
                    if len(tokens) > args.seq_length:
                        tokens = tokens[:args.seq_length]
                        labels = labels[:args.seq_length]
                        mask = mask[:args.seq_length]
                    while len(tokens) < args.seq_length:
                        tokens.append(0)
                        labels.append(0)
                        mask.append(0)
                    dataset.append([tokens, labels, mask])

                    tokens, labels = [], []
                    continue
                tokens.append(line[0])
                labels.append(line[1])

        return dataset

    # Evaluation function.
    def evaluate(args, is_test):
        if is_test:
            dataset = read_dataset(args.test_path)
        else:
            dataset = read_dataset(args.dev_path)

        input_ids = torch.LongTensor([sample[0] for sample in dataset])
        label_ids = torch.LongTensor([sample[1] for sample in dataset])
        mask_ids = torch.LongTensor([sample[2] for sample in dataset])

        instances_num = input_ids.size(0)
        batch_size = args.batch_size

        if is_test:
            print("Batch size: ", batch_size)
            print("The number of test instances:", instances_num)

        correct = 0
        gold_entities_num = 0
        pred_entities_num = 0

        confusion = torch.zeros(len(labels_map),
                                len(labels_map),
                                dtype=torch.long)

        model.eval()

        for i, (input_ids_batch, label_ids_batch, mask_ids_batch) in enumerate(
                batch_loader(batch_size, input_ids, label_ids, mask_ids)):
            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            # loss, _, pred, gold = model(input_ids_batch, label_ids_batch, mask_ids_batch)
            feats = model(input_ids_batch, label_ids_batch, mask_ids_batch)
            path_score, best_path = model.crf(feats, mask_ids_batch.byte())
            pred = best_path.contiguous().view(-1)
            gold = label_ids_batch.contiguous().view(-1)
            """ if i == 0:
                print('pred', pred)
                print('gold', gold) """

            # Gold.
            for j in range(gold.size()[0]):
                if (j > 0 and gold[j - 1].item() <= 1
                        and gold[j].item() > 1) or (j == 0
                                                    and gold[j].item() > 1):
                    gold_entities_num += 1

            # Predict.
            for j in range(pred.size()[0]):
                if (j > 0 and pred[j - 1].item() <= 1 and pred[j].item() > 1
                        and gold[j].item() != 0) or (j == 0
                                                     and pred[j].item() > 1):
                    pred_entities_num += 1

            pred_entities_pos = []
            gold_entities_pos = []
            start, end = 0, 0

            # Correct.
            for j in range(gold.size()[0]):
                if (j > 0 and gold[j - 1].item() <= 1
                        and gold[j].item() > 1) or (j == 0
                                                    and gold[j].item() > 1):
                    start = j
                    for k in range(j, gold.size()[0]):
                        if gold[k].item() <= 1:
                            end = k - 1
                            break
                    else:
                        end = gold.size()[0] - 1
                    gold_entities_pos.append((start, end))

            # Predict.
            for j in range(pred.size()[0]):
                if (j > 0 and pred[j - 1].item() <= 1
                        and pred[j].item() > 1) or (j == 0
                                                    and pred[j].item() > 1):
                    start = j
                    for k in range(j, pred.size()[0]):
                        if pred[k].item() <= 1:
                            end = k - 1
                            break
                    else:
                        end = pred.size()[0] - 1
                    pred_entities_pos.append((start, end))

            for entity in pred_entities_pos:
                if entity not in gold_entities_pos:
                    continue
                for j in range(entity[0], entity[1] + 1):
                    if gold[j].item() != pred[j].item():
                        break
                else:
                    correct += 1

        print("Report precision, recall, and f1:")
        p = correct / pred_entities_num
        r = correct / gold_entities_num
        f1 = 2 * p * r / (p + r)
        print("{:.3f}, {:.3f}, {:.3f}".format(p, r, f1))

        return f1

    # Training phase.
    print("Start training.")
    instances = read_dataset(args.train_path)

    input_ids = torch.LongTensor([ins[0] for ins in instances])
    label_ids = torch.LongTensor([ins[1] for ins in instances])
    mask_ids = torch.LongTensor([ins[2] for ins in instances])

    instances_num = input_ids.size(0)
    batch_size = args.batch_size
    train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    print("Batch size: ", batch_size)
    print("The number of training instances:", instances_num)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup,
                         t_total=train_steps)

    total_loss = 0.
    f1 = 0.0
    best_f1 = 0.0

    for epoch in range(1, args.epochs_num + 1):
        model.train()
        for i, (input_ids_batch, label_ids_batch, mask_ids_batch) in enumerate(
                batch_loader(batch_size, input_ids, label_ids, mask_ids)):
            model.zero_grad()

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            """ loss, _, _, _ = model(input_ids_batch, label_ids_batch, mask_ids_batch)
            if torch.cuda.device_count() > 1:
                loss = torch.mean(loss)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".format(epoch, i+1, total_loss / args.report_steps))
                total_loss = 0. """
            """ print("mask1:", mask_ids_batch)
            print("label1:", label_ids_batch) """
            feats = model(input_ids_batch, label_ids_batch, mask_ids_batch)
            """ print("feats:", feats) """
            loss = model.loss(feats, mask_ids_batch, label_ids_batch)
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Loss: {:.3f}".format(
                    epoch, i + 1, loss))

            loss.backward()
            optimizer.step()

        f1 = evaluate(args, False)
        if f1 > best_f1:
            best_f1 = f1
            save_model(model, args.output_model_path)
        #else:
        #    break

    # Evaluation phase.
    print("Start evaluation.")
    """ if torch.cuda.device_count() > 1:
        model.module.load_state_dict(torch.load(args.output_model_path))
    else:
        model.load_state_dict(torch.load(args.output_model_path)) """
    model.load_state_dict(torch.load(args.output_model_path))

    evaluate(args, True)
    """
예제 #29
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path",
                        default=None,
                        type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path",
                        default="./models/ner_model.bin",
                        type=str,
                        help="Path of the output model.")
    parser.add_argument("--vocab_path",
                        type=str,
                        required=True,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path",
                        type=str,
                        required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path",
                        type=str,
                        required=True,
                        help="Path of the devset.")
    parser.add_argument("--test_path", type=str, help="Path of the testset.")
    parser.add_argument("--config_path",
                        default="./models/bert_base_config.json",
                        type=str,
                        help="Path of the config file.")

    # Model options.
    parser.add_argument("--batch_size",
                        type=int,
                        default=32,
                        help="Batch_size.")
    parser.add_argument("--seq_length",
                        default=128,
                        type=int,
                        help="Sequence length.")
    parser.add_argument("--embedding",
                        choices=["bert", "word"],
                        default="bert",
                        help="Emebdding type.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                                   "cnn", "gatedcnn", "attn", \
                                                   "rcnn", "crnn", "gpt", "bilstm"], \
                                                   default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional",
                        action="store_true",
                        help="Specific to recurrent model.")

    # Subword options.
    parser.add_argument("--subword_type",
                        choices=["none", "char"],
                        default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path",
                        type=str,
                        default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder",
                        choices=["avg", "lstm", "gru", "cnn"],
                        default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num",
                        type=int,
                        default=2,
                        help="The number of subencoder layers.")

    # Optimizer options.
    parser.add_argument("--learning_rate",
                        type=float,
                        default=2e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup",
                        type=float,
                        default=0.1,
                        help="Warm up value.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout.")
    parser.add_argument("--epochs_num",
                        type=int,
                        default=3,
                        help="Number of epochs.")
    parser.add_argument("--report_steps",
                        type=int,
                        default=100,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7, help="Random seed.")

    args = parser.parse_args()

    # Load the hyperparameters of the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    labels_map = {"[PAD]": 0}
    begin_ids = []

    # Find tagging labels
    with open(args.train_path, mode="r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            if line_id == 0:
                continue
            labels = line.strip().split("\t")[1].split()
            for l in labels:
                if l not in labels_map:
                    if l.startswith("B") or l.startswith("S"):
                        begin_ids.append(len(labels_map))
                    labels_map[l] = len(labels_map)

    print("Labels: ", labels_map)
    args.labels_num = len(labels_map)

    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build bert model.
    # A pseudo target is added.
    args.target = "bert"
    model = build_model(args)

    # Load or initialize parameters.
    if args.pretrained_model_path is not None:
        # Initialize with pretrained model.
        model.load_state_dict(torch.load(args.pretrained_model_path),
                              strict=False)
    else:
        # Initialize with normal distribution.
        for n, p in list(model.named_parameters()):
            if 'gamma' not in n and 'beta' not in n:
                p.data.normal_(0, 0.02)

    # Build sequence labeling model.
    model = BertTagger(args, model)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(
            torch.cuda.device_count()))
        model = nn.DataParallel(model)

    model = model.to(device)

    # Datset loader.
    def batch_loader(batch_size, input_ids, label_ids, mask_ids):
        instances_num = input_ids.size()[0]
        for i in range(instances_num // batch_size):
            input_ids_batch = input_ids[i * batch_size:(i + 1) * batch_size, :]
            label_ids_batch = label_ids[i * batch_size:(i + 1) * batch_size, :]
            mask_ids_batch = mask_ids[i * batch_size:(i + 1) * batch_size, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch
        if instances_num > instances_num // batch_size * batch_size:
            input_ids_batch = input_ids[instances_num // batch_size *
                                        batch_size:, :]
            label_ids_batch = label_ids[instances_num // batch_size *
                                        batch_size:, :]
            mask_ids_batch = mask_ids[instances_num // batch_size *
                                      batch_size:, :]
            yield input_ids_batch, label_ids_batch, mask_ids_batch

    # Read dataset.
    def read_dataset(path):
        dataset = []
        with open(path, mode="r", encoding="utf-8") as f:
            f.readline()
            tokens, labels = [], []
            for line_id, line in enumerate(f):
                tokens, labels = line.strip().split("\t")
                tokens = [vocab.get(t) for t in tokens.split(" ")]
                labels = [labels_map[l] for l in labels.split(" ")]
                mask = [1] * len(tokens)
                if len(tokens) > args.seq_length:
                    tokens = tokens[:args.seq_length]
                    labels = labels[:args.seq_length]
                    mask = mask[:args.seq_length]
                while len(tokens) < args.seq_length:
                    tokens.append(0)
                    labels.append(0)
                    mask.append(0)
                dataset.append([tokens, labels, mask])

        return dataset

    # Evaluation function.
    def evaluate(args, is_test):
        if is_test:
            dataset = read_dataset(args.test_path)
        else:
            dataset = read_dataset(args.dev_path)

        input_ids = torch.LongTensor([sample[0] for sample in dataset])
        label_ids = torch.LongTensor([sample[1] for sample in dataset])
        mask_ids = torch.LongTensor([sample[2] for sample in dataset])

        instances_num = input_ids.size(0)
        batch_size = args.batch_size

        if is_test:
            print("Batch size: ", batch_size)
            print("The number of test instances:", instances_num)

        correct = 0
        gold_entities_num = 0
        pred_entities_num = 0

        confusion = torch.zeros(len(labels_map),
                                len(labels_map),
                                dtype=torch.long)

        model.eval()

        for i, (input_ids_batch, label_ids_batch, mask_ids_batch) in enumerate(
                batch_loader(batch_size, input_ids, label_ids, mask_ids)):
            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            loss, _, pred, gold = model(input_ids_batch, label_ids_batch,
                                        mask_ids_batch)

            for j in range(gold.size()[0]):
                if gold[j].item() in begin_ids:
                    gold_entities_num += 1

            for j in range(pred.size()[0]):
                if pred[j].item(
                ) in begin_ids and gold[j].item() != labels_map["[PAD]"]:
                    pred_entities_num += 1

            pred_entities_pos = []
            gold_entities_pos = []
            start, end = 0, 0

            for j in range(gold.size()[0]):
                if gold[j].item() in begin_ids:
                    start = j
                    for k in range(j + 1, gold.size()[0]):
                        if gold[k].item(
                        ) == labels_map["[PAD]"] or gold[k].item(
                        ) == labels_map["O"] or gold[k].item() in begin_ids:
                            end = k - 1
                            break
                    else:
                        end = gold.size()[0] - 1
                    gold_entities_pos.append((start, end))

            for j in range(pred.size()[0]):
                if pred[j].item(
                ) in begin_ids and gold[j].item() != labels_map["[PAD]"]:
                    start = j
                    for k in range(j + 1, pred.size()[0]):
                        if pred[k].item(
                        ) == labels_map["[PAD]"] or pred[k].item(
                        ) == labels_map["O"] or pred[k].item() in begin_ids:
                            end = k - 1
                            break
                    else:
                        end = pred.size()[0] - 1
                    pred_entities_pos.append((start, end))

            for entity in pred_entities_pos:
                if entity not in gold_entities_pos:
                    continue
                for j in range(entity[0], entity[1] + 1):
                    if gold[j].item() != pred[j].item():
                        break
                else:
                    correct += 1

        print("Report precision, recall, and f1:")
        p = correct / pred_entities_num
        r = correct / gold_entities_num
        f1 = 2 * p * r / (p + r)
        print("{:.3f}, {:.3f}, {:.3f}".format(p, r, f1))

        return f1

    # Training phase.
    print("Start training.")
    instances = read_dataset(args.train_path)

    input_ids = torch.LongTensor([ins[0] for ins in instances])
    label_ids = torch.LongTensor([ins[1] for ins in instances])
    mask_ids = torch.LongTensor([ins[2] for ins in instances])

    instances_num = input_ids.size(0)
    batch_size = args.batch_size
    train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    print("Batch size: ", batch_size)
    print("The number of training instances:", instances_num)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]
    ooptimizer = AdamW(optimizer_grouped_parameters,
                       lr=args.learning_rate,
                       correct_bias=False)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=train_steps * args.warmup,
                                     t_total=train_steps)

    total_loss = 0.
    f1 = 0.0
    best_f1 = 0.0

    for epoch in range(1, args.epochs_num + 1):
        model.train()
        for i, (input_ids_batch, label_ids_batch, mask_ids_batch) in enumerate(
                batch_loader(batch_size, input_ids, label_ids, mask_ids)):
            model.zero_grad()

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)

            loss, _, _, _ = model(input_ids_batch, label_ids_batch,
                                  mask_ids_batch)
            if torch.cuda.device_count() > 1:
                loss = torch.mean(loss)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".
                      format(epoch, i + 1, total_loss / args.report_steps))
                total_loss = 0.

            loss.backward()
            optimizer.step()
            scheduler.step()

        f1 = evaluate(args, False)
        if f1 > best_f1:
            best_f1 = f1
            save_model(model, args.output_model_path)
        else:
            continue

    # Evaluation phase.
    if args.test_path is not None:
        print("Test set evaluation.")
        model = load_model(model, args.output_model_path)
        evaluate(args, True)
예제 #30
0
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--pretrained_model_path", default=None, type=str,
                        help="Path of the pretrained model.")
    parser.add_argument("--output_model_path", default="./models/classifier_model.bin", type=str,
                        help="Path of the output model.")
    parser.add_argument("--vocab_path", default="./models/google_vocab.txt", type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--train_path", type=str, required=True,
                        help="Path of the trainset.")
    parser.add_argument("--dev_path", type=str, required=True,
                        help="Path of the devset.") 
    parser.add_argument("--test_path", type=str, required=True,
                        help="Path of the testset.")
    parser.add_argument("--config_path", default="./models/google_config.json", type=str,
                        help="Path of the config file.")
    # parser.add_argument("--BM25_path", type=str, required=True,
    #                     help="Path of the BM25.")
    # parser.add_argument("--BM25Type_path", type=str, required=True,
    #                     help="Path of the BM25Type_path.")
    # parser.add_argument("--BM25_test_path", type=str, required=True,
    #                     help="Path of the BM25Test_path.")
    # parser.add_argument("--Stopword_path", type=str, required=True,
    #                     help="Path of the Stopword.")

    # Model options.
    parser.add_argument("--batch_size", type=int, default=32,
                        help="Batch size.")
    parser.add_argument("--seq_length", type=int, default=256,
                        help="Sequence length.")
    parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
                                                   "cnn", "gatedcnn", "attn", \
                                                   "rcnn", "crnn", "gpt", "bilstm"], \
                                                   default="bert", help="Encoder type.")
    parser.add_argument("--bidirectional", action="store_true", help="Specific to recurrent model.")
    parser.add_argument("--pooling", choices=["mean", "max", "first", "last"], default="first",
                        help="Pooling type.")

    # Subword options.
    parser.add_argument("--subword_type", choices=["none", "char"], default="none",
                        help="Subword feature type.")
    parser.add_argument("--sub_vocab_path", type=str, default="models/sub_vocab.txt",
                        help="Path of the subword vocabulary file.")
    parser.add_argument("--subencoder", choices=["avg", "lstm", "gru", "cnn"], default="avg",
                        help="Subencoder type.")
    parser.add_argument("--sub_layers_num", type=int, default=2, help="The number of subencoder layers.")

    # Tokenizer options.
    parser.add_argument("--tokenizer", choices=["bert", "char", "word", "space"], default="bert",
                        help="Specify the tokenizer." 
                             "Original Google BERT uses bert tokenizer on Chinese corpus."
                             "Char tokenizer segments sentences into characters."
                             "Word tokenizer supports online word segmentation based on jieba segmentor."
                             "Space tokenizer segments sentences into words according to space."
                             )

    # Optimizer options.
    parser.add_argument("--learning_rate", type=float, default=2e-5,
                        help="Learning rate.")
    parser.add_argument("--warmup", type=float, default=0.1,
                        help="Warm up value.")

    # Training options.
    parser.add_argument("--dropout", type=float, default=0.5,
                        help="Dropout.")
    parser.add_argument("--epochs_num", type=int, default=5,
                        help="Number of epochs.")
    parser.add_argument("--report_steps", type=int, default=100,
                        help="Specific steps to print prompt.")
    parser.add_argument("--seed", type=int, default=7,
                        help="Random seed.")

    # Evaluation options.
    parser.add_argument("--mean_reciprocal_rank", action="store_true", help="Evaluation metrics for DBQA dataset.")

    # kg
    parser.add_argument("--kg_name", required=False, help="KG name or path")
    parser.add_argument("--workers_num", type=int, default=1, help="number of process for loading dataset")
    parser.add_argument("--no_vm", action="store_true", help="Disable the visible_matrix")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    # Count the number of labels.
    labels_set = set()
    columns = {}
    with open(args.train_path, mode="r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            try:
                line = line.strip().split("\t")
                if line_id == 0:
                    for i, column_name in enumerate(line):
                        columns[column_name] = i
                    continue
                label = int(line[columns["label"]])
                labels_set.add(label)
            except:
                pass

    args.labels_num = len(labels_set) 
    # Load vocabulary.
    vocab = Vocab()
    vocab.load(args.vocab_path)
    args.vocab = vocab

    # Build bert model.
    # A pseudo target is added.
    args.target = "bert"
    model = build_model(args)

    # Load or initialize parameters.
    if args.pretrained_model_path is not None:
        # Initialize with pretrained model.
        model.load_state_dict(torch.load(args.pretrained_model_path), strict=False)  
    else:
        # Initialize with normal distribution.
        for n, p in list(model.named_parameters()):
            if 'gamma' not in n and 'beta' not in n:
                p.data.normal_(0, 0.02)
    
    # Build classification model.
    model = BertClassifier(args, model)

    # For simplicity, we use DataParallel wrapper to use multiple GPUs.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)

    model = model.to(device)
    
    # Datset loader.
    def batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vms):
        instances_num = input_ids.size()[0]
        for i in range(instances_num // batch_size):
            input_ids_batch = input_ids[i*batch_size: (i+1)*batch_size, :]
            label_ids_batch = label_ids[i*batch_size: (i+1)*batch_size]
            mask_ids_batch = mask_ids[i*batch_size: (i+1)*batch_size, :]
            pos_ids_batch = pos_ids[i*batch_size: (i+1)*batch_size, :]
            vms_batch = vms[i*batch_size: (i+1)*batch_size]
            yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch
        if instances_num > instances_num // batch_size * batch_size:
            input_ids_batch = input_ids[instances_num//batch_size*batch_size:, :]
            label_ids_batch = label_ids[instances_num//batch_size*batch_size:]
            mask_ids_batch = mask_ids[instances_num//batch_size*batch_size:, :]
            pos_ids_batch = pos_ids[instances_num//batch_size*batch_size:, :]
            vms_batch = vms[instances_num//batch_size*batch_size:]

            yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch

    # Build knowledge graph.
    if args.kg_name == 'none':
        spo_files = []
    else:
        spo_files = [args.kg_name]
    kg = KnowledgeGraph(spo_files=spo_files, predicate=True)

    def read_dataset(path, workers_num=1):

        print("Loading sentences from {}".format(path))
        sentences = []
        with open(path, mode='r', encoding="utf-8") as f:
            for line_id, line in enumerate(f):
                if line_id == 0:
                    continue
                sentences.append(line)
        sentence_num = len(sentences)

        print("There are {} sentence in total. We use {} processes to inject knowledge into sentences.".format(sentence_num, workers_num))
        if workers_num > 1:
            params = []
            sentence_per_block = int(sentence_num / workers_num) + 1
            for i in range(workers_num):
                params.append((i, sentences[i*sentence_per_block: (i+1)*sentence_per_block], columns, kg, vocab, args))
            pool = Pool(workers_num)
            res = pool.map(add_knowledge_worker, params)
            pool.close()
            pool.join()
            dataset = [sample for block in res for sample in block]
        else:
            params = (0, sentences, columns, kg, vocab, args)
            dataset = add_knowledge_worker(params)

        return dataset

    # Evaluation function.
    def evaluate(args, is_test, metrics='f1'):
        
        if is_test:
            dataset = read_dataset(args.test_path, workers_num=args.workers_num)
           
        else:
            dataset = read_dataset(args.dev_path, workers_num=args.workers_num)
        
       
        input_ids = torch.LongTensor([sample[0] for sample in dataset])
        label_ids = torch.LongTensor([sample[1] for sample in dataset])
        mask_ids = torch.LongTensor([sample[2] for sample in dataset])
        pos_ids = torch.LongTensor([example[3] for example in dataset])
        vms = [example[4] for example in dataset]

        batch_size = args.batch_size
        instances_num = input_ids.size()[0]
        if is_test:
            print("The number of evaluation instances: ", instances_num)

        correct = 0
        # Confusion matrix.
        confusion = torch.zeros(args.labels_num, args.labels_num, dtype=torch.long)

        model.eval()
       # BM25Rank,train_type_dict=BM25()

        # start=0
        # end=batch_size
        
        #if not args.mean_reciprocal_rank:
            
        if not args.mean_reciprocal_rank:
            print("計算F1")
            ls=[]
            data=[]
            ls_gold=[]
            ls_pred=[]
          
            for i, (input_ids_batch, label_ids_batch,  mask_ids_batch, pos_ids_batch, vms_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vms)):
               
                
                # RankResult=[]
                # BatachBM25=BM25Rank[start:end]
                # start=start+batch_size
                # end=start+batch_size
                # vms_batch = vms_batch.long()
                vms_batch = torch.LongTensor(vms_batch)

                input_ids_batch = input_ids_batch.to(device)
                label_ids_batch = label_ids_batch.to(device)
                mask_ids_batch = mask_ids_batch.to(device)
                pos_ids_batch = pos_ids_batch.to(device)
                vms_batch = vms_batch.to(device)
                ls.append(label_ids_batch)
                
                with torch.no_grad():
                    try:
                        loss, logits = model(input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch)
                    except:
                        print(input_ids_batch)
                        print(input_ids_batch.size())
                        print(vms_batch)
                        print(vms_batch.size())

                logits = nn.Softmax(dim=1)(logits)
                pred = torch.argmax(logits, dim=1)
               
                gold = label_ids_batch
                
                # for (l,B) in zip(logits,BatachBM25):
              
                    # sort_label=torch.sort(l,descending=True)
                    # sort_dict=dict()
                    # for (score,label) in zip(sort_label[0],sort_label[1]):
                        # label=int(label.cpu().numpy())
                        # score=float(score.cpu().numpy())
                        # sort_dict[label]=score
                    # print(sorted(B, key = lambda x : x[1]))
                    # print(max(sort_dict.values()))
                    # maximum=max(sort_dict.values())
                    # keys = [key for key, value in sort_dict.items() if value == maximum] 
                    # print(keys)
                    # sys.exit()
                    # ReRank= dict()
                    # TotalScore=0
                    # for b in B:
                        # TotalScore+=b[1]
                    
                    # for b in B:
                        # if(sort_dict[train_type_dict[b[0]]]>0.953):
                        #     ReRank[b[0]]=sort_dict[train_type_dict[b[0]]]  #b[1]+
                        # else:b[1]+sort_dict[train_type_dict[b[0]]]((b[1]/949.11)*100)+
                        # ReRank[b[0]]=sort_dict[train_type_dict[b[0]]]
        
                    # maximum = max(ReRank.values()) 
                    # keys = [key for key, value in ReRank.items() if value == maximum]
                    # result=train_type_dict[keys[0]]
                    # gold = label_ids_batch
                    # RankResult.append(result)
   
                # pred=torch.Tensor(RankResult).cuda()
                print("Pred")
                print(pred)
                # gold = label_ids_batch
                print("Gold")
                print(gold)
                for (p,g) in zip(pred.cpu().numpy(),gold.cpu().numpy()):
                    ls_gold.append(g)
                    ls_pred.append(int(p))
              
                # for j in range(pred.size()[0]):
                #     confusion[pred[j], gold[j]] += 1
                # correct += torch.sum(pred == gold).item()
            print(ls_gold)
            print(ls_pred)
            print("準確率")
            print(accuracy_score(ls_gold, ls_pred))
        
            f1 = f1_score(ls_gold,ls_pred, average='macro')
            p = precision_score(ls_gold, ls_pred, average='macro')
            r = recall_score(ls_gold, ls_pred, average='macro')
            print("F1 Score:")
            print(f1)
            print("Precision:")
            print(p)
            print("Recall:")
            print(r)
         
        else:
            print("計算MRR")
            rank_pos=[]
            for i, (input_ids_batch, label_ids_batch,  mask_ids_batch, pos_ids_batch, vms_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vms)):

                vms_batch = torch.LongTensor(vms_batch)

                input_ids_batch = input_ids_batch.to(device)
                label_ids_batch = label_ids_batch.to(device)
                mask_ids_batch = mask_ids_batch.to(device)
                pos_ids_batch = pos_ids_batch.to(device)
                vms_batch = vms_batch.to(device)

                with torch.no_grad():
                    loss, logits = model(input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch)
                logits = nn.Softmax(dim=1)(logits)
                
                for l,g in zip(logits,label_ids_batch):
                   
                    sort_label=torch.sort(l,descending=True)
                    for (idx,p) in enumerate(sort_label[1]):
                        if(p==g):
                            rank_pos.append(idx)
            MRR=0
            for score in rank_pos:
                MRR+=1/(score+1)
            print("MRR:",MRR/1035)
           
    
    def BM25_readfile():
        #讀取BM25的訓練資料(標準問句斷過詞)
        BM25_train_data=[]
        with open(args.BM25_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
            
                BM25_train_data.append(line.replace("\r", "").replace("\n", ""))
        #讀取stop word
        stopWords=[]
        with open(args.Stopword_path, 'r', encoding='UTF-8') as file:
            for data in file.readlines():
                data = data.strip()
                stopWords.append(data)
        return stopWords,BM25_train_data
    
    def BM25():
        stopWords,BM25_train_data=BM25_readfile()  
        count_l1_sim=0
        count_num_l1=0
        count_l0_sim=0
        count_num_l0=0
        correct=0
        predict=[]
        similarity_ls=[]
        for sentence in BM25_train_data[1:]:
            S1=sentence.strip().split('\t')[0]
            S2=sentence.strip().split('\t')[1]
            words1 = [w for w in S1.split() if w not in stopWords]
            S1=" ".join(words1)
            words2 = [w for w in S2.split() if w not in stopWords]
            S2=" ".join(words2)

            sim=similarity.ssim(S1,S2,model='bm25')
            similarity_ls.append(sim)
            if(sim<11):
                predict_label=0
                predict.append(predict_label)

            if(sim>11):
                predict_label=1
                predict.append(predict_label)
      
     
    # Evaluation phase.
    print("Final evaluation on the test dataset.")

    if torch.cuda.device_count() > 1:
        model.module.load_state_dict(torch.load(args.output_model_path))
    else:
        model.load_state_dict(torch.load(args.output_model_path))
    accuracy=evaluate(args, True,metrics='Acc')
  
    print("準確率:")
    print(accuracy)