Ejemplo n.º 1
0
def convert_multiple_choice_examples_to_features(examples: list,
                                                 tokenizer: BertTokenizer,
                                                 max_seq_length: int,
                                                 is_training: bool,
                                                 verbose: bool = False):
    features = []
    for idx, example in enumerate(examples):
        option_features = []
        for option in example.get_option_segments():
            context_tokens = tokenizer.tokenize(option['segment1'])
            if "segment2" in option:
                option_tokens = tokenizer.tokenize(option["segment2"])
                _truncate_seq_pair(context_tokens, option_tokens,
                                   max_seq_length - 3)
                tokens = ["[CLS]"] + context_tokens + [
                    "[SEP]"
                ] + option_tokens + ["[SEP]"]
                segment_ids = [0] * (len(context_tokens) +
                                     2) + [1] * (len(option_tokens) + 1)
            else:
                context_tokens = context_tokens[0:(max_seq_length - 2)]
                tokens = ["[CLS]"] + context_tokens + ["[SEP]"]
                segment_ids = [0] * len(tokens)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            input_mask = [1] * len(input_ids)

            padding = [0] * (max_seq_length - len(input_ids))
            input_ids += padding
            input_mask += padding
            segment_ids += padding

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            option_features.append(
                (tokens, input_ids, input_mask, segment_ids))

        label = example.label

        if idx < 5 and verbose:
            logger.info("*** Example ***")
            logger.info(f"example_id: {example.example_id}")
            for choice_idx, (tokens, input_ids, input_mask,
                             segment_ids) in enumerate(option_features):
                logger.info(f"choice: {choice_idx}")
                logger.info(f"tokens: {' '.join(tokens)}")
                logger.info(f"input_ids: {' '.join(map(str, input_ids))}")
                logger.info(f"input_mask: {' '.join(map(str, input_mask))}")
                logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}")
            if is_training:
                logger.info(f"label: {label}")

        features.append(
            MultipleChoiceFeatures(example_id=example.example_id,
                                   option_features=option_features,
                                   label=label))

    return features
Ejemplo n.º 2
0
class NemoBertTokenizer(TokenizerSpec):
    def __init__(self, pretrained_model=None,
                 vocab_file=None,
                 do_lower_case=True,
                 max_len=None,
                 do_basic_tokenize=True,
                 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
        if pretrained_model:
            self.tokenizer = BertTokenizer.from_pretrained(pretrained_model)
            if "uncased" not in pretrained_model:
                self.tokenizer.basic_tokenizer.do_lower_case = False
        else:
            self.tokenizer = BertTokenizer(vocab_file,
                                           do_lower_case,
                                           max_len,
                                           do_basic_tokenize,
                                           never_split)
        self.vocab_size = len(self.tokenizer.vocab)
        self.never_split = never_split

    def text_to_tokens(self, text):
        tokens = self.tokenizer.tokenize(text)
        return tokens

    def tokens_to_text(self, tokens):
        text = self.tokenizer.convert_tokens_to_string(tokens)
        return remove_spaces(handle_quotes(text.strip()))

    def token_to_id(self, token):
        return self.tokens_to_ids([token])[0]

    def tokens_to_ids(self, tokens):
        ids = self.tokenizer.convert_tokens_to_ids(tokens)
        return ids

    def ids_to_tokens(self, ids):
        tokens = self.tokenizer.convert_ids_to_tokens(ids)
        return tokens

    def text_to_ids(self, text):
        tokens = self.text_to_tokens(text)
        ids = self.tokens_to_ids(tokens)
        return ids

    def ids_to_text(self, ids):
        tokens = self.ids_to_tokens(ids)
        tokens_clean = [t for t in tokens if t not in self.never_split]
        text = self.tokens_to_text(tokens_clean)
        return text

    def pad_id(self):
        return self.tokens_to_ids(["[PAD]"])[0]

    def bos_id(self):
        return self.tokens_to_ids(["[CLS]"])[0]

    def eos_id(self):
        return self.tokens_to_ids(["[SEP]"])[0]
Ejemplo n.º 3
0
def sentence_pair_processing(data: list,tokenizer: BertTokenizer, max_sequence_length=88):
        
    max_bert_input_length = 0
    for sentence_pair in data:
        
        sentence_1_tokenized,sentence_2_tokenized = tokenizer.tokenize(sentence_pair['sentence_1']),tokenizer.tokenize(sentence_pair['sentence_2'])
        truncate_seq_pair(sentence_1_tokenized,sentence_2_tokenized,max_sequence_length-3)
        
        max_bert_input_length = max(max_bert_input_length, len(sentence_1_tokenized) + len(sentence_2_tokenized) + 3)
        sentence_pair['sentence_1_tokenized'] = sentence_1_tokenized
        sentence_pair['sentence_2_tokenized'] = sentence_2_tokenized
        
        dataset_input_ids = torch.empty((len(data), max_bert_input_length), dtype=torch.long)
        dataset_token_type_ids = torch.empty((len(data), max_bert_input_length), dtype=torch.long)
        dataset_attention_masks = torch.empty((len(data), max_bert_input_length), dtype=torch.long)
        dataset_scores = torch.empty((len(data), 1), dtype=torch.float)
        
    for idx, sentence_pair in enumerate(data):
        tokens = []
        input_type_ids = []

        tokens.append("[CLS]")
        input_type_ids.append(0)
        for token in sentence_pair['sentence_1_tokenized']:
            tokens.append(token)
            input_type_ids.append(0)
        tokens.append("[SEP]")
        input_type_ids.append(0)

        for token in sentence_pair['sentence_2_tokenized']:
            tokens.append(token)
            input_type_ids.append(1)
        tokens.append("[SEP]")
        input_type_ids.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        attention_masks = [1] * len(input_ids)
        while len(input_ids) < max_bert_input_length:
            input_ids.append(0)
            attention_masks.append(0)
            input_type_ids.append(0)

        dataset_input_ids[idx] = torch.tensor(input_ids, dtype=torch.long)
        dataset_token_type_ids[idx] = torch.tensor(input_type_ids, dtype=torch.long)
        dataset_attention_masks[idx] = torch.tensor(attention_masks, dtype=torch.long)
        if 'similarity' not in sentence_pair or sentence_pair['similarity'] is None:
            dataset_scores[idx] = torch.tensor(float('nan'), dtype=torch.float)
        else:
            dataset_scores[idx] = torch.tensor(sentence_pair['similarity'], dtype=torch.float)

    return dataset_input_ids, dataset_token_type_ids, dataset_attention_masks, dataset_scores
Ejemplo n.º 4
0
class BertBPE(object):
    @staticmethod
    def add_args(parser):
        # fmt: off
        parser.add_argument('--bpe-cased',
                            action='store_true',
                            help='set for cased BPE',
                            default=False)
        parser.add_argument('--bpe-vocab-file',
                            type=str,
                            help='bpe vocab file.')
        # fmt: on

    def __init__(self, args):
        try:
            from pytorch_transformers import BertTokenizer
            from pytorch_transformers.tokenization_utils import clean_up_tokenization
        except ImportError:
            raise ImportError(
                'Please install 1.0.0 version of pytorch_transformers'
                'with: pip install pytorch-transformers')

        if 'bpe_vocab_file' in args:
            self.bert_tokenizer = BertTokenizer(
                args.bpe_vocab_file, do_lower_case=not args.bpe_cased)
        else:
            vocab_file_name = 'bert-base-cased' if args.bpe_cased else 'bert-base-uncased'
            self.bert_tokenizer = BertTokenizer.from_pretrained(
                vocab_file_name)
            self.clean_up_tokenization = clean_up_tokenization

    def encode(self, x: str) -> str:
        return ' '.join(self.bert_tokenizer.tokenize(x))

    def decode(self, x: str) -> str:
        return self.clean_up_tokenization(
            self.bert_tokenizer.convert_tokens_to_string(x.split(' ')))

    def is_beginning_of_word(self, x: str) -> bool:
        return not x.startswith('##')
Ejemplo n.º 5
0
class BertProcessor(object):
    def __init__(self, vocab_path, do_lower_case, min_freq_words=None):
        self.tokenizer = BertTokenizer(vocab_file=vocab_path,
                                       do_lower_case=do_lower_case)

    def get_labels(self):
        return [
            '[CLS]', '[SEP]', 'O', 'B-NIHSS', 'B-Measurement', 'B-1a_LOC',
            'B-1b_LOCQuestions', 'B-1c_LOCCommands', 'B-2_BestGaze',
            'B-3_Visual', 'B-4_FacialPalsy', 'B-5_Motor', 'B-5a_LeftArm',
            'B-5b_RightArm', 'B-6_Motor', 'B-6a_LeftLeg', 'B-6b_RightLeg',
            'B-7_LimbAtaxia', 'B-8_Sensory', 'B-9_BestLanguage',
            'B-10_Dysarthria', 'B-11_ExtinctionInattention', 'I-NIHSS',
            'I-Measurement', 'I-1a_LOC', 'I-1b_LOCQuestions',
            'I-1c_LOCCommands', 'I-2_BestGaze', 'I-3_Visual',
            'I-4_FacialPalsy', 'I-5_Motor', 'I-5a_LeftArm', 'I-5b_RightArm',
            'I-6_Motor', 'I-6a_LeftLeg', 'I-6b_RightLeg', 'I-7_LimbAtaxia',
            'I-8_Sensory', 'I-9_BestLanguage', 'I-10_Dysarthria',
            'I-11_ExtinctionInattention'
        ]

    @classmethod
    def read_data(cls, input_file, quotechar=None):
        if 'pkl' in str(input_file):
            lines = tools.load_pickle(input_file)
        else:
            lines = input_file
        return lines

    def get_train(self, data_file):
        return self.read_data(data_file)

    def get_valid(self, data_file):
        return self.read_data(data_file)

    def get_test(self, data_file):
        return self.read_data(data_file)

    def create_examples(self, lines, example_type, cached_file):
        if cached_file.exists():
            tools.logger.info("Loading samples from cached files %s",
                              cached_file)
            examples = torch.load(cached_file)
        else:
            pbar = progressbar.ProgressBar(
                n_total=len(lines), desc=f'create {example_type} samples')
            examples = []
            for i, line in enumerate(lines):
                hadm_id = line['HADM_ID']
                guid = '%s-%s-%d' % (example_type, hadm_id, i)
                sentence = line['token']  # list
                sentence = [' ' if type(t) == float else t for t in sentence]
                label = line['tags']  # list
                code = line['code']  # brat entity Tcode T1 T2
                relations = line['relations']  # brat relations golden standard
                # text_a: string. The untokenized text of the first sequence. For single
                # sequence tasks, only this sequence must be specified.
                text_a = ' '.join(sentence)  # string
                text_b = None
                examples.append(
                    InputExample(guid=guid,
                                 text_a=text_a,
                                 text_b=text_b,
                                 label=label,
                                 code=code,
                                 relations=relations,
                                 hadm_id=hadm_id))
                pbar(step=i)
            tools.logger.info("Saving examples into cached file %s",
                              cached_file)
            torch.save(examples, cached_file)
        return examples

    def create_features(self, examples, max_seq_len, cached_file):
        if cached_file.exists():
            tools.logger.info('Loading features from cached file %s',
                              cached_file)
            features = torch.load(cached_file)
        else:
            label_list = self.get_labels()
            label2id = {label: i for i, label in enumerate(label_list)}
            pbar = progressbar.ProgressBar(
                n_total=len(examples),
                desc='creating the specified features of examples')
            features = []
            for example_id, example in enumerate(examples):
                hamd_id = example.hadm_id
                text_list = example.text_a.split(' ')  # string
                idx_CR = [
                    idx for idx, text in enumerate(text_list)
                    if text == '<CRLF>'
                ]
                label_list = example.label
                code_list = example.code
                relation_list = example.relations

                new_tokens = []
                new_segment_ids = []
                new_label_ids = []
                new_code = []

                new_tokens.append('[CLS]')
                new_segment_ids.append(0)
                new_label_ids.append(label2id['[CLS]'])
                new_code.append('0')

                for text, label, code in zip(text_list, label_list, code_list):
                    if text == '<CRLF>':
                        continue
                    else:
                        token_list = self.tokenizer.tokenize(text)
                        for idx, token in enumerate(token_list):
                            new_tokens.append(token)
                            new_segment_ids.append(0)
                            if idx == 0:
                                new_label_ids.append(label2id[label])
                                new_code.append(code)
                            elif label == 'O':
                                new_label_ids.append(label2id[label])
                                new_code.append(code)
                            else:
                                temp_l = 'I-' + label.split('-')[1]
                                new_label_ids.append(label2id[temp_l])
                                new_code.append(code)

                assert len(new_tokens) == len(new_segment_ids)
                assert len(new_tokens) == len(new_label_ids)
                assert len(new_tokens) == len(new_code)

                if len(new_tokens) >= max_seq_len:
                    new_tokens = new_tokens[0:(max_seq_len - 1)]
                    new_segment_ids = new_segment_ids[0:(max_seq_len - 1)]
                    new_label_ids = new_label_ids[0:(max_seq_len - 1)]
                    new_code = new_code[0:(max_seq_len - 1)]

                new_tokens.append('[SEP]')
                new_segment_ids.append(0)
                new_label_ids.append(label2id['[SEP]'])
                new_code.append('0')

                input_ids = self.tokenizer.convert_tokens_to_ids(new_tokens)
                input_mask = [1] * len(input_ids)
                input_len = len(new_label_ids)

                if len(input_ids) < max_seq_len:
                    pad_zero = [0] * (max_seq_len - len(input_ids))
                    input_ids.extend(pad_zero)
                    input_mask.extend(pad_zero)
                    new_segment_ids.extend(pad_zero)
                    new_label_ids.extend(pad_zero)
                    new_code.extend(['0'] * len(pad_zero))

                assert len(input_ids) == max_seq_len
                assert len(input_mask) == max_seq_len
                assert len(new_segment_ids) == max_seq_len
                assert len(new_label_ids) == max_seq_len
                assert len(new_code) == max_seq_len

                df_temp = pd.DataFrame({
                    'input_ids': input_ids,
                    'code': new_code
                })
                agg_fun = lambda s: (max(s['code']), s.index.tolist()[0],
                                     s.index.tolist()[-1])
                groupby_code = df_temp.groupby('code').apply(agg_fun)
                code_position = {}
                for key, start, end in groupby_code:
                    if key != '0':
                        code_position[(start - 1, end - 1)] = key
                    else:
                        continue

                if example_id < 2:
                    tools.logger.info('*** Examples: ***')
                    tools.logger.info("guid: %s" % (example.guid))
                    tools.logger.info("tokens: %s" %
                                      " ".join([str(x) for x in new_tokens]))
                    tools.logger.info("input_ids: %s" %
                                      " ".join([str(x) for x in input_ids]))
                    tools.logger.info("input_mask: %s" %
                                      " ".join([str(x) for x in input_mask]))
                    tools.logger.info(
                        "segment_ids: %s" %
                        " ".join([str(x) for x in new_segment_ids]))
                    tools.logger.info("old label name: %s " %
                                      " ".join(example.label))
                    tools.logger.info("new label ids: %s" %
                                      " ".join([str(x)
                                                for x in new_label_ids]))

                features.append(
                    InputFeature(
                        input_ids=input_ids,
                        input_mask=input_mask,
                        segment_ids=new_segment_ids,
                        label_id=new_label_ids,
                        input_len=input_len,
                        code=new_code,
                        new_tokens=new_tokens,
                        relations=relation_list,  # golden standard
                        hamd_id=hamd_id,
                        code_position=code_position))

                pbar(step=example_id)

            tools.logger.info('Saving features into cached file %s',
                              cached_file)
            torch.save(features, cached_file)
        return features

    def create_dataset(self, features, is_sorted=False):
        if is_sorted:
            tools.logger.info('sorted data by the length of input')
            features = sorted(features,
                              key=lambda x: x.input_len,
                              reverse=True)
        all_input_ids = torch.tensor([f.input_ids for f in features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in features],
                                     dtype=torch.long)
        all_input_lens = torch.tensor([f.input_len for f in features],
                                      dtype=torch.long)
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                all_label_ids, all_input_lens)

        return dataset
Ejemplo n.º 6
0
class BertProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def __init__(self, vocab_path, do_lower_case):
        self.tokenizer = BertTokenizer(vocab_path, do_lower_case)

    def get_train(self, data_file):
        """Gets a collection of `InputExample`s for the train set."""
        return self.read_data(data_file)

    def get_dev(self, data_file):
        """Gets a collection of `InputExample`s for the dev set."""
        return self.read_data(data_file)

    def get_test(self, lines):
        return lines

    def get_labels(self):
        """Gets the list of labels for this data set."""
        return ["0", "1"]

    @classmethod
    def read_data(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        if 'pkl' in str(input_file):
            lines = load_pickle(input_file)
        else:
            lines = input_file
        return lines

    def truncate_seq_pair(self, tokens_a, tokens_b, max_length):
        # This is a simple heuristic which will always truncate the longer sequence
        # one token at a time. This makes more sense than truncating an equal percent
        # of tokens from each, since if one sequence is very short then each token
        # that's truncated likely contains more information than a longer sequence.
        while True:
            total_length = len(tokens_a) + len(tokens_b)
            if total_length <= max_length:
                break
            if len(tokens_a) > len(tokens_b):
                tokens_a.pop()
            else:
                tokens_b.pop()

    def create_examples(self, lines, example_type, cached_examples_file):
        '''
        Creates examples for data
        '''
        pbar = ProgressBar(n_total=len(lines))
        if cached_examples_file.exists():
            logger.info("Loading examples from cached file %s", cached_examples_file)
            examples = torch.load(cached_examples_file)
        else:
            examples = []
            for i, line in enumerate(lines):
                guid = '%s-%d' % (example_type, i)
                text_a = line[0]
                label = line[1]
                text_b = None
                example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
                examples.append(example)
                pbar.batch_step(step=i, info={}, bar_type='create examples')
            logger.info("Saving examples into cached file %s", cached_examples_file)
            torch.save(examples, cached_examples_file)
        return examples

    def create_features(self, examples, max_seq_len, cached_features_file):
        '''
        # The convention in BERT is:
        # (a) For sequence pairs:
        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
        #  type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1
        # (b) For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids:   0   0   0   0  0     0   0
        '''
        pbar = ProgressBar(n_total=len(examples))
        if cached_features_file.exists():
            logger.info("Loading features from cached file %s", cached_features_file)
            features = torch.load(cached_features_file)
        else:
            features = []
            for ex_id, example in enumerate(examples):
                tokens_a = self.tokenizer.tokenize(example.text_a)
                tokens_b = None
                label_id = int(example.label)

                if example.text_b:
                    tokens_b = self.tokenizer.tokenize(example.text_b)
                    # Modifies `tokens_a` and `tokens_b` in place so that the total
                    # length is less than the specified length.
                    # Account for [CLS], [SEP], [SEP] with "- 3"
                    self.truncate_seq_pair(tokens_a, tokens_b, max_length=max_seq_len - 3)
                else:
                    # Account for [CLS] and [SEP] with '-2'
                    if len(tokens_a) > max_seq_len - 2:
                        tokens_a = tokens_a[:max_seq_len - 2]
                tokens = ['[CLS]'] + tokens_a + ['[SEP]']
                segment_ids = [0] * len(tokens)
                if tokens_b:
                    tokens += tokens_b + ['[SEP]']
                    segment_ids += [1] * (len(tokens_b) + 1)

                input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
                input_mask = [1] * len(input_ids)
                padding = [0] * (max_seq_len - len(input_ids))
                input_len = len(input_ids)

                input_ids += padding
                input_mask += padding
                segment_ids += padding

                assert len(input_ids) == max_seq_len
                assert len(input_mask) == max_seq_len
                assert len(segment_ids) == max_seq_len

                if ex_id < 2:
                    logger.info("*** Example ***")
                    logger.info(f"guid: {example.guid}" % ())
                    logger.info(f"tokens: {' '.join([str(x) for x in tokens])}")
                    logger.info(f"input_ids: {' '.join([str(x) for x in input_ids])}")
                    logger.info(f"input_mask: {' '.join([str(x) for x in input_mask])}")
                    logger.info(f"segment_ids: {' '.join([str(x) for x in segment_ids])}")

                feature = InputFeature(input_ids=input_ids,
                                       input_mask=input_mask,
                                       segment_ids=segment_ids,
                                       label_id=label_id,
                                       input_len=input_len)
                features.append(feature)
                pbar.batch_step(step=ex_id, info={}, bar_type='create features')
            logger.info("Saving features into cached file %s", cached_features_file)
            torch.save(features, cached_features_file)
        return features

    def create_dataset(self, features, is_sorted=False):
        # Convert to Tensors and build dataset
        if is_sorted:
            logger.info("sorted data by th length of input")
            features = sorted(features, key=lambda x: x.input_len, reverse=True)
        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        return dataset
Ejemplo n.º 7
0
class BertDetector(object):
    def __init__(self,
                 bert_model_dir=config.bert_model_dir,
                 bert_model_vocab=config.bert_model_vocab,
                 threshold=0.1):
        self.name = 'bert_detector'
        self.bert_model_dir = bert_model_dir
        self.bert_model_vocab = bert_model_vocab
        self.initialized_bert_detector = False
        self.threshold = threshold

    def check_bert_detector_initialized(self):
        if not self.initialized_bert_detector:
            self.initialize_bert_detector()

    def initialize_bert_detector(self):
        t1 = time.time()
        self.bert_tokenizer = BertTokenizer(vocab_file=self.bert_model_vocab)
        self.MASK_TOKEN = "[MASK]"
        self.MASK_ID = self.bert_tokenizer.convert_tokens_to_ids(
            [self.MASK_TOKEN])[0]
        # Prepare model
        self.model = BertForMaskedLM.from_pretrained(self.bert_model_dir)
        logger.debug("Loaded model ok, path: %s, spend: %.3f s." %
                     (self.bert_model_dir, time.time() - t1))
        self.initialized_bert_detector = True

    def _convert_sentence_to_detect_features(self, sentence):
        """Loads a sentence into a list of `InputBatch`s."""
        self.check_bert_detector_initialized()
        features = []
        tokens = self.bert_tokenizer.tokenize(sentence)
        token_ids = self.bert_tokenizer.convert_tokens_to_ids(tokens)
        for idx, token_id in enumerate(token_ids):
            masked_lm_labels = [-1] * len(token_ids)
            masked_lm_labels[idx] = token_id
            features.append(
                InputFeatures(input_ids=token_ids,
                              masked_lm_labels=masked_lm_labels,
                              input_tokens=tokens,
                              id=idx,
                              token=tokens[idx]))
        return features

    def predict_token_prob(self, sentence):
        self.check_bert_detector_initialized()
        result = []
        eval_features = self._convert_sentence_to_detect_features(sentence)

        for f in eval_features:
            input_ids = torch.tensor([f.input_ids])
            masked_lm_labels = torch.tensor([f.masked_lm_labels])
            outputs = self.model(input_ids, masked_lm_labels=masked_lm_labels)
            masked_lm_loss, predictions = outputs[:2]
            prob = np.exp(-masked_lm_loss.item())
            result.append([prob, f])
        return result

    def detect(self, sentence):
        """
        句子改错
        :param sentence: 句子文本
        :return: list[list], [error_word, begin_pos, end_pos, error_type]
        """
        maybe_errors = []
        for prob, f in self.predict_token_prob(sentence):
            logger.debug('prob:%s, token:%s, idx:%s' % (prob, f.token, f.id))
            if prob < self.threshold:
                maybe_errors.append([f.token, f.id, f.id + 1, ErrorType.char])
        return maybe_errors
def bert_sentence_pair_preprocessing(data: list,
                                     tokenizer: BertTokenizer,
                                     max_sequence_length=128):
    """
    Pre-processes an array of sentence pairs for input into bert. Sentence pairs are expected to be processed
    as given in data.py.

    Each sentence pair is tokenized and concatenated together by the [SEP] token for input into BERT

    :return: three tensors: [data_size, input_ids], [data_size, token_type_ids], [data_size, attention_mask]
    """

    max_bert_input_length = 0
    for sentence_pair in data:

        sentence_1_tokenized, sentence_2_tokenized = tokenizer.tokenize(
            sentence_pair['sentence_1']), tokenizer.tokenize(
                sentence_pair['sentence_2'])
        _truncate_seq_pair(sentence_1_tokenized, sentence_2_tokenized,
                           max_sequence_length -
                           3)  #accounting for positioning tokens

        max_bert_input_length = max(
            max_bert_input_length,
            len(sentence_1_tokenized) + len(sentence_2_tokenized) + 3)
        sentence_pair['sentence_1_tokenized'] = sentence_1_tokenized
        sentence_pair['sentence_2_tokenized'] = sentence_2_tokenized

    dataset_input_ids = torch.empty((len(data), max_bert_input_length),
                                    dtype=torch.long)
    dataset_token_type_ids = torch.empty((len(data), max_bert_input_length),
                                         dtype=torch.long)
    dataset_attention_masks = torch.empty((len(data), max_bert_input_length),
                                          dtype=torch.long)
    dataset_scores = torch.empty((len(data), 1), dtype=torch.float)

    for idx, sentence_pair in enumerate(data):
        tokens = []
        input_type_ids = []

        tokens.append("[CLS]")
        input_type_ids.append(0)
        for token in sentence_pair['sentence_1_tokenized']:
            tokens.append(token)
            input_type_ids.append(0)
        tokens.append("[SEP]")
        input_type_ids.append(0)

        for token in sentence_pair['sentence_2_tokenized']:
            tokens.append(token)
            input_type_ids.append(1)
        tokens.append("[SEP]")
        input_type_ids.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        attention_masks = [1] * len(input_ids)
        while len(input_ids) < max_bert_input_length:
            input_ids.append(0)
            attention_masks.append(0)
            input_type_ids.append(0)

        dataset_input_ids[idx] = torch.tensor(input_ids, dtype=torch.long)
        dataset_token_type_ids[idx] = torch.tensor(input_type_ids,
                                                   dtype=torch.long)
        dataset_attention_masks[idx] = torch.tensor(attention_masks,
                                                    dtype=torch.long)
        if 'similarity' not in sentence_pair or sentence_pair[
                'similarity'] is None:
            dataset_scores[idx] = torch.tensor(float('nan'), dtype=torch.float)
        else:
            dataset_scores[idx] = torch.tensor(sentence_pair['similarity'],
                                               dtype=torch.float)

    return dataset_input_ids, dataset_token_type_ids, dataset_attention_masks, dataset_scores
Ejemplo n.º 9
0
if __name__ == '__main__':
    args = parser.parse_args()
    assert os.path.exists(args.bert_model), '{} does not exist'.format(args.bert_model)
    assert os.path.exists(args.bert_vocab), '{} does not exist'.format(args.bert_vocab)
    assert args.topk > 0, '{} should be positive'.format(args.topk)

    print('Initialize BERT vocabulary from {}...'.format(args.bert_vocab))
    bert_tokenizer = BertTokenizer(vocab_file=args.bert_vocab)
    print('Initialize BERT model from {}...'.format(args.bert_model))
    config = BertConfig.from_json_file('./bert-base-uncased/config.json')
    bert_model = BertForMaskedLM.from_pretrained('./bert-base-uncased/pytorch_model.bin', config = config)

    while True:
        message = input('Enter your message: ').strip()
        tokens = bert_tokenizer.tokenize(message)
        if len(tokens) == 0:
            continue
        if tokens[0] != CLS:
            tokens = [CLS] + tokens
        if tokens[-1] != SEP:
            tokens.append(SEP)
        token_idx, segment_idx, mask = to_bert_input(tokens, bert_tokenizer)
        with torch.no_grad():
            logits = bert_model(token_idx, segment_idx, mask, masked_lm_labels=None)
        logits = np.squeeze(logits[0], axis=0)
        probs = torch.softmax(logits, dim=-1)

        mask_cnt = 0
        for idx, token in enumerate(tokens):
            if token == MASK:
Ejemplo n.º 10
0
class RuleBertWordDetector(object):
    def __init__(self,
                 language_model_path=config.language_model_path,
                 word_freq_path=config.word_freq_path,
                 char_freq_path=config.char_freq_path,
                 custom_word_freq_path=config.custom_word_freq_path,
                 custom_confusion_path=config.custom_confusion_path,
                 person_name_path=config.person_name_path,
                 place_name_path=config.place_name_path,
                 stopwords_path=config.stopwords_path,
                 bert_model_dir=config.bert_model_dir,
                 bert_model_vocab=config.bert_model_vocab,
                 threshold=0.1):
        self.name = 'rule_bert_word_detector'
        self.language_model_path = language_model_path
        self.word_freq_path = word_freq_path
        self.char_freq_path = char_freq_path
        self.custom_word_freq_path = custom_word_freq_path
        self.custom_confusion_path = custom_confusion_path
        self.person_name_path = person_name_path
        self.place_name_path = place_name_path
        self.stopwords_path = stopwords_path
        self.is_char_error_detect = True
        self.is_word_error_detect = True
        self.is_redundancy_miss_error_detect = True
        self.initialized_detector = False
        self.bert_model_dir = bert_model_dir
        self.bert_model_vocab = bert_model_vocab
        self.threshold = threshold

    def initialize_detector(self):
        t1 = time.time()
        try:
            import kenlm
        except ImportError:
            raise ImportError(
                'mypycorrector dependencies are not fully installed, '
                'they are required for statistical language model.'
                'Please use "pip install kenlm" to install it.'
                'if you are Win, Please install kenlm in cgwin.')

        self.lm = kenlm.Model(self.language_model_path)
        logger.debug('Loaded language model: %s, spend: %s s' %
                     (self.language_model_path, str(time.time() - t1)))

        # 词、频数dict
        t2 = time.time()
        self.word_freq = self.load_word_freq_dict(self.word_freq_path)
        self.char_freq = self.load_char_freq_dict(self.char_freq_path)
        t3 = time.time()
        logger.debug(
            'Loaded word freq, char freq file: %s, size: %d, spend: %s s' %
            (self.word_freq_path, len(self.word_freq), str(t3 - t2)))
        # 自定义混淆集
        self.custom_confusion = self._get_custom_confusion_dict(
            self.custom_confusion_path)
        t4 = time.time()
        logger.debug('Loaded confusion file: %s, size: %d, spend: %s s' %
                     (self.custom_confusion_path, len(
                         self.custom_confusion), str(t4 - t3)))
        # 自定义切词词典
        self.custom_word_freq = self.load_word_freq_dict(
            self.custom_word_freq_path)
        self.person_names = self.load_word_freq_dict(self.person_name_path)
        self.place_names = self.load_word_freq_dict(self.place_name_path)
        self.stopwords = self.load_word_freq_dict(self.stopwords_path)
        # 合并切词词典及自定义词典
        self.custom_word_freq.update(self.person_names)
        self.custom_word_freq.update(self.place_names)
        self.custom_word_freq.update(self.stopwords)

        self.word_freq.update(self.custom_word_freq)
        t5 = time.time()
        logger.debug('Loaded custom word file: %s, size: %d, spend: %s s' %
                     (self.custom_confusion_path, len(
                         self.custom_word_freq), str(t5 - t4)))
        self.tokenizer = Tokenizer(dict_path=self.word_freq_path,
                                   custom_word_freq_dict=self.custom_word_freq,
                                   custom_confusion_dict=self.custom_confusion)
        # bert预训练模型
        t6 = time.time()
        self.bert_tokenizer = BertTokenizer(vocab_file=self.bert_model_vocab)
        self.MASK_TOKEN = "[MASK]"
        self.MASK_ID = self.bert_tokenizer.convert_tokens_to_ids(
            [self.MASK_TOKEN])[0]
        # Prepare model
        self.model = BertForMaskedLM.from_pretrained(self.bert_model_dir)
        logger.debug("Loaded model ok, path: %s, spend: %.3f s." %
                     (self.bert_model_dir, time.time() - t6))
        self.initialized_detector = True

    def check_detector_initialized(self):
        if not self.initialized_detector:
            self.initialize_detector()

    def enable_char_error(self, enable=True):
        """
        is open char error detect
        :param enable:
        :return:
        """
        self.is_char_error_detect = enable

    def enable_word_error(self, enable=True):
        """
        is open word error detect
        :param enable:
        :return:
        """
        self.is_word_error_detect = enable

    def enable_redundancy_miss_error(self, enable=True):
        '''
        @Descripttion: is open redundancy miss error detect
        @param enable 
        @return: 
        '''
        self.is_redundancy_miss_error_detect = enable

    def _convert_sentence_to_detect_features(self, sentence):
        """Loads a sentence into a list of `InputBatch`s."""
        self.check_detector_initialized()
        features = []
        tokens = self.bert_tokenizer.tokenize(sentence)
        token_ids = self.bert_tokenizer.convert_tokens_to_ids(tokens)
        for idx, token_id in enumerate(token_ids):
            masked_lm_labels = [-1] * len(token_ids)
            masked_lm_labels[idx] = token_id
            features.append(
                InputFeatures(input_ids=token_ids,
                              masked_lm_labels=masked_lm_labels,
                              input_tokens=tokens,
                              id=idx,
                              token=tokens[idx]))
        return features

    # bert 预测可能的错误字
    def predict_token_prob(self, sentence):
        self.check_detector_initialized()
        result = []
        eval_features = self._convert_sentence_to_detect_features(sentence)

        for f in eval_features:
            input_ids = torch.tensor([f.input_ids])
            masked_lm_labels = torch.tensor([f.masked_lm_labels])
            outputs = self.model(input_ids, masked_lm_labels=masked_lm_labels)
            masked_lm_loss, predictions = outputs[:2]
            prob = np.exp(-masked_lm_loss.item())
            result.append([prob, f])
        return result

    @staticmethod
    def load_word_freq_dict(path):
        """
        加载切词词典
        :param path:
        :return:
        """
        word_freq = {}
        with codecs.open(path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line.startswith('#'):
                    continue
                info = line.split()
                if len(info) < 1:
                    continue
                word = info[0]
                # 取词频,默认1
                freq = int(info[1]) if len(info) > 1 else 1
                word_freq[word] = freq
        return word_freq

    @staticmethod
    def load_char_freq_dict(path):
        """
        加载常用字碎片词典
        :param path:
        :return:
        """
        char_freq = {}
        with codecs.open(path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line.startswith('#'):
                    continue
                info = line.split()
                if len(info) < 1:
                    continue
                char = info[0]
                # 取词频,默认1
                freq = int(info[1]) if len(info) > 1 else 1
                char_freq[char] = freq
        return char_freq

    def _get_custom_confusion_dict(self, path):
        """
        取自定义困惑集
        :param path:
        :return: dict, {variant: origin}, eg: {"交通先行": "交通限行"}
        """
        confusion = {}
        with codecs.open(path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line.startswith('#'):
                    continue
                info = line.split()
                if len(info) < 2:
                    continue
                variant = info[0]
                origin = info[1]
                freq = int(info[2]) if len(info) > 2 else 1
                self.word_freq[origin] = freq
                confusion[variant] = origin
        return confusion

    def ngram_score(self, chars):
        """
        取n元文法得分
        :param chars: list, 以词或字切分
        :return:
        """
        self.check_detector_initialized()
        return self.lm.score(' '.join(chars), bos=False, eos=False)

    def ppl_score(self, words):
        """
        取语言模型困惑度得分,越小句子越通顺
        :param words: list, 以词或字切分
        :return:
        """
        self.check_detector_initialized()
        return self.lm.perplexity(' '.join(words))

    def word_frequency(self, word):
        """
        取词在样本中的词频
        :param word:
        :return:
        """
        self.check_detector_initialized()
        return self.word_freq.get(word, 0)

    def set_word_frequency(self, word, num):
        """
        更新在样本中的词频
        """
        self.check_detector_initialized()
        self.word_freq[word] = num
        return self.word_freq

    @staticmethod
    def _check_contain_error(maybe_err, maybe_errors):
        """
        检测错误集合(maybe_errors)是否已经包含该错误位置(maybe_err)
        :param maybe_err: [error_word, begin_pos, end_pos, error_type]
        :param maybe_errors:
        :return:
        """
        error_word_idx = 0
        begin_idx = 1
        end_idx = 2
        for err in maybe_errors:
            if maybe_err[error_word_idx] in err[error_word_idx] and maybe_err[begin_idx] >= err[begin_idx] and \
                    maybe_err[end_idx] <= err[end_idx]:
                return True
        return False

    def _add_maybe_error_item(self, maybe_err, maybe_errors):
        """
        新增错误
        :param maybe_err:
        :param maybe_errors:
        :return:
        """
        if maybe_err not in maybe_errors and not self._check_contain_error(
                maybe_err, maybe_errors):
            maybe_errors.append(maybe_err)

    @staticmethod
    def is_filter_token(token):
        result = False
        # pass blank
        if not token.strip():
            result = True
        # pass punctuation
        if token in PUNCTUATION_LIST:
            result = True
        # pass num
        if token.isdigit():
            result = True
        # pass alpha
        if is_alphabet_string(token.lower()):
            result = True
        return result

    def detect(self, sentence):
        """
        检测句子中的疑似错误信息,包括[词、位置、错误类型]
        :param sentence:
        :return: list[list], [error_word, begin_pos, end_pos, error_type]
        """
        maybe_errors = []
        if not sentence.strip():
            return maybe_errors
        # 初始化
        self.check_detector_initialized()
        # 文本归一化
        sentence = uniform(sentence)
        # 切词
        tokens = self.tokenizer.tokenize(sentence)
        self.tokens = [token[0] for token in tokens]
        # print(tokens)
        # 自定义混淆集加入疑似错误词典
        # for confuse in self.custom_confusion:
        #     idx = sentence.find(confuse)
        #     if idx > -1:
        #         maybe_err = [confuse, idx, idx +
        #                      len(confuse), ErrorType.confusion]
        #         self._add_maybe_error_item(maybe_err, maybe_errors)

        if self.is_word_error_detect:
            # 未登录词加入疑似错误词典
            for word, begin_idx, end_idx in tokens:
                # pass filter word
                if self.is_filter_token(word):
                    continue
                # pass in dict
                if word in self.word_freq:
                    if self.is_redundancy_miss_error_detect:
                        # 多字词或词频大于50000的单字,可以continue
                        if len(
                                word
                        ) == 1 and word in self.char_freq and self.char_freq.get(
                                word) < 50000:
                            maybe_err = [
                                word, begin_idx, end_idx, ErrorType.word_char
                            ]
                            self._add_maybe_error_item(maybe_err, maybe_errors)
                            continue
                        # 出现叠字,考虑是否多字
                        if len(word) == 1 and sentence[begin_idx - 1] == word:
                            maybe_err = [
                                word, begin_idx, end_idx, ErrorType.redundancy
                            ]
                            self._add_maybe_error_item(maybe_err, maybe_errors)
                            continue
                    continue
                # 对碎片单字进行检测,可能多字、少字、错字
                if self.is_redundancy_miss_error_detect:
                    if len(word) == 1:
                        maybe_err = [
                            word, begin_idx, end_idx, ErrorType.word_char
                        ]
                        self._add_maybe_error_item(maybe_err, maybe_errors)
                        continue
                maybe_err = [word, begin_idx, end_idx, ErrorType.word]
                self._add_maybe_error_item(maybe_err, maybe_errors)

        if self.is_char_error_detect:
            # 语言模型检测疑似错误字
            try:
                for prob, f in self.predict_token_prob(sentence):
                    # logger.debug('prob:%s, token:%s, idx:%s' % (prob, f.token, f.id))
                    if prob < self.threshold:
                        maybe_err = [f.token, f.id, f.id + 1, ErrorType.char]
                        self._add_maybe_error_item(maybe_err, maybe_errors)
                # return maybe_errors
            except IndexError as ie:
                logger.warn("index error, sentence:" + sentence + str(ie))
            except Exception as e:
                logger.warn("detect error, sentence:" + sentence + str(e))
        return sorted(maybe_errors, key=lambda k: k[1], reverse=False)
Ejemplo n.º 11
0
tokenizer=BertTokenizer(vocab_file=vocabulary,do_lower_case=False)


tokenized_texts = []
word_piece_labels = []
i_inc = 0
for word_list,label in (zip(sentences,labels)):
    temp_lable = []
    temp_token = []
    
    # Add [CLS] at the front 
    temp_lable.append('[CLS]')
    temp_token.append('[CLS]')
    
    for word,lab in zip(word_list,label):
        token_list = tokenizer.tokenize(word)
        for m,token in enumerate(token_list):
            temp_token.append(token)
            if m==0:
                temp_lable.append(lab)
            else:
                temp_lable.append('X')  
                
    # Add [SEP] at the end
    temp_lable.append('[SEP]')
    temp_token.append('[SEP]')
    
    tokenized_texts.append(temp_token)
    word_piece_labels.append(temp_lable)
    
    if 5 > i_inc:
Ejemplo n.º 12
0
class BertProcessor(object):
    def __init__(self, vocab_path, do_lower_case, min_freq_words=None):
        self.tokenizer = BertTokenizer(vocab_file=vocab_path, do_lower_case=do_lower_case)

    def get_labels(self):
        return ['[CLS]', '[SEP]', 'O',
                'B-NIHSS', 'B-Measurement', 'B-TemporalConstraint', 'B-1a_LOC', 'B-1b_LOCQuestions', 'B-1c_LOCCommands', 'B-2_BestGaze', 'B-3_Visual', 'B-4_FacialPalsy', 'B-56_Motor', 'B-5_Motor', 'B-5a_LeftArm', 'B-5b_RightArm', 'B-6_Motor', 'B-6a_LeftLeg', 'B-6b_RightLeg', 'B-7_LimbAtaxia', 'B-8_Sensory', 'B-9_BestLanguage', 'B-10_Dysarthria', 'B-11_ExtinctionInattention',
                'I-NIHSS', 'I-Measurement', 'I-TemporalConstraint', 'I-1a_LOC', 'I-1b_LOCQuestions', 'I-1c_LOCCommands', 'I-2_BestGaze', 'I-3_Visual', 'I-4_FacialPalsy', 'I-56_Motor', 'I-5_Motor', 'I-5a_LeftArm', 'I-5b_RightArm', 'I-6_Motor', 'I-6a_LeftLeg', 'I-6b_RightLeg', 'I-7_LimbAtaxia', 'I-8_Sensory', 'I-9_BestLanguage', 'I-10_Dysarthria', 'I-11_ExtinctionInattention'
               ]

    def create_examples(self, lines, example_type):
        examples = []
        for i, line in enumerate(lines):
            hadm_id = line['HADM_ID']
            guid = '%s-%s-%d' % (example_type, hadm_id, i)
            sentence = line['token'] # list
            sentence = [' ' if type(t)==float else t for t in sentence ]
            label = line['tags']  # list
            code = line['code']
            # text_a: string. The untokenized text of the first sequence. For single
            # sequence tasks, only this sequence must be specified.
            text_a = ' '.join(sentence) # string
            text_b = None
            examples.append(InputExample(guid=guid, text_a=text_a,text_b=text_b, label=label, code=code))
        return examples

    def create_features(self, examples, max_seq_len):
        label_list = self.get_labels()
        label2id = {label:i for i, label in enumerate(label_list)}

        features = []
        for example_id, example in enumerate(examples): # examples:
            text_list = example.text_a.split(' ')  # string
            label_list = example.label
            code_list = example.code

            new_tokens = [] # tokens
            new_segment_ids =[]
            new_label_ids = []
            new_code = []

            new_tokens.append('[CLS]')
            new_segment_ids.append(0)
            new_label_ids.append(label2id['[CLS]'])
            new_code.append('0')
            
            for text, label, code in zip(text_list, label_list, code_list):
                if text == '<CRLF>':
                    continue
                else:
                    token_list = self.tokenizer.tokenize(text)
                    for idx, token in enumerate(token_list):
                        new_tokens.append(token)
                        new_segment_ids.append(0)
                        if idx == 0:
                            new_label_ids.append(label2id[label])
                            new_code.append(code)
                        elif label == 'O':
                            new_label_ids.append(label2id[label])
                            new_code.append(code)
                        else:
                            temp_l = 'I-'+label.split('-')[1]
                            new_label_ids.append(label2id[temp_l])
                            new_code.append(code)

            assert len(new_tokens) == len(new_segment_ids)
            assert len(new_tokens) == len(new_label_ids)
            assert len(new_tokens) == len(new_code)

            if len(new_tokens) >= max_seq_len :
                new_tokens = new_tokens[0:(max_seq_len-1)]
                new_segment_ids = new_segment_ids[0:(max_seq_len-1)]
                new_label_ids = new_label_ids[0:(max_seq_len-1)]
                new_code = new_code[0:(max_seq_len-1)]

            new_tokens.append('[SEP]')
            new_segment_ids.append(0)
            new_label_ids.append(label2id['[SEP]'])
            new_code.append('0')

            input_ids = self.tokenizer.convert_tokens_to_ids(new_tokens)
            input_mask = [1] * len(input_ids)
            input_len = len(new_label_ids)

            if len(input_ids) < max_seq_len:
                pad_zero = [0] * (max_seq_len - len(input_ids))
                input_ids.extend(pad_zero)
                input_mask.extend(pad_zero)
                new_segment_ids.extend(pad_zero)
                new_label_ids.extend(pad_zero)
                new_code.extend(['0']* len(pad_zero))

            assert len(input_ids) == max_seq_len
            assert len(input_mask) == max_seq_len
            assert len(new_segment_ids) == max_seq_len
            assert len(new_label_ids) == max_seq_len
            assert len(new_code) == max_seq_len

            df_temp = pd.DataFrame({'input_ids':input_ids, 'code':new_code})
            agg_fun = lambda s: ( max(s['code']), s.index.tolist()[0], s.index.tolist()[-1])
            groupby_code = df_temp.groupby('code').apply(agg_fun)
            code_position = {}
            for key, start, end in groupby_code:
                if key != '0':
                    code_position[key] = (start, end)
                else:
                    continue

            features.append(
                InputFeature(
                    input_ids = input_ids,
                    input_mask = input_mask,
                    segment_ids = new_segment_ids,
                    label_id = new_label_ids,
                    input_len = input_len,
                    code = new_code,
                    code_position = code_position
            ))

        return features
Ejemplo n.º 13
0
class BertCoder(object):
    def __init__(self,
                 filename,
                 bert_filename,
                 do_lower_case=False,
                 word_boundaries=False):
        self.filename = filename
        self.bert_filename = bert_filename
        self.do_lower_case = do_lower_case
        self.do_basic_tokenize = False
        # Hack around the fact that we need to know the word boundaries
        self.word_boundaries = word_boundaries

    def __len__(self):
        return self.tokenizer.vocab_size

    def fit(self, tokens):
        # NOTE: We allow the model to use default: do_basic_tokenize.
        # This potentially splits tokens into more tokens apart from subtokens:
        # eg. Mr.Doe -> Mr . D ##oe  (Note that . is not preceded by ##)
        # We take this into account when creating the token_flags in
        # function text_to_token_flags
        self.tokenizer = BertTokenizer(
            self.bert_filename,
            # do_basic_tokenize=self.do_basic_tokenize,
            do_lower_case=self.do_lower_case)
        return self

    def text_to_token_flags(self, text):
        """Return a tuple representing which subtokens are the beginning of a
        token. This is needed for NER using BERT:

        https://arxiv.org/pdf/1810.04805.pdf:

        "We use the representation of the first sub-token as the input to the
        token-level classifier over the NER label set."

        """
        text = self.tokenizer.basic_tokenizer._run_strip_accents(text)
        token_flags = []
        if self.do_lower_case:
            actual_split = text.lower().split()
        else:
            actual_split = text.split()

        bert_tokens = []
        for token in actual_split:
            local_bert_tokens = self.tokenizer.tokenize(token) or ['[UNK]']
            token_flags.append(1)
            for more in local_bert_tokens[1:]:
                token_flags.append(0)
            bert_tokens.extend(local_bert_tokens)
        # assert len(actual_tokens) == 0, [actual_tokens, actual_split, bert_tokens]
        assert len(token_flags) == len(bert_tokens), [
            actual_split, bert_tokens
        ]
        assert sum(token_flags) == len(actual_split)
        return tuple(token_flags)

    def encode(self, tokens):
        # Sometimes tokens include whitespace!
        # for sent_tokens in tokens:
        #     for token in sent_tokens:
        #         if ' ' in token:
        #             print(token)
        # The AIS dataset has a token ". .", for example.
        sent_tokens_no_ws = [[token.replace(' ', '') for token in sent_tokens]
                             for sent_tokens in tokens]
        texts = (' '.join(sent_tokens) for sent_tokens in sent_tokens_no_ws)
        if self.word_boundaries:
            encoded = tuple(self.text_to_token_flags(text) for text in texts)
            # encoded = tuple(tuple(0 if token.startswith('##') else 1
            #                       for token in self.tokenizer.tokenize(text))
            #                 for text in texts)
        else:
            # Adds CLS and SEP
            encoded = tuple(
                tuple(self.tokenizer.encode(text, add_special_tokens=True))
                for text in texts)
        return encoded

    def decode(self, ids):
        if self.word_boundaries:
            return []
        else:
            # NOTE: we only encode a single sentence, so use [0]
            return tuple(
                tuple(
                    self.tokenizer.decode(sent_ids,
                                          clean_up_tokenization_spaces=False)
                    [0].split()) for sent_ids in ids)

    def load(self, filename):
        self.tokenizer = BertTokenizer(
            filename,
            # do_basic_tokenize=self.do_basic_tokenize,
            do_lower_case=self.do_lower_case)
        return self

    def save(self, filename):
        copyfile(self.bert_filename, filename)
Ejemplo n.º 14
0
def tag_sent(text):
    # initialize variables
    num_tags = 24  # depends on the labelling scheme
    max_len = 45
    vocabulary = "bert_models/vocab.txt"
    bert_out_address = 'bert/model'

    tokenizer = BertTokenizer(vocab_file=vocabulary, do_lower_case=False)

    model = BertForTokenClassification.from_pretrained(bert_out_address,
                                                       num_labels=num_tags)

    f = open('se_data/tags.txt')
    lines = f.readlines()
    f.close()

    tag2idx = {}
    for line in lines:
        key = line.split()[0]
        val = line.split()[1]
        tag2idx[key.strip()] = int(val.strip())

    tag2name = {tag2idx[key]: key for key in tag2idx.keys()}

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    if torch.cuda.is_available():
        model.cuda()
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

    model.eval()

    tokenized_texts = []
    word_piece_labels = []
    i_inc = 0

    temp_token = []

    # Add [CLS] at the front
    temp_token.append('[CLS]')

    for word in nltk.word_tokenize(text):
        token_list = tokenizer.tokenize(word)
        for m, token in enumerate(token_list):
            temp_token.append(token)

    # Add [SEP] at the end
    temp_token.append('[SEP]')

    tokenized_texts.append(temp_token)

    #if 5 > i_inc:
    #print("No.%d,len:%d"%(i_inc,len(temp_token)))
    #print("texts:%s"%(" ".join(temp_token)))
    #i_inc +=1

    input_ids = pad_sequences(
        [tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts],
        maxlen=max_len,
        dtype="long",
        truncating="post",
        padding="post")

    attention_masks = [[int(i > 0) for i in ii] for ii in input_ids]
    #attention_masks[0];

    segment_ids = [[0] * len(input_id) for input_id in input_ids]
    #segment_ids[0];

    tr_inputs = torch.tensor(input_ids).to(device)
    tr_masks = torch.tensor(attention_masks).to(device)
    tr_segs = torch.tensor(segment_ids).to(device)

    outputs = model(
        tr_inputs,
        token_type_ids=None,
        attention_mask=tr_masks,
    )

    #tr_masks = tr_masks.to('cpu').numpy()

    logits = outputs[0]

    # Get NER predict result
    logits = torch.argmax(F.log_softmax(logits, dim=2), dim=2)
    logits = logits.detach().cpu().numpy()

    #print(logits)
    #print(len(logits[0]))
    tags_t = [tag2name[t] for t in logits[0]]

    #print(nltk.word_tokenize(text))
    c = len(tokenized_texts[0])
    #print(tags_t[:c])
    return tokenized_texts[0][1:len(temp_token) -
                              1], tags_t[:c][1:len(tags_t[:c]) - 1]