def __init__(self, t: PreTrainedTokenizer, args, file_path: str, block_size=512):
        assert os.path.isfile(file_path)
        logger.info("Creating features from dataset file at %s", file_path)
        
        # -------------------------- CHANGES START
        bert_tokenizer = os.path.join(args.tokenizer_name, "vocab.txt")
        if os.path.exists(bert_tokenizer):
            logger.info("Loading BERT tokenizer")
            from tokenizers import BertWordPieceTokenizer
            tokenizer = BertWordPieceTokenizer(os.path.join(args.tokenizer_name, "vocab.txt"), handle_chinese_chars=False, lowercase=False)
            tokenizer.enable_truncation(512)
        else:
            from tokenizers import ByteLevelBPETokenizer
            from tokenizers.processors import BertProcessing
            logger.info("Loading RoBERTa tokenizer")
            
            tokenizer = ByteLevelBPETokenizer(
                os.path.join(args.tokenizer_name, "vocab.json"),
                os.path.join(args.tokenizer_name, "merges.txt")
            )
            tokenizer._tokenizer.post_processor = BertProcessing(
                ("</s>", tokenizer.token_to_id("</s>")),
                ("<s>", tokenizer.token_to_id("<s>")),
            )
            tokenizer.enable_truncation(max_length=512)

        logger.info("Reading file %s", file_path)
        with open(file_path, encoding="utf-8") as f:
            lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]

        logger.info("Running tokenization")
        self.examples = tokenizer.encode_batch(lines)
예제 #2
0
class CustomDataset:
    def __init__(self, sentences, bert_path, padding=140):
        self.sentences = sentences
        self.tokenizer = BertWordPieceTokenizer(f'{bert_path}/vocab.txt',
                                                lowercase=True)
        self.padding = padding

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        s = self.sentences[idx]  #['[CLS]', *self.sentences[idx], '[SEP]']

        to_ignore_none = lambda x: x if x is not None else 0
        to_id = lambda x: to_ignore_none(self.tokenizer.token_to_id(x))

        n_pads = self.padding - len(s)
        x = list(map(to_id, s))
        assert (len(x) == len(s))
        x = x + [0 for _ in range(n_pads)]
        return torch.tensor(x), n_pads  #, torch.tensor([])
예제 #3
0
def train_tokenizer(filename, params):
    """
    Train a BertWordPieceTokenizer with the specified params and save it
    """
    # Get tokenization params
    save_location = params["tokenizer_path"]
    max_length = params["max_length"]
    min_freq = params["min_freq"]
    vocabsize = params["vocab_size"]

    tokenizer = BertWordPieceTokenizer()
    tokenizer.do_lower_case = False
    special_tokens = ["[S]","[PAD]","[/S]","[UNK]","[MASK]", "[SEP]","[CLS]"]
    tokenizer.train(files=[filename], vocab_size=vocabsize, min_frequency=min_freq, special_tokens = special_tokens)

    tokenizer._tokenizer.post_processor = BertProcessing(("[SEP]", tokenizer.token_to_id("[SEP]")), ("[CLS]", tokenizer.token_to_id("[CLS]")),)
    tokenizer.enable_truncation(max_length=max_length)

    print("Saving tokenizer ...")
    if not os.path.exists(save_location):
        os.makedirs(save_location)
    tokenizer.save(save_location)
예제 #4
0
def model_predict(inp, model=[]):
    # Called by default-python environment.
    # inp -- default is a string, but you can also specify
    # the type in "input_type.py".
    # model is optional and the return value of load_model.
    # Should return JSON.
    
    # predict all tokens
    text = inp.pred
    tokenizer = BertTokenizer.from_pretrained('models/src/models/vocab_swebert.txt', do_lower_case=False)
#     input_ids = tokenizer(text.lower())["input_ids"]
#     tokenizer = BertTokenizer.from_pretrained('src/models/vocab_swebert.txt', lowercase=True, strip_accents=False)
    bert_word_piece_tokenizer = BertWordPieceTokenizer("models/src/models/vocab_swebert.txt", lowercase=True, strip_accents=False)
    output = bert_word_piece_tokenizer.encode(text)
    tokens = output.tokens
    indexed_tokens = output.ids
    input_ids = indexed_tokens
    print(tokens)
    
    # mask one of the tokens
    masked_index = inp.msk_ind
    tokens[masked_index] = '[MASK]'
    print(tokens)
#     input_ids[masked_index] = tokenizer.convert_tokens_to_ids('[MASK]')
    indexed_tokens[masked_index] = bert_word_piece_tokenizer.token_to_id('[MASK]')
    print(input_ids)

    # do predictions
    with torch.no_grad(): #deactivate the autograd engine to reduce memory usage and speed up
        outputs = model(torch.tensor([input_ids]))
    predictions = outputs[0]
    
    predicted_index_top5 = torch.argsort(predictions[0, masked_index], descending=True)[:5]
    predicted_token = tokenizer.convert_ids_to_tokens(predicted_index_top5)
#   predicted_index_top5
    print(predicted_token)
    return {"result": predicted_token}
예제 #5
0
class Tokenizer(object):
    def __init__(self, vocab_path, do_lower_case=True):
        if BertWordPieceTokenizer:
            self.tokenizer = BertWordPieceTokenizer(
                vocab_path,
                lowercase=do_lower_case,
            )
        else:
            self.tokenizer = tokenization.FullTokenizer(
                vocab_path, do_lower_case=do_lower_case)
        self._do_lower_case = do_lower_case

    def tokenize(self, input_text):
        if BertWordPieceTokenizer:
            return self.tokenizer.encode(input_text,
                                         add_special_tokens=False).tokens
        else:
            return self.tokenizer.tokenize(input_text)

    def encode(self, input_text, add_special_tokens=False):
        input_tokens = self.tokenize(input_text)
        if add_special_tokens:
            input_tokens = ['[CLS]'] + input_tokens + ['[SEP]']
        input_token_ids = self.convert_tokens_to_ids(input_tokens)
        return input_token_ids

    def padded_to_ids(self, input_text, max_length):
        if len(input_text) > max_length:
            return input_text[:max_length]
        else:
            return input_text + [0] * (max_length - len(input_text))

    def convert_tokens_to_ids(self, input_tokens):
        if BertWordPieceTokenizer:
            token_ids = [
                self.tokenizer.token_to_id(token) for token in input_tokens
            ]
        else:
            token_ids = self.tokenizer.convert_tokens_to_ids(input_tokens)
        return token_ids

    def customize_tokenize(self, input_text):
        temp_x = ""
        for c in input_text:
            if self._is_cjk_character(c) or self._is_punctuation(
                    c) or self._is_space(c) or self._is_control(c):
                temp_x += " " + c + " "
            else:
                temp_x += c
        return temp_x.split()

    def convert_ids_to_tokens(self, input_ids):
        if BertWordPieceTokenizer:
            input_tokens = [
                self.tokenizer.id_to_token(ids) for ids in input_ids
            ]
        else:
            input_tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
        return input_tokens

    def decode(self, input_tokens):
        text, flag = '', False
        for i, token in enumerate(input_tokens):
            if token[:2] == '##':
                text += token[2:]
            elif len(token) == 1 and self._is_cjk_character(token):
                text += token
            elif len(token) == 1 and self._is_punctuation(token):
                text += token
                text += ' '
            elif i > 0 and self._is_cjk_character(text[-1]):
                text += token
            else:
                text += ' '
                text += token
        text = re.sub(' +', ' ', text)
        text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text)
        punctuation = self._cjk_punctuation() + '+-/={(<['
        punctuation_regex = '|'.join([re.escape(p) for p in punctuation])
        punctuation_regex = '(%s) ' % punctuation_regex
        text = re.sub(punctuation_regex, '\\1', text)
        text = re.sub('(\d\.) (\d)', '\\1\\2', text)

        return text.strip()

    @staticmethod
    def stem(token):
        """
    """
        if token[:2] == '##':
            return token[2:]
        else:
            return token

    @staticmethod
    def _is_space(ch):
        """
    """
        return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \
            unicodedata.category(ch) == 'Zs'

    @staticmethod
    def _is_punctuation(ch):
        """
    """
        code = ord(ch)
        return 33 <= code <= 47 or \
            58 <= code <= 64 or \
            91 <= code <= 96 or \
            123 <= code <= 126 or \
            unicodedata.category(ch).startswith('P')

    @staticmethod
    def _cjk_punctuation():
        return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002'

    @staticmethod
    def _is_cjk_character(ch):
        """
    reference:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
    """
        code = ord(ch)
        return 0x4E00 <= code <= 0x9FFF or \
            0x3400 <= code <= 0x4DBF or \
            0x20000 <= code <= 0x2A6DF or \
            0x2A700 <= code <= 0x2B73F or \
            0x2B740 <= code <= 0x2B81F or \
            0x2B820 <= code <= 0x2CEAF or \
            0xF900 <= code <= 0xFAFF or \
            0x2F800 <= code <= 0x2FA1F

    @staticmethod
    def _is_control(ch):
        """
    """
        return unicodedata.category(ch) in ('Cc', 'Cf')

    @staticmethod
    def _is_special(ch):
        """
    """
        return bool(ch) and (ch[0] == '[') and (ch[-1] == ']')

    def rematch(self, text, tokens):
        if is_py2:
            text = unicode(text)

        if self._do_lower_case:
            text = text.lower()

        normalized_text, char_mapping = '', []
        for i, ch in enumerate(text):
            if self._do_lower_case:
                ch = unicodedata.normalize('NFD', ch)
                ch = ''.join(
                    [c for c in ch if unicodedata.category(c) != 'Mn'])
            ch = ''.join([
                c for c in ch
                if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c))
            ])
            normalized_text += ch
            char_mapping.extend([i] * len(ch))

        text, token_mapping, offset = normalized_text, [], 0
        for token in tokens:
            if self._is_special(token):
                token_mapping.append([])
            else:
                token = self.stem(token)
                start = text[offset:].index(token) + offset
                end = start + len(token)
                token_mapping.append(char_mapping[start:end])
                offset = end

        return token_mapping
class BertClassifierDataBuilder(ClassifierDataBuilder):
    def __init__(self, config):
        super().__init__(config)
        self.preprocessor = BertWordPieceTokenizer(config.vocab_file,
                                                   lowercase=config.lower_case)
        self.cls_id = self.preprocessor.token_to_id('[CLS]')
        self.sep_id = self.preprocessor.token_to_id('[SEP]')
        self.max_seq_length = config.max_seq_length

    def build_one_input_ids(self, seq1_codes, seq2_codes):
        seq1_ids = seq1_codes.ids
        seq2_ids = seq2_codes.ids
        seq1_tokens = seq1_codes.tokens
        seq2_tokens = seq2_codes.tokens
        """Truncates a sequence pair in place to the maximum length."""
        # This is a simple heuristic which will always truncate the longer
        # sequence one token at a time. This makes more sense than
        # truncating an equal percent of tokens from each, since if one
        # sequence is very short then each token that's truncated likely
        # contains more information than a longer sequence.
        while True:
            total_length = len(seq1_tokens) + len(seq2_tokens)
            if total_length <= self.max_seq_length - 3:
                # Account for [CLS], [SEP], [SEP] with "- 3"
                # logger.info('truncation finished.')
                break
            if len(seq1_tokens) > len(seq2_tokens):
                seq1_tokens.pop()
                seq1_ids.pop()
            else:
                seq2_tokens.pop()
                seq2_ids.pop()

        first_part_ids = [self.cls_id] + seq1_ids + [self.sep_id]
        second_part_ids = seq2_ids + [self.sep_id]
        input_ids = first_part_ids + second_part_ids
        segment_ids = [0] * len(first_part_ids) + [1] * len(second_part_ids)
        # pad to max_seq_length
        input_len = len(input_ids)
        input_ids += [0] * (self.max_seq_length - input_len)
        segment_ids += [0] * (self.max_seq_length - input_len)
        return input_ids, segment_ids, seq1_tokens, seq2_tokens

    def build_ids(self, seq1_list, seq2_codes):
        part1_ids = []
        part2_ids = []
        seq1_tokens = []
        seq2_tokens = []
        for s1 in seq1_list:
            s1_codes = self.preprocessor.encode(s1, add_special_tokens=False)
            one_output = self.build_one_input_ids(s1_codes, seq2_codes)
            p1_ids, sp2_ids, s1_tokens, s2_tokens = one_output
            part1_ids.extend(p1_ids)
            part2_ids.extend(sp2_ids)
            seq1_tokens.extend(s1_tokens)
            seq2_tokens.extend(s2_tokens)
        return part1_ids, part2_ids, seq1_tokens, seq2_tokens

    def set_ids(self, feature_dict, one_output):
        input_ids, segment_ids, seq1_tokens, seq2_tokens = one_output
        feature_dict['input_ids'] = input_ids
        feature_dict['segment_ids'] = segment_ids
        feature_dict['seq1_tokens'] = seq1_tokens
        feature_dict['seq2_tokens'] = seq2_tokens
        return feature_dict

    def input_to_feature(self, one_input):
        if len(one_input) == 3:
            eid, seq1, seq2 = one_input
            label = None
        elif len(one_input) == 4:
            eid, seq1, seq2, label = one_input
        else:
            raise ValueError('number of inputs not valid error: {}'.format(
                len(one_input)))

        seq2_codes = self.preprocessor.encode(seq2, add_special_tokens=False)
        ans_cls = self.may_process_label(label, None)
        feature_dict = {
            'feature_id': eid,
            'label': label,
            'cls': ans_cls,
            'seq1': seq1,
            'seq2': seq2,
        }
        if isinstance(seq1, list):  # for race multiple choice
            seq1_list = seq1
        else:
            seq1_list = [seq1]  # for boolq, mnli, qqp
        one_output = self.build_ids(seq1_list, seq2_codes)
        feature_dict = self.set_ids(feature_dict, one_output)
        self.num_examples += 1
        feature_id = '{}_{}'.format(self.num_examples,
                                    feature_dict['feature_id'])
        feature_dict['feature_id'] = feature_id
        feature = self.feature(**feature_dict)
        yield feature

    @staticmethod
    def two_seq_str_fn(feat):
        seq1_str = [
            '|{:>5}|{:>15}|{:>10}|{:>10}'.format('seq1_idx', 'token',
                                                 'input_idx', 'input_id')
        ]
        seq1_str.extend([
            '|{:>5}|{:>15}|{:>10}|{:>10}'.format(q_idx, q_token, q_idx + 1,
                                                 feat.input_ids[q_idx + 1])
            for q_idx, q_token in enumerate(feat.seq1_tokens)
        ])

        seq2_str = [
            '|{:>5}|{:>15}|{:>10}|{:>10}'.format('seq2_idx', 'token',
                                                 'input_idx', 'input_id')
        ]
        seq1_len = len(feat.seq1_tokens)
        seq2_str.extend([
            '|{:>5}|{:>15}|{:>10}|{:>10}'.format(
                c_idx, c_token, c_idx + 2 + seq1_len,
                feat.input_ids[c_idx + 2 + seq1_len])
            for c_idx, c_token in enumerate(feat.seq2_tokens)
        ])

        return seq1_str, seq2_str
예제 #7
0
class Reader(object):
    def __init__(self,
                 bert_model: str,
                 tokenizer: BaseTokenizer = None,
                 cls: str = "[CLS]",
                 sep: str = "[SEP]",
                 threshold=6):

        self.tokenizer: BaseTokenizer = tokenizer
        self.cls = cls
        self.sep = sep
        if self.tokenizer is None:
            vocab_path: str = "tokenization/" + bert_model + ".txt"
            self.tokenizer = BertWordPieceTokenizer(vocab_path,
                                                    lowercase="-cased"
                                                    not in bert_model)

        self.threshold = threshold
        self.subword_alphabet: Optional[Alphabet] = None
        self.label_alphabet: Optional[Alphabet] = None

        self.train: Optional[List[SentInst]] = None
        self.dev: Optional[List[SentInst]] = None
        self.test: Optional[List[SentInst]] = None

    def _read_file(self, filename: str, mode: str = 'train') -> List[SentInst]:
        sent_list = []
        max_len = 0
        num_thresh = 0
        with open(filename, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line == "":  # last few blank lines
                    break

                raw_tokens = line.split(' ')
                tokens = raw_tokens
                chars = [list(t) for t in raw_tokens]

                entities = next(f).strip()
                if entities == "":  # no entities
                    sent_inst = SentInst(tokens, chars, [])
                else:
                    entity_list = []
                    entities = entities.split("|")
                    for item in entities:
                        pointers, label = item.split()
                        pointers = pointers.split(",")
                        if int(pointers[1]) > len(tokens):
                            pdb.set_trace()
                        span_len = int(pointers[1]) - int(pointers[0])
                        if span_len < 0:
                            print("Warning! span_len < 0")
                            continue
                        if span_len > max_len:
                            max_len = span_len
                        if span_len > self.threshold:
                            num_thresh += 1

                        new_entity = (int(pointers[0]), int(pointers[1]),
                                      label)
                        # may be duplicate entities in some datasets
                        if (mode == 'train' and new_entity
                                not in entity_list) or (mode != 'train'):
                            entity_list.append(new_entity)

                    # assert len(entity_list) == len(set(entity_list)) # check duplicate
                    sent_inst = SentInst(tokens, chars, entity_list)
                assert next(f).strip() == ""  # separating line

                sent_list.append(sent_inst)
        print("Max length: {}".format(max_len))
        print("Threshold {}: {}".format(self.threshold, num_thresh))
        return sent_list

    def _gen_dic(self) -> None:
        label_set = set()

        for sent_list in [self.train, self.dev, self.test]:
            num_mention = 0
            for sentInst in sent_list:
                for entity in sentInst.entities:
                    label_set.add(entity[2])
                num_mention += len(sentInst.entities)
            print("# mentions: {}".format(num_mention))

        vocab = [
            self.tokenizer.id_to_token(idx)
            for idx in range(self.tokenizer.get_vocab_size())
        ]
        self.subword_alphabet = Alphabet(vocab, 0)
        self.label_alphabet = Alphabet(label_set, 0)

    @staticmethod
    def _pad_batches(input_ids_batches: List[List[List[int]]],
                     first_subtokens_batches: List[List[List[int]]]) \
            -> Tuple[List[List[List[int]]],
                     List[List[List[int]]],
                     List[List[List[bool]]]]:

        padded_input_ids_batches = []
        input_mask_batches = []
        mask_batches = []

        all_batches = list(zip(input_ids_batches, first_subtokens_batches))
        for input_ids_batch, first_subtokens_batch in all_batches:

            batch_len = len(input_ids_batch)
            max_subtokens_num = max(
                [len(input_ids) for input_ids in input_ids_batch])
            max_sent_len = max([
                len(first_subtokens)
                for first_subtokens in first_subtokens_batch
            ])

            padded_input_ids_batch = []
            input_mask_batch = []
            mask_batch = []

            for i in range(batch_len):

                subtokens_num = len(input_ids_batch[i])
                sent_len = len(first_subtokens_batch[i])

                padded_subtoken_vec = input_ids_batch[i].copy()
                padded_subtoken_vec.extend([0] *
                                           (max_subtokens_num - subtokens_num))
                input_mask = [1] * subtokens_num + [0] * (max_subtokens_num -
                                                          subtokens_num)
                mask = [True] * sent_len + [False] * (max_sent_len - sent_len)

                padded_input_ids_batch.append(padded_subtoken_vec)
                input_mask_batch.append(input_mask)
                mask_batch.append(mask)

            padded_input_ids_batches.append(padded_input_ids_batch)
            input_mask_batches.append(input_mask_batch)
            mask_batches.append(mask_batch)

        return padded_input_ids_batches, input_mask_batches, mask_batches

    def get_batches(self, sentences: List[SentInst], batch_size: int) -> Tuple:
        subtoken_dic_dic = defaultdict(lambda: defaultdict(list))
        first_subtoken_dic_dic = defaultdict(lambda: defaultdict(list))
        last_subtoken_dic_dic = defaultdict(lambda: defaultdict(list))
        label_dic_dic = defaultdict(lambda: defaultdict(list))

        this_input_ids_batches = []
        this_first_subtokens_batches = []
        this_last_subtokens_batches = []
        this_label_batches = []

        for sentInst in sentences:
            subtoken_vec = []
            first_subtoken_vec = []
            last_subtoken_vec = []
            subtoken_vec.append(self.tokenizer.token_to_id(self.cls))
            for t in sentInst.tokens:
                encoding = self.tokenizer.encode(t)
                ids = [
                    v for v, mask in zip(encoding.ids,
                                         encoding.special_tokens_mask)
                    if mask == 0
                ]
                first_subtoken_vec.append(len(subtoken_vec))
                subtoken_vec.extend(ids)
                last_subtoken_vec.append(len(subtoken_vec))
            subtoken_vec.append(self.tokenizer.token_to_id(self.sep))

            label_list = [(u[0], u[1], self.label_alphabet.get_index(u[2]))
                          for u in sentInst.entities]

            subtoken_dic_dic[len(
                sentInst.tokens)][len(subtoken_vec)].append(subtoken_vec)
            first_subtoken_dic_dic[len(
                sentInst.tokens)][len(subtoken_vec)].append(first_subtoken_vec)
            last_subtoken_dic_dic[len(
                sentInst.tokens)][len(subtoken_vec)].append(last_subtoken_vec)
            label_dic_dic[len(
                sentInst.tokens)][len(subtoken_vec)].append(label_list)

        input_ids_batches = []
        first_subtokens_batches = []
        last_subtokens_batches = []
        label_batches = []
        for length1 in sorted(subtoken_dic_dic.keys(), reverse=True):
            for length2 in sorted(subtoken_dic_dic[length1].keys(),
                                  reverse=True):
                input_ids_batches.extend(subtoken_dic_dic[length1][length2])
                first_subtokens_batches.extend(
                    first_subtoken_dic_dic[length1][length2])
                last_subtokens_batches.extend(
                    last_subtoken_dic_dic[length1][length2])
                label_batches.extend(label_dic_dic[length1][length2])

        [
            this_input_ids_batches.append(input_ids_batches[i:i + batch_size])
            for i in range(0, len(input_ids_batches), batch_size)
        ]
        [
            this_first_subtokens_batches.append(
                first_subtokens_batches[i:i + batch_size])
            for i in range(0, len(first_subtokens_batches), batch_size)
        ]
        [
            this_last_subtokens_batches.append(
                last_subtokens_batches[i:i + batch_size])
            for i in range(0, len(last_subtokens_batches), batch_size)
        ]
        [
            this_label_batches.append(label_batches[i:i + batch_size])
            for i in range(0, len(label_batches), batch_size)
        ]

        this_input_ids_batches, this_input_mask_batches, this_mask_batches \
            = self._pad_batches(this_input_ids_batches, this_first_subtokens_batches)

        return (this_input_ids_batches, this_input_mask_batches,
                this_first_subtokens_batches, this_last_subtokens_batches,
                this_label_batches, this_mask_batches)

    def to_batch(self, batch_size: int) -> Tuple:
        ret_list = []
        for sent_list in [self.train, self.dev, self.test]:
            ret_list.append(self.get_batches(sent_list, batch_size))
        return tuple(ret_list)

    def read_all_data(self, file_path: str, train_file: str, dev_file: str,
                      test_file: str) -> None:
        self.train = self._read_file(file_path + train_file)
        self.dev = self._read_file(file_path + dev_file, mode='dev')
        self.test = self._read_file(file_path + test_file, mode='test')
        self._gen_dic()

    def debug_single_sample(self, subtoken: List[int],
                            label_list: List[Tuple[int, int, int]]) -> None:
        print(" ".join(
            [self.subword_alphabet.get_instance(t) for t in subtoken]))
        for label in label_list:
            print(label[0], label[1],
                  self.label_alphabet.get_instance(label[2]))
예제 #8
0
파일: numerize.py 프로젝트: alexa/ramen
def numerize(vocab_path, input_path, bin_path):
    tokenizer = BertWordPieceTokenizer(vocab_path,
                                       unk_token=UNK_TOKEN,
                                       sep_token=SEP_TOKEN,
                                       cls_token=CLS_TOKEN,
                                       pad_token=PAD_TOKEN,
                                       mask_token=MASK_TOKEN,
                                       lowercase=False,
                                       strip_accents=False)
    sentences = []
    with open(input_path, 'r') as f:
        batch_stream = []
        for i, line in enumerate(f):
            batch_stream.append(line)
            if i % 1000 == 0:
                res = tokenizer.encode_batch(batch_stream)
                batch_stream = []
                # flatten the list
                for s in res:
                    sentences.extend(s.ids[1:])
            if i % 100000 == 0:
                print(f'processed {i} lines')

    print('convert the data to numpy')

    # convert data to numpy format in uint16
    if tokenizer.get_vocab_size() < 1 << 16:
        sentences = np.uint16(sentences)
    else:
        assert tokenizer.get_vocab_size() < 1 << 31
        sentences = np.int32(sentences)

    # save special tokens for later processing
    sep_index = tokenizer.token_to_id(SEP_TOKEN)
    cls_index = tokenizer.token_to_id(CLS_TOKEN)
    unk_index = tokenizer.token_to_id(UNK_TOKEN)
    mask_index = tokenizer.token_to_id(MASK_TOKEN)
    pad_index = tokenizer.token_to_id(PAD_TOKEN)

    # sanity check
    assert sep_index == SEP_INDEX
    assert cls_index == CLS_INDEX
    assert unk_index == UNK_INDEX
    assert pad_index == PAD_INDEX
    assert mask_index == MASK_INDEX

    print('collect statistics')
    # collect some statistics of the dataset
    n_unks = (sentences == unk_index).sum()
    n_toks = len(sentences)
    p_unks = n_unks * 100. / n_toks
    n_seqs = (sentences == sep_index).sum()
    print(
        f'| {n_seqs} sentences - {n_toks} tokens - {p_unks:.2f}% unknown words'
    )

    # print some statistics
    data = {
        'sentences': sentences,
        'sep_index': sep_index,
        'cls_index': cls_index,
        'unk_index': unk_index,
        'pad_index': pad_index,
        'mask_index': mask_index
    }

    torch.save(data, bin_path, pickle_protocol=4)
예제 #9
0
class Tweets(Dataset):
    def __init__(self, device='cpu', pad=150, test=False, N=4):
        self.samples = []
        self.pad = pad

        self.tokenizer = BertWordPieceTokenizer(
            "./data/bert-base-uncased-vocab.txt",
            lowercase=True,
            clean_text=True)

        self.tokenizer.enable_padding(max_length=pad -
                                      1)  # -1 for sentiment token

        self.tokenizer.add_special_tokens(['[POS]'])
        self.tokenizer.add_special_tokens(['[NEG]'])
        self.tokenizer.add_special_tokens(['[NEU]'])
        self.vocab = self.tokenizer.get_vocab()

        self.sent_t = {
            'positive': self.tokenizer.token_to_id('[POS]'),
            'negative': self.tokenizer.token_to_id('[NEG]'),
            'neutral': self.tokenizer.token_to_id('[NEU]')
        }

        self.pos_set = {'UNK': 0}
        all_pos = load('help/tagsets/upenn_tagset.pickle').keys()

        for i, p in enumerate(all_pos):
            self.pos_set[p] = i + 1

        self.tweet_tokenizer = TweetTokenizer()

        data = None
        if test is True:
            data = pd.read_csv(TEST_PATH).values
            for row in data:
                tid, tweet, sentiment = tuple(row)

                pos_membership = [0] * len(tweet)

                pos_tokens = self.tweet_tokenizer.tokenize(tweet)
                pos = nltk.pos_tag(pos_tokens)
                offset = 0

                for i, token in enumerate(pos_tokens):
                    start = tweet.find(token, offset)
                    end = start + len(token)
                    if pos[i][1] in self.pos_set:
                        pos_membership[start:end] = [self.pos_set[pos[i][1]]
                                                     ] * len(token)
                    offset += len(token)

                tokens = self.tokenizer.encode(tweet)
                word_to_index = tokens.ids
                offsets = tokens.offsets

                token_pos = [0] * len(word_to_index)
                # get pos info
                for i, (s, e) in enumerate(offsets):
                    if word_to_index[i] == 0 or word_to_index[
                            i] == 101 or word_to_index[i] == 102:
                        pass
                    elif s != e:
                        sub = pos_membership[s:e]
                        token_pos[i] = max(set(sub), key=sub.count)

                token_pos = [0] + token_pos
                word_to_index = [self.sent_t[sentiment]] + word_to_index
                offsets = [(0, 0)] + offsets
                offsets = np.array([[off[0], off[1]] for off in offsets])
                word_to_index = np.array(word_to_index)
                token_pos = np.array(token_pos)

                self.samples.append({
                    'tid': tid,
                    'sentiment': sentiment,
                    'tweet': word_to_index,
                    'offsets': offsets,
                    'raw_tweet': tweet,
                    'pos': token_pos
                })

        else:

            data = pd.read_csv(TRAIN_PATH).values
            if N > 0:
                data = augment_n(data, N=N)

            for row in data:
                tid, tweet, selection, sentiment = tuple(row)

                char_membership = [0] * len(tweet)
                pos_membership = [0] * len(tweet)
                si = tweet.find(selection)
                if si < 0:
                    char_membership[0:] = [1] * len(char_membership)
                else:
                    char_membership[si:si +
                                    len(selection)] = [1] * len(selection)

                pos_tokens = self.tweet_tokenizer.tokenize(tweet)
                pos = nltk.pos_tag(pos_tokens)
                offset = 0

                for i, token in enumerate(pos_tokens):
                    start = tweet.find(token, offset)
                    end = start + len(token)
                    if pos[i][1] in self.pos_set:
                        pos_membership[start:end] = [self.pos_set[pos[i][1]]
                                                     ] * len(token)
                    offset += len(token)

                tokens = self.tokenizer.encode(tweet)
                word_to_index = tokens.ids
                offsets = tokens.offsets

                token_membership = [0] * len(word_to_index)
                token_pos = [0] * len(word_to_index)

                # Inclusive indices
                start = None
                end = None
                for i, (s, e) in enumerate(offsets):
                    if word_to_index[i] == 0 or word_to_index[
                            i] == 101 or word_to_index[i] == 102:
                        token_membership[i] = -1
                    elif sum(char_membership[s:e]) > 0:
                        token_membership[i] = 1
                        if start is None:
                            start = i + 1
                        end = i + 1

                # get pos info
                for i, (s, e) in enumerate(offsets):
                    if word_to_index[i] == 0 or word_to_index[
                            i] == 101 or word_to_index[i] == 102:
                        pass
                    elif s != e:
                        sub = pos_membership[s:e]
                        token_pos[i] = max(set(sub), key=sub.count)

                if start is None:
                    print("Data Point Error")
                    print(tweet)
                    print(selection)
                    continue
                # token_membership = torch.LongTensor(token_membership).to(device)
                word_to_index = [self.sent_t[sentiment]] + word_to_index
                token_membership = [-1] + token_membership
                offsets = [(0, 0)] + offsets
                token_pos = [0] + token_pos

                offsets = np.array([[off[0], off[1]] for off in offsets])
                word_to_index = np.array(word_to_index)
                token_membership = np.array(token_membership).astype('float')
                token_pos = np.array(token_pos)

                if tid is None:
                    raise Exception('None field detected')
                if sentiment is None:
                    raise Exception('None field detected')
                if word_to_index is None:
                    raise Exception('None field detected')
                if token_membership is None:
                    raise Exception('None field detected')
                if selection is None:
                    raise Exception('None field detected')
                if tweet is None:
                    raise Exception('None field detected')
                if start is None:
                    raise Exception('None field detected')
                if end is None:
                    raise Exception('None field detected')
                if offsets is None:
                    raise Exception('None field detected')

                self.samples.append({
                    'tid': tid,
                    'sentiment': sentiment,
                    'tweet': word_to_index,
                    'selection': token_membership,
                    'raw_selection': selection,
                    'raw_tweet': tweet,
                    'start': start,
                    'end': end,
                    'offsets': offsets,
                    'pos': token_pos
                })

    def get_splits(self, val_size=.3):
        N = len(self.samples)
        indices = np.random.permutation(N)
        split = int(N * (1 - val_size))
        train_indices = indices[0:split]
        valid_indices = indices[split:]
        return train_indices, valid_indices

    def k_folds(self, k=5):
        N = len(self.samples)
        indices = np.random.permutation(N)
        return np.array_split(indices, k)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        try:
            return self.samples[idx]
        except TypeError:
            pass
        return [self.samples[i] for i in idx]
예제 #10
0
class BertQaDataBuilder(QaDataBuilder):
    def __init__(self, config):
        super().__init__(config)
        self.preprocessor = BertWordPieceTokenizer(config.vocab_file,
                                                   lowercase=config.lower_case)
        self.cls_id = self.preprocessor.token_to_id('[CLS]')
        self.sep_id = self.preprocessor.token_to_id('[SEP]')
        self.max_seq_length = config.max_seq_length
        self.max_ctx_tokens = 0  # updated in input_to_feature

    def get_max_ctx_tokens(self, q_len):
        return self.max_seq_length - q_len - 3  # 1 [CLS], 2 [SEP]

    def get_ctx_offset(self, q_len):
        return q_len + 2  # +2 for [CLS], [SEP] since q is before ctx

    def input_to_feature(self, one_input):
        if len(one_input) == 3:
            qid, question, context = one_input
            label = None
        elif len(one_input) == 4:
            qid, question, context, label = one_input
        else:
            raise ValueError('number of inputs not valid error: {}'.format(
                len(one_input)))

        q_codes = self.preprocessor.encode(question, add_special_tokens=False)
        ctx_codes = self.preprocessor.encode(context, add_special_tokens=False)
        q_ids = q_codes.ids
        ctx_ids = ctx_codes.ids
        ctx_tokens = ctx_codes.tokens
        ctx_spans = ctx_codes.offsets

        label_info = self.may_process_label(label, (context, ctx_codes))
        ans_cls, ans_start, ans_end = label_info
        feature_dict = {
            'feature_id': qid,
            'question': question,
            'context': context,
            'question_tokens': q_codes.tokens,
            'label': label,
            'cls': ans_cls,
            'answer_start': ans_start,
            'answer_end': ans_end,
        }
        q_len = len(q_codes.tokens)
        ctx_token_len = len(ctx_codes.tokens)
        max_ctx_tokens = self.get_max_ctx_tokens(q_len)
        context_valid_spans = get_valid_windows(ctx_token_len, max_ctx_tokens,
                                                self.config.context_stride)
        win_offset = self.get_ctx_offset(q_len)
        for win_span in context_valid_spans:
            win_start, win_end = win_span
            win_ctx_ids = ctx_ids[win_start:win_end]
            feature_dict = self.build_ids(feature_dict, q_ids, win_ctx_ids)
            win_ctx_tokens = ctx_tokens[win_start:win_end]
            win_ctx_spans = ctx_spans[win_start:win_end]

            cls, answer_start, answer_end = self.adjust_label(
                feature_dict, win_offset, win_span)
            if feature_dict['label'] is not None and cls is None:
                # has label, but no valid answer_span in current window
                continue
            self.num_examples += 1
            feature_id = '{}_{}'.format(self.num_examples,
                                        feature_dict['feature_id'])
            feature_dict['feature_id'] = feature_id
            feature_dict['context_tokens'] = win_ctx_tokens
            feature_dict['context_spans'] = win_ctx_spans
            feature_dict['answer_start'] = answer_start
            feature_dict['answer_end'] = answer_end
            feature = self.feature(**feature_dict)
            yield feature

    def build_ids(self, feature_dict, q_ids, win_ctx_ids):
        # for BERT, first put cls, then put q and ctx
        first_part_ids = [self.cls_id] + q_ids + [self.sep_id]
        second_part_ids = win_ctx_ids + [self.sep_id]
        input_ids = first_part_ids + second_part_ids

        segment_ids = [0] * len(first_part_ids) + [1] * len(second_part_ids)
        # pad to max_seq_length
        input_len = len(input_ids)
        input_ids += [0] * (self.max_seq_length - input_len)
        segment_ids += [0] * (self.max_seq_length - input_len)

        feature_dict['input_ids'] = input_ids
        feature_dict['segment_ids'] = segment_ids
        return feature_dict

    @staticmethod
    def two_seq_str_fn(feat):
        q_str = [
            '|{:>5}|{:>15}|{:>10}|{:>10}'.format('q_idx', 'token', 'input_idx',
                                                 'input_id')
        ]
        q_str.extend([
            '|{:>5}|{:>15}|{:>10}|{:>10}'.format(q_idx, q_token, q_idx + 1,
                                                 feat.input_ids[q_idx + 1])
            for q_idx, q_token in enumerate(feat.question_tokens)
        ])

        ctx_str = [
            '|{:>5}|{:>15}|{:>15}|{:>10}|{:>10}'.format(
                'c_idx', 'token', 'span', 'input_idx', 'input_id')
        ]
        q_len = len(feat.question_tokens)
        ctx_str.extend([
            '|{:>5}|{:>15}|{:>15}|{:>10}|{:>10}'.format(
                c_idx, c_token, str(feat.context_spans[c_idx]),
                c_idx + 2 + q_len, feat.input_ids[c_idx + 2 + q_len])
            for c_idx, c_token in enumerate(feat.context_tokens)
        ])

        return q_str, ctx_str
예제 #11
0
                dec_seq_len=512)
checkpoint = torch.load(
    'checkpoints/amadeus-performer-2020-11-25-00.20.57-300.pt')
model.eval(True)
# model.load_state_dict(torch.load('models/amadeus-performer-2020-11-06-12.47.52.pt'))
model.load_state_dict(checkpoint['model_state_dict'])
model.cuda()

run = True

sentences = []

while run:
    try:
        sentence = input('> ')
        if sentence in ['quit', 'exit']:
            run = False
            continue
        sentences.append(tokenizer.encode(sentence))
        if len(sentences) > 3:
            sentences = sentences[-3:]
        input_seq = torch.tensor(Encoding.merge(sentences[:]).ids).cuda()
        start_tokens = torch.tensor([tokenizer.token_to_id('[CLS]')]).cuda()
        out = model.generate(input_seq=input_seq,
                             start_tokens=start_tokens,
                             eos_token=tokenizer.token_to_id('[SEP]'))
        response = tokenizer.decode(out.tolist())
        sentences.append(tokenizer.encode(response))
        print(response)
    except KeyboardInterrupt:
        run = False