예제 #1
0
def main():
    config = json.load(open('config.json', 'r'))

    set_seed(config["seed"])

    if not os.path.exists(config["output_dir"]):
        os.makedirs(config["output_dir"])
    if not os.path.exists(config["save_dir"]):
        os.makedirs(config["save_dir"])

    # model_config = transformers.BertConfig.from_pretrained(config["model_name"])
    # tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
    tokenizer = BertTokenizer.from_pretrained(config["model_name"])
    model = BertForClassification(config["model_name"])
    # model = AutoModelForMultipleChoice.from_pretrained(config["model_name"])

    model.cuda()

    processor = DataProcessor(config["data_dir"])

    train_examples = processor.get_train_examples()
    train_dataset = processor.get_dataset(train_examples, tokenizer,
                                          config["max_length"])

    valid_examples = processor.get_dev_examples()
    valid_dataset = processor.get_dataset(valid_examples, tokenizer,
                                          config["max_length"])

    test_examples = processor.get_test_examples()
    test_dataset = processor.get_dataset(test_examples, tokenizer,
                                         config["max_length"])

    train(config, model, train_dataset, valid_dataset)
    result = evaluate(config, model, test_dataset)
    print(result[:2])
예제 #2
0
class BertComputation(RequestHandler):
    """
    Request handler that computes embeddings on english sentences
    """

    enable_cuda = True if torch.cuda.is_available() else False
    CLS = '[CLS]'
    SEP = '[SEP]'
    MSK = '[MASK]'
    pretrained_weights = 'bert-base-uncased'
    tokenizer = BertTokenizer.from_pretrained(pretrained_weights)
    model = BertModel(config=BertConfig())
    model.eval()
    if enable_cuda:
        model.cuda(torch.device('cuda'))

    def process_request(self, data):
        tokens = []
        mask_ids = []
        seg_ids = []
        seq_len = 0
        for txt in data:
            utt = txt.strip().lower()
            toks = self.tokenizer.tokenize(utt)
            if len(toks) > max_seq_len - 2:
                toks = toks[:max_seq_len - 2]
            toks.insert(0, self.CLS)
            toks.insert(-1, self.SEP)
            seq_len = max(len(toks), seq_len)
            mask_ids.append([1] * len(toks) + [0] * (seq_len - len(toks)))
            seg_ids.append([0] * seq_len)
            tokens.append(toks + [self.MSK] * (seq_len - len(toks)))
        input_ids = []
        for i in range(len(tokens)):
            for _ in range(seq_len - len(tokens[i])):
                tokens[i].append(self.MSK)
                mask_ids[i].append(0)
                seg_ids[i].append(0)
            input_ids.append(self.tokenizer.convert_tokens_to_ids(tokens[i]))
        return (torch.tensor(input_ids), torch.tensor(mask_ids),
                torch.tensor(seg_ids))

    def run_inference(self, model_input):
        if self.enable_cuda:
            input_ids, mask_ids, seg_ids = [x.to('cuda') for x in model_input]
        else:
            input_ids, mask_ids, seg_ids = model_input
        cls_emb = self.model.forward(input_ids=input_ids,
                                     attention_mask=mask_ids,
                                     token_type_ids=seg_ids)[1]
        return cls_emb

    def process_response(self, model_output_item):
        return encode_pickle(model_output_item.cpu().detach().numpy())
예제 #3
0
def test_tokenizer_batch_encode(tokenizer: BertTokenizer, sentence: List[str]):
    first_sentence, second_sentence = ["I love china", "I hate china"], [
        "I also love china", "I also hate china"
    ]
    output: BatchEncoding = tokenizer.batch_encode_plus(
        batch_text_or_text_pairs=first_sentence,
        padding="max_length",
        max_length=200,
        return_tensors='pt',
        return_token_type_ids=True,
        return_attention_mask=True,
        return_special_tokens_mask=True,
        return_length=True)
    assert output.input_ids.shape == (2, 200)
예제 #4
0
def test_tokenizer_call(tokenizer: BertTokenizer, sentence: List[str]):
    # 可针对于single text or batch text 进行编码
    first_sentence, second_sentence = ["I love china", "I hate china"], [
        "I also love china", "I also hate china"
    ]
    output: BatchEncoding = tokenizer.__call__(text=first_sentence,
                                               text_pair=second_sentence,
                                               padding=True,
                                               max_length=20,
                                               return_tensors='pt',
                                               return_token_type_ids=True,
                                               return_attention_mask=True,
                                               return_special_tokens_mask=True,
                                               return_length=True)
    assert output.input_ids.shape == (2, 10)
예제 #5
0
def main(args):
    utils.import_user_module(args)

    os.makedirs(args.destdir, exist_ok=True)

    logger.addHandler(
        logging.FileHandler(
            filename=os.path.join(args.destdir, "preprocess.log"),
        )
    )
    logger.info(args)

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    target = not args.only_source

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert (
            not args.srcdict or not args.tgtdict
        ), "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert (
                args.trainpref
            ), "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {train_path(lang) for lang in [args.source_lang, args.target_lang]},
                src=True,
            )
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        else:
            assert (
                args.trainpref
            ), "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)], src=True)

        if target:
            if args.tgtdict:
                tgt_dict = task.load_dictionary(args.tgtdict)
            else:
                assert (
                    args.trainpref
                ), "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers, avoid_tokenize=False):
        if vocab is not None:
            print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        else:
            print('| Using None Dictionary and only string split is performed.')

        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(
            input_prefix, ("." + lang) if lang is not None else ""
        )
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                # TODO: worker > 1 is not working for map dataset
                if args.input_mapping is True:
                    raise NotImplementedError("Worker > 1 is not implemented for map dataset yet.")
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        vocab,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                        avoid_tokenize,
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(
            dataset_dest_file(args, output_prefix, lang, "bin"),
            impl=args.dataset_impl,
            vocab_size=len(vocab) if vocab is not None else -1,
        )
        merge_result(
            Binarizer.binarize(
                input_file, vocab, lambda t: ds.add_item(t), offset=0, end=offsets[1], avoid_tokenize=avoid_tokenize,
            )
        )
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        if vocab is not None:
            unk = vocab.unk_word if hasattr(vocab, 'unk_word') else vocab.unk_token
        else:
            unk = ""
        logger.info(
            "[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                unk,
            )
        )

    def make_binary_alignment_dataset(input_prefix, output_prefix, num_workers):
        nseq = [0]

        def merge_result(worker_result):
            nseq[0] += worker_result["nseq"]

        input_file = input_prefix
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize_alignments,
                    (
                        args,
                        input_file,
                        utils.parse_alignment,
                        prefix,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(
            dataset_dest_file(args, output_prefix, None, "bin"), impl=args.dataset_impl
        )

        merge_result(
            Binarizer.binarize_alignments(
                input_file,
                utils.parse_alignment,
                lambda t: ds.add_item(t),
                offset=0,
                end=offsets[1],
            )
        )
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, None)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))

        logger.info("[alignments] {}: parsed {} alignments".format(input_file, nseq[0]))

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1, avoid_tokenize=False):
        output_prefix += '.bert' if isinstance(vocab, BertTokenizer) and not isinstance(vocab, ElectraTokenizer) else ''
        input_prefix += '.bert' if isinstance(vocab, BertTokenizer) and not isinstance(vocab, ElectraTokenizer) else ''
        output_prefix += '.bart' if isinstance(vocab, BartTokenizer) else ''
        input_prefix += '.bart' if isinstance(vocab, BartTokenizer) else ''
        output_prefix += '.electra' if isinstance(vocab, ElectraTokenizer) else ''
        input_prefix += '.electra' if isinstance(vocab, ElectraTokenizer) else ''

        if args.dataset_impl == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)
        else:
            make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers, avoid_tokenize=avoid_tokenize)
            # map prefix
            map_input_prefix = input_prefix + '.map'
            map_output_prefix = output_prefix + '.map'
            # if existed mapping files
            if os.path.exists("{}.{}".format(map_input_prefix, lang)) and args.input_mapping is True:
                #import pdb; pdb.set_trace()
                make_binary_dataset(None, map_input_prefix, map_output_prefix, lang, num_workers, avoid_tokenize=True)
    def make_all(lang, vocab):
        if args.trainpref:
            make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers, avoid_tokenize=args.avoid_tokenize_extras)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(
                    vocab, validpref, outprefix, lang, num_workers=args.workers
                )
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers)

    def make_all_alignments():
        if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix):
            make_binary_alignment_dataset(
                args.trainpref + "." + args.align_suffix,
                "train.align",
                num_workers=args.workers,
            )
        if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix):
            make_binary_alignment_dataset(
                args.validpref + "." + args.align_suffix,
                "valid.align",
                num_workers=args.workers,
            )
        if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix):
            make_binary_alignment_dataset(
                args.testpref + "." + args.align_suffix,
                "test.align",
                num_workers=args.workers,
            )

    make_all(args.source_lang, src_dict)
    if target:
        make_all(args.target_lang, tgt_dict)
    if args.bert_model_name:
        berttokenizer = BertTokenizer.from_pretrained(args.bert_model_name, do_lower_case=False)
        make_all(args.source_lang, berttokenizer)
    if args.bart_model_name:
        barttokenizer = BartTokenizer.from_pretrained(args.bart_model_name, do_lower_case=False)
        make_all(args.source_lang, barttokenizer)
    # if args.electra_model_name:
    #     electratokenizer = ElectraTokenizer.from_pretrained(args.electra_model_name)
    #     make_all(args.source_lang, electratokenizer)
    if args.align_suffix:
        make_all_alignments()

    logger.info("Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding="utf-8") as align_file:
            with open(src_file_name, "r", encoding="utf-8") as src_file:
                with open(tgt_file_name, "r", encoding="utf-8") as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")), a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)

        with open(
            os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
            ),
            "w",
            encoding="utf-8",
        ) as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
예제 #6
0
import torch
from transformers.models.bert import BertModel, BertTokenizer

model_name = '/data/project/learn_code/data/chinese-bert-wwm-ext/'
# 读取模型对应的tokenizer
tokenizer = BertTokenizer.from_pretrained(model_name)
# 载入模型
model = BertModel.from_pretrained(model_name)
# 输入文本
# input_text = "Here is some text to encode"
input_text = "今天天气很好啊,你好吗"
# 通过tokenizer把文本变成 token_id
input_ids = tokenizer.encode(input_text, add_special_tokens=True)
print(len(input_ids))
# input_ids: [101, 2182, 2003, 2070, 3793, 2000, 4372, 16044, 102]
input_ids = torch.tensor([input_ids])
# 获得BERT模型最后一个隐层结果
print(input_ids.shape)
with torch.no_grad():
    last_hidden_states = model(input_ids)[0]
    print(last_hidden_states)
    print(last_hidden_states.shape)
예제 #7
0
def tokenizer(model_name) -> BertTokenizer:
    return BertTokenizer.from_pretrained(model_name)
예제 #8
0
    parser.add_argument('--test',
                        action="store_true",
                        default=False,
                        help='test')
    parser.add_argument('--predict',
                        action="store_true",
                        default=False,
                        help='predict')
    args = parser.parse_args()

    config = Config()

    never_split_token = [str(num) for num in range(30000)]

    if args.train:
        tokenizer = BertTokenizer.from_pretrained(
            config.pretrained_model_folder, do_basic_tokenize=True)
        tokenizer.save_pretrained(config.model_folder)
        train_data_loader, test_data_loader = prepare_qqsim_train_test_org_data_loader(
            './data/qq_number/gaiic_track3_round1_train_20210228.tsv',
            tokenizer=tokenizer,
            batch_size=config.train_batch_size,
            test_scale=0.1,
        )

        train(config, train_data_loader, test_data_loader)

    elif args.test:
        tokenizer = BertTokenizer.from_pretrained(
            config.pretrained_model_folder, do_basic_tokenize=True)
        train_data_loader, test_data_loader = prepare_qqsim_train_test_org_data_loader(
            './data/qq_number/gaiic_track3_round1_train_20210228.tsv',
예제 #9
0
        self.init_weights()

    def forward(self, input_ids, token_type_ids, output_type="pooler"):
        """
        :param input_ids:
        :param token_type_ids:
        :param output_type:  "seq2seq" or "pooler"
        :return:
        """
        sequence_output, pooled_output = self.bert(input_ids, token_type_ids)

        if output_type == "pooler":
            return pooled_output
        else:
            prediction_scores, _ = self.cls(sequence_output,
                                            pooled_output)  # [b,s,V]
            return prediction_scores


if __name__ == '__main__':
    bert_wwm_pt_path = "/data/project/learn_code/data/chinese-bert-wwm-ext"
    config = bert_wwm_pt_path + "/config.json"
    tmp_state_dict = torch.load(bert_wwm_pt_path + "/pytorch_model.bin",
                                map_location="cpu")
    tokenizer = BertTokenizer.from_pretrained(bert_wwm_pt_path)
    bjModel = BojoneModel.from_pretrained(
        pretrained_model_name_or_path=bert_wwm_pt_path,
        config=config,
        state_dict=tmp_state_dict,
        local_files_only=False)