def __init__(self, t: PreTrainedTokenizer, args, file_path: str, block_size=512): assert os.path.isfile(file_path) logger.info("Creating features from dataset file at %s", file_path) # -------------------------- CHANGES START bert_tokenizer = os.path.join(args.tokenizer_name, "vocab.txt") if os.path.exists(bert_tokenizer): logger.info("Loading BERT tokenizer") from tokenizers import BertWordPieceTokenizer tokenizer = BertWordPieceTokenizer(os.path.join(args.tokenizer_name, "vocab.txt"), handle_chinese_chars=False, lowercase=False) tokenizer.enable_truncation(512) else: from tokenizers import ByteLevelBPETokenizer from tokenizers.processors import BertProcessing logger.info("Loading RoBERTa tokenizer") tokenizer = ByteLevelBPETokenizer( os.path.join(args.tokenizer_name, "vocab.json"), os.path.join(args.tokenizer_name, "merges.txt") ) tokenizer._tokenizer.post_processor = BertProcessing( ("</s>", tokenizer.token_to_id("</s>")), ("<s>", tokenizer.token_to_id("<s>")), ) tokenizer.enable_truncation(max_length=512) logger.info("Reading file %s", file_path) with open(file_path, encoding="utf-8") as f: lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())] logger.info("Running tokenization") self.examples = tokenizer.encode_batch(lines)
class CustomDataset: def __init__(self, sentences, bert_path, padding=140): self.sentences = sentences self.tokenizer = BertWordPieceTokenizer(f'{bert_path}/vocab.txt', lowercase=True) self.padding = padding def __len__(self): return len(self.sentences) def __getitem__(self, idx): s = self.sentences[idx] #['[CLS]', *self.sentences[idx], '[SEP]'] to_ignore_none = lambda x: x if x is not None else 0 to_id = lambda x: to_ignore_none(self.tokenizer.token_to_id(x)) n_pads = self.padding - len(s) x = list(map(to_id, s)) assert (len(x) == len(s)) x = x + [0 for _ in range(n_pads)] return torch.tensor(x), n_pads #, torch.tensor([])
def train_tokenizer(filename, params): """ Train a BertWordPieceTokenizer with the specified params and save it """ # Get tokenization params save_location = params["tokenizer_path"] max_length = params["max_length"] min_freq = params["min_freq"] vocabsize = params["vocab_size"] tokenizer = BertWordPieceTokenizer() tokenizer.do_lower_case = False special_tokens = ["[S]","[PAD]","[/S]","[UNK]","[MASK]", "[SEP]","[CLS]"] tokenizer.train(files=[filename], vocab_size=vocabsize, min_frequency=min_freq, special_tokens = special_tokens) tokenizer._tokenizer.post_processor = BertProcessing(("[SEP]", tokenizer.token_to_id("[SEP]")), ("[CLS]", tokenizer.token_to_id("[CLS]")),) tokenizer.enable_truncation(max_length=max_length) print("Saving tokenizer ...") if not os.path.exists(save_location): os.makedirs(save_location) tokenizer.save(save_location)
def model_predict(inp, model=[]): # Called by default-python environment. # inp -- default is a string, but you can also specify # the type in "input_type.py". # model is optional and the return value of load_model. # Should return JSON. # predict all tokens text = inp.pred tokenizer = BertTokenizer.from_pretrained('models/src/models/vocab_swebert.txt', do_lower_case=False) # input_ids = tokenizer(text.lower())["input_ids"] # tokenizer = BertTokenizer.from_pretrained('src/models/vocab_swebert.txt', lowercase=True, strip_accents=False) bert_word_piece_tokenizer = BertWordPieceTokenizer("models/src/models/vocab_swebert.txt", lowercase=True, strip_accents=False) output = bert_word_piece_tokenizer.encode(text) tokens = output.tokens indexed_tokens = output.ids input_ids = indexed_tokens print(tokens) # mask one of the tokens masked_index = inp.msk_ind tokens[masked_index] = '[MASK]' print(tokens) # input_ids[masked_index] = tokenizer.convert_tokens_to_ids('[MASK]') indexed_tokens[masked_index] = bert_word_piece_tokenizer.token_to_id('[MASK]') print(input_ids) # do predictions with torch.no_grad(): #deactivate the autograd engine to reduce memory usage and speed up outputs = model(torch.tensor([input_ids])) predictions = outputs[0] predicted_index_top5 = torch.argsort(predictions[0, masked_index], descending=True)[:5] predicted_token = tokenizer.convert_ids_to_tokens(predicted_index_top5) # predicted_index_top5 print(predicted_token) return {"result": predicted_token}
class Tokenizer(object): def __init__(self, vocab_path, do_lower_case=True): if BertWordPieceTokenizer: self.tokenizer = BertWordPieceTokenizer( vocab_path, lowercase=do_lower_case, ) else: self.tokenizer = tokenization.FullTokenizer( vocab_path, do_lower_case=do_lower_case) self._do_lower_case = do_lower_case def tokenize(self, input_text): if BertWordPieceTokenizer: return self.tokenizer.encode(input_text, add_special_tokens=False).tokens else: return self.tokenizer.tokenize(input_text) def encode(self, input_text, add_special_tokens=False): input_tokens = self.tokenize(input_text) if add_special_tokens: input_tokens = ['[CLS]'] + input_tokens + ['[SEP]'] input_token_ids = self.convert_tokens_to_ids(input_tokens) return input_token_ids def padded_to_ids(self, input_text, max_length): if len(input_text) > max_length: return input_text[:max_length] else: return input_text + [0] * (max_length - len(input_text)) def convert_tokens_to_ids(self, input_tokens): if BertWordPieceTokenizer: token_ids = [ self.tokenizer.token_to_id(token) for token in input_tokens ] else: token_ids = self.tokenizer.convert_tokens_to_ids(input_tokens) return token_ids def customize_tokenize(self, input_text): temp_x = "" for c in input_text: if self._is_cjk_character(c) or self._is_punctuation( c) or self._is_space(c) or self._is_control(c): temp_x += " " + c + " " else: temp_x += c return temp_x.split() def convert_ids_to_tokens(self, input_ids): if BertWordPieceTokenizer: input_tokens = [ self.tokenizer.id_to_token(ids) for ids in input_ids ] else: input_tokens = self.tokenizer.convert_ids_to_tokens(input_ids) return input_tokens def decode(self, input_tokens): text, flag = '', False for i, token in enumerate(input_tokens): if token[:2] == '##': text += token[2:] elif len(token) == 1 and self._is_cjk_character(token): text += token elif len(token) == 1 and self._is_punctuation(token): text += token text += ' ' elif i > 0 and self._is_cjk_character(text[-1]): text += token else: text += ' ' text += token text = re.sub(' +', ' ', text) text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text) punctuation = self._cjk_punctuation() + '+-/={(<[' punctuation_regex = '|'.join([re.escape(p) for p in punctuation]) punctuation_regex = '(%s) ' % punctuation_regex text = re.sub(punctuation_regex, '\\1', text) text = re.sub('(\d\.) (\d)', '\\1\\2', text) return text.strip() @staticmethod def stem(token): """ """ if token[:2] == '##': return token[2:] else: return token @staticmethod def _is_space(ch): """ """ return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \ unicodedata.category(ch) == 'Zs' @staticmethod def _is_punctuation(ch): """ """ code = ord(ch) return 33 <= code <= 47 or \ 58 <= code <= 64 or \ 91 <= code <= 96 or \ 123 <= code <= 126 or \ unicodedata.category(ch).startswith('P') @staticmethod def _cjk_punctuation(): return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002' @staticmethod def _is_cjk_character(ch): """ reference:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) """ code = ord(ch) return 0x4E00 <= code <= 0x9FFF or \ 0x3400 <= code <= 0x4DBF or \ 0x20000 <= code <= 0x2A6DF or \ 0x2A700 <= code <= 0x2B73F or \ 0x2B740 <= code <= 0x2B81F or \ 0x2B820 <= code <= 0x2CEAF or \ 0xF900 <= code <= 0xFAFF or \ 0x2F800 <= code <= 0x2FA1F @staticmethod def _is_control(ch): """ """ return unicodedata.category(ch) in ('Cc', 'Cf') @staticmethod def _is_special(ch): """ """ return bool(ch) and (ch[0] == '[') and (ch[-1] == ']') def rematch(self, text, tokens): if is_py2: text = unicode(text) if self._do_lower_case: text = text.lower() normalized_text, char_mapping = '', [] for i, ch in enumerate(text): if self._do_lower_case: ch = unicodedata.normalize('NFD', ch) ch = ''.join( [c for c in ch if unicodedata.category(c) != 'Mn']) ch = ''.join([ c for c in ch if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c)) ]) normalized_text += ch char_mapping.extend([i] * len(ch)) text, token_mapping, offset = normalized_text, [], 0 for token in tokens: if self._is_special(token): token_mapping.append([]) else: token = self.stem(token) start = text[offset:].index(token) + offset end = start + len(token) token_mapping.append(char_mapping[start:end]) offset = end return token_mapping
class BertClassifierDataBuilder(ClassifierDataBuilder): def __init__(self, config): super().__init__(config) self.preprocessor = BertWordPieceTokenizer(config.vocab_file, lowercase=config.lower_case) self.cls_id = self.preprocessor.token_to_id('[CLS]') self.sep_id = self.preprocessor.token_to_id('[SEP]') self.max_seq_length = config.max_seq_length def build_one_input_ids(self, seq1_codes, seq2_codes): seq1_ids = seq1_codes.ids seq2_ids = seq2_codes.ids seq1_tokens = seq1_codes.tokens seq2_tokens = seq2_codes.tokens """Truncates a sequence pair in place to the maximum length.""" # This is a simple heuristic which will always truncate the longer # sequence one token at a time. This makes more sense than # truncating an equal percent of tokens from each, since if one # sequence is very short then each token that's truncated likely # contains more information than a longer sequence. while True: total_length = len(seq1_tokens) + len(seq2_tokens) if total_length <= self.max_seq_length - 3: # Account for [CLS], [SEP], [SEP] with "- 3" # logger.info('truncation finished.') break if len(seq1_tokens) > len(seq2_tokens): seq1_tokens.pop() seq1_ids.pop() else: seq2_tokens.pop() seq2_ids.pop() first_part_ids = [self.cls_id] + seq1_ids + [self.sep_id] second_part_ids = seq2_ids + [self.sep_id] input_ids = first_part_ids + second_part_ids segment_ids = [0] * len(first_part_ids) + [1] * len(second_part_ids) # pad to max_seq_length input_len = len(input_ids) input_ids += [0] * (self.max_seq_length - input_len) segment_ids += [0] * (self.max_seq_length - input_len) return input_ids, segment_ids, seq1_tokens, seq2_tokens def build_ids(self, seq1_list, seq2_codes): part1_ids = [] part2_ids = [] seq1_tokens = [] seq2_tokens = [] for s1 in seq1_list: s1_codes = self.preprocessor.encode(s1, add_special_tokens=False) one_output = self.build_one_input_ids(s1_codes, seq2_codes) p1_ids, sp2_ids, s1_tokens, s2_tokens = one_output part1_ids.extend(p1_ids) part2_ids.extend(sp2_ids) seq1_tokens.extend(s1_tokens) seq2_tokens.extend(s2_tokens) return part1_ids, part2_ids, seq1_tokens, seq2_tokens def set_ids(self, feature_dict, one_output): input_ids, segment_ids, seq1_tokens, seq2_tokens = one_output feature_dict['input_ids'] = input_ids feature_dict['segment_ids'] = segment_ids feature_dict['seq1_tokens'] = seq1_tokens feature_dict['seq2_tokens'] = seq2_tokens return feature_dict def input_to_feature(self, one_input): if len(one_input) == 3: eid, seq1, seq2 = one_input label = None elif len(one_input) == 4: eid, seq1, seq2, label = one_input else: raise ValueError('number of inputs not valid error: {}'.format( len(one_input))) seq2_codes = self.preprocessor.encode(seq2, add_special_tokens=False) ans_cls = self.may_process_label(label, None) feature_dict = { 'feature_id': eid, 'label': label, 'cls': ans_cls, 'seq1': seq1, 'seq2': seq2, } if isinstance(seq1, list): # for race multiple choice seq1_list = seq1 else: seq1_list = [seq1] # for boolq, mnli, qqp one_output = self.build_ids(seq1_list, seq2_codes) feature_dict = self.set_ids(feature_dict, one_output) self.num_examples += 1 feature_id = '{}_{}'.format(self.num_examples, feature_dict['feature_id']) feature_dict['feature_id'] = feature_id feature = self.feature(**feature_dict) yield feature @staticmethod def two_seq_str_fn(feat): seq1_str = [ '|{:>5}|{:>15}|{:>10}|{:>10}'.format('seq1_idx', 'token', 'input_idx', 'input_id') ] seq1_str.extend([ '|{:>5}|{:>15}|{:>10}|{:>10}'.format(q_idx, q_token, q_idx + 1, feat.input_ids[q_idx + 1]) for q_idx, q_token in enumerate(feat.seq1_tokens) ]) seq2_str = [ '|{:>5}|{:>15}|{:>10}|{:>10}'.format('seq2_idx', 'token', 'input_idx', 'input_id') ] seq1_len = len(feat.seq1_tokens) seq2_str.extend([ '|{:>5}|{:>15}|{:>10}|{:>10}'.format( c_idx, c_token, c_idx + 2 + seq1_len, feat.input_ids[c_idx + 2 + seq1_len]) for c_idx, c_token in enumerate(feat.seq2_tokens) ]) return seq1_str, seq2_str
class Reader(object): def __init__(self, bert_model: str, tokenizer: BaseTokenizer = None, cls: str = "[CLS]", sep: str = "[SEP]", threshold=6): self.tokenizer: BaseTokenizer = tokenizer self.cls = cls self.sep = sep if self.tokenizer is None: vocab_path: str = "tokenization/" + bert_model + ".txt" self.tokenizer = BertWordPieceTokenizer(vocab_path, lowercase="-cased" not in bert_model) self.threshold = threshold self.subword_alphabet: Optional[Alphabet] = None self.label_alphabet: Optional[Alphabet] = None self.train: Optional[List[SentInst]] = None self.dev: Optional[List[SentInst]] = None self.test: Optional[List[SentInst]] = None def _read_file(self, filename: str, mode: str = 'train') -> List[SentInst]: sent_list = [] max_len = 0 num_thresh = 0 with open(filename, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line == "": # last few blank lines break raw_tokens = line.split(' ') tokens = raw_tokens chars = [list(t) for t in raw_tokens] entities = next(f).strip() if entities == "": # no entities sent_inst = SentInst(tokens, chars, []) else: entity_list = [] entities = entities.split("|") for item in entities: pointers, label = item.split() pointers = pointers.split(",") if int(pointers[1]) > len(tokens): pdb.set_trace() span_len = int(pointers[1]) - int(pointers[0]) if span_len < 0: print("Warning! span_len < 0") continue if span_len > max_len: max_len = span_len if span_len > self.threshold: num_thresh += 1 new_entity = (int(pointers[0]), int(pointers[1]), label) # may be duplicate entities in some datasets if (mode == 'train' and new_entity not in entity_list) or (mode != 'train'): entity_list.append(new_entity) # assert len(entity_list) == len(set(entity_list)) # check duplicate sent_inst = SentInst(tokens, chars, entity_list) assert next(f).strip() == "" # separating line sent_list.append(sent_inst) print("Max length: {}".format(max_len)) print("Threshold {}: {}".format(self.threshold, num_thresh)) return sent_list def _gen_dic(self) -> None: label_set = set() for sent_list in [self.train, self.dev, self.test]: num_mention = 0 for sentInst in sent_list: for entity in sentInst.entities: label_set.add(entity[2]) num_mention += len(sentInst.entities) print("# mentions: {}".format(num_mention)) vocab = [ self.tokenizer.id_to_token(idx) for idx in range(self.tokenizer.get_vocab_size()) ] self.subword_alphabet = Alphabet(vocab, 0) self.label_alphabet = Alphabet(label_set, 0) @staticmethod def _pad_batches(input_ids_batches: List[List[List[int]]], first_subtokens_batches: List[List[List[int]]]) \ -> Tuple[List[List[List[int]]], List[List[List[int]]], List[List[List[bool]]]]: padded_input_ids_batches = [] input_mask_batches = [] mask_batches = [] all_batches = list(zip(input_ids_batches, first_subtokens_batches)) for input_ids_batch, first_subtokens_batch in all_batches: batch_len = len(input_ids_batch) max_subtokens_num = max( [len(input_ids) for input_ids in input_ids_batch]) max_sent_len = max([ len(first_subtokens) for first_subtokens in first_subtokens_batch ]) padded_input_ids_batch = [] input_mask_batch = [] mask_batch = [] for i in range(batch_len): subtokens_num = len(input_ids_batch[i]) sent_len = len(first_subtokens_batch[i]) padded_subtoken_vec = input_ids_batch[i].copy() padded_subtoken_vec.extend([0] * (max_subtokens_num - subtokens_num)) input_mask = [1] * subtokens_num + [0] * (max_subtokens_num - subtokens_num) mask = [True] * sent_len + [False] * (max_sent_len - sent_len) padded_input_ids_batch.append(padded_subtoken_vec) input_mask_batch.append(input_mask) mask_batch.append(mask) padded_input_ids_batches.append(padded_input_ids_batch) input_mask_batches.append(input_mask_batch) mask_batches.append(mask_batch) return padded_input_ids_batches, input_mask_batches, mask_batches def get_batches(self, sentences: List[SentInst], batch_size: int) -> Tuple: subtoken_dic_dic = defaultdict(lambda: defaultdict(list)) first_subtoken_dic_dic = defaultdict(lambda: defaultdict(list)) last_subtoken_dic_dic = defaultdict(lambda: defaultdict(list)) label_dic_dic = defaultdict(lambda: defaultdict(list)) this_input_ids_batches = [] this_first_subtokens_batches = [] this_last_subtokens_batches = [] this_label_batches = [] for sentInst in sentences: subtoken_vec = [] first_subtoken_vec = [] last_subtoken_vec = [] subtoken_vec.append(self.tokenizer.token_to_id(self.cls)) for t in sentInst.tokens: encoding = self.tokenizer.encode(t) ids = [ v for v, mask in zip(encoding.ids, encoding.special_tokens_mask) if mask == 0 ] first_subtoken_vec.append(len(subtoken_vec)) subtoken_vec.extend(ids) last_subtoken_vec.append(len(subtoken_vec)) subtoken_vec.append(self.tokenizer.token_to_id(self.sep)) label_list = [(u[0], u[1], self.label_alphabet.get_index(u[2])) for u in sentInst.entities] subtoken_dic_dic[len( sentInst.tokens)][len(subtoken_vec)].append(subtoken_vec) first_subtoken_dic_dic[len( sentInst.tokens)][len(subtoken_vec)].append(first_subtoken_vec) last_subtoken_dic_dic[len( sentInst.tokens)][len(subtoken_vec)].append(last_subtoken_vec) label_dic_dic[len( sentInst.tokens)][len(subtoken_vec)].append(label_list) input_ids_batches = [] first_subtokens_batches = [] last_subtokens_batches = [] label_batches = [] for length1 in sorted(subtoken_dic_dic.keys(), reverse=True): for length2 in sorted(subtoken_dic_dic[length1].keys(), reverse=True): input_ids_batches.extend(subtoken_dic_dic[length1][length2]) first_subtokens_batches.extend( first_subtoken_dic_dic[length1][length2]) last_subtokens_batches.extend( last_subtoken_dic_dic[length1][length2]) label_batches.extend(label_dic_dic[length1][length2]) [ this_input_ids_batches.append(input_ids_batches[i:i + batch_size]) for i in range(0, len(input_ids_batches), batch_size) ] [ this_first_subtokens_batches.append( first_subtokens_batches[i:i + batch_size]) for i in range(0, len(first_subtokens_batches), batch_size) ] [ this_last_subtokens_batches.append( last_subtokens_batches[i:i + batch_size]) for i in range(0, len(last_subtokens_batches), batch_size) ] [ this_label_batches.append(label_batches[i:i + batch_size]) for i in range(0, len(label_batches), batch_size) ] this_input_ids_batches, this_input_mask_batches, this_mask_batches \ = self._pad_batches(this_input_ids_batches, this_first_subtokens_batches) return (this_input_ids_batches, this_input_mask_batches, this_first_subtokens_batches, this_last_subtokens_batches, this_label_batches, this_mask_batches) def to_batch(self, batch_size: int) -> Tuple: ret_list = [] for sent_list in [self.train, self.dev, self.test]: ret_list.append(self.get_batches(sent_list, batch_size)) return tuple(ret_list) def read_all_data(self, file_path: str, train_file: str, dev_file: str, test_file: str) -> None: self.train = self._read_file(file_path + train_file) self.dev = self._read_file(file_path + dev_file, mode='dev') self.test = self._read_file(file_path + test_file, mode='test') self._gen_dic() def debug_single_sample(self, subtoken: List[int], label_list: List[Tuple[int, int, int]]) -> None: print(" ".join( [self.subword_alphabet.get_instance(t) for t in subtoken])) for label in label_list: print(label[0], label[1], self.label_alphabet.get_instance(label[2]))
def numerize(vocab_path, input_path, bin_path): tokenizer = BertWordPieceTokenizer(vocab_path, unk_token=UNK_TOKEN, sep_token=SEP_TOKEN, cls_token=CLS_TOKEN, pad_token=PAD_TOKEN, mask_token=MASK_TOKEN, lowercase=False, strip_accents=False) sentences = [] with open(input_path, 'r') as f: batch_stream = [] for i, line in enumerate(f): batch_stream.append(line) if i % 1000 == 0: res = tokenizer.encode_batch(batch_stream) batch_stream = [] # flatten the list for s in res: sentences.extend(s.ids[1:]) if i % 100000 == 0: print(f'processed {i} lines') print('convert the data to numpy') # convert data to numpy format in uint16 if tokenizer.get_vocab_size() < 1 << 16: sentences = np.uint16(sentences) else: assert tokenizer.get_vocab_size() < 1 << 31 sentences = np.int32(sentences) # save special tokens for later processing sep_index = tokenizer.token_to_id(SEP_TOKEN) cls_index = tokenizer.token_to_id(CLS_TOKEN) unk_index = tokenizer.token_to_id(UNK_TOKEN) mask_index = tokenizer.token_to_id(MASK_TOKEN) pad_index = tokenizer.token_to_id(PAD_TOKEN) # sanity check assert sep_index == SEP_INDEX assert cls_index == CLS_INDEX assert unk_index == UNK_INDEX assert pad_index == PAD_INDEX assert mask_index == MASK_INDEX print('collect statistics') # collect some statistics of the dataset n_unks = (sentences == unk_index).sum() n_toks = len(sentences) p_unks = n_unks * 100. / n_toks n_seqs = (sentences == sep_index).sum() print( f'| {n_seqs} sentences - {n_toks} tokens - {p_unks:.2f}% unknown words' ) # print some statistics data = { 'sentences': sentences, 'sep_index': sep_index, 'cls_index': cls_index, 'unk_index': unk_index, 'pad_index': pad_index, 'mask_index': mask_index } torch.save(data, bin_path, pickle_protocol=4)
class Tweets(Dataset): def __init__(self, device='cpu', pad=150, test=False, N=4): self.samples = [] self.pad = pad self.tokenizer = BertWordPieceTokenizer( "./data/bert-base-uncased-vocab.txt", lowercase=True, clean_text=True) self.tokenizer.enable_padding(max_length=pad - 1) # -1 for sentiment token self.tokenizer.add_special_tokens(['[POS]']) self.tokenizer.add_special_tokens(['[NEG]']) self.tokenizer.add_special_tokens(['[NEU]']) self.vocab = self.tokenizer.get_vocab() self.sent_t = { 'positive': self.tokenizer.token_to_id('[POS]'), 'negative': self.tokenizer.token_to_id('[NEG]'), 'neutral': self.tokenizer.token_to_id('[NEU]') } self.pos_set = {'UNK': 0} all_pos = load('help/tagsets/upenn_tagset.pickle').keys() for i, p in enumerate(all_pos): self.pos_set[p] = i + 1 self.tweet_tokenizer = TweetTokenizer() data = None if test is True: data = pd.read_csv(TEST_PATH).values for row in data: tid, tweet, sentiment = tuple(row) pos_membership = [0] * len(tweet) pos_tokens = self.tweet_tokenizer.tokenize(tweet) pos = nltk.pos_tag(pos_tokens) offset = 0 for i, token in enumerate(pos_tokens): start = tweet.find(token, offset) end = start + len(token) if pos[i][1] in self.pos_set: pos_membership[start:end] = [self.pos_set[pos[i][1]] ] * len(token) offset += len(token) tokens = self.tokenizer.encode(tweet) word_to_index = tokens.ids offsets = tokens.offsets token_pos = [0] * len(word_to_index) # get pos info for i, (s, e) in enumerate(offsets): if word_to_index[i] == 0 or word_to_index[ i] == 101 or word_to_index[i] == 102: pass elif s != e: sub = pos_membership[s:e] token_pos[i] = max(set(sub), key=sub.count) token_pos = [0] + token_pos word_to_index = [self.sent_t[sentiment]] + word_to_index offsets = [(0, 0)] + offsets offsets = np.array([[off[0], off[1]] for off in offsets]) word_to_index = np.array(word_to_index) token_pos = np.array(token_pos) self.samples.append({ 'tid': tid, 'sentiment': sentiment, 'tweet': word_to_index, 'offsets': offsets, 'raw_tweet': tweet, 'pos': token_pos }) else: data = pd.read_csv(TRAIN_PATH).values if N > 0: data = augment_n(data, N=N) for row in data: tid, tweet, selection, sentiment = tuple(row) char_membership = [0] * len(tweet) pos_membership = [0] * len(tweet) si = tweet.find(selection) if si < 0: char_membership[0:] = [1] * len(char_membership) else: char_membership[si:si + len(selection)] = [1] * len(selection) pos_tokens = self.tweet_tokenizer.tokenize(tweet) pos = nltk.pos_tag(pos_tokens) offset = 0 for i, token in enumerate(pos_tokens): start = tweet.find(token, offset) end = start + len(token) if pos[i][1] in self.pos_set: pos_membership[start:end] = [self.pos_set[pos[i][1]] ] * len(token) offset += len(token) tokens = self.tokenizer.encode(tweet) word_to_index = tokens.ids offsets = tokens.offsets token_membership = [0] * len(word_to_index) token_pos = [0] * len(word_to_index) # Inclusive indices start = None end = None for i, (s, e) in enumerate(offsets): if word_to_index[i] == 0 or word_to_index[ i] == 101 or word_to_index[i] == 102: token_membership[i] = -1 elif sum(char_membership[s:e]) > 0: token_membership[i] = 1 if start is None: start = i + 1 end = i + 1 # get pos info for i, (s, e) in enumerate(offsets): if word_to_index[i] == 0 or word_to_index[ i] == 101 or word_to_index[i] == 102: pass elif s != e: sub = pos_membership[s:e] token_pos[i] = max(set(sub), key=sub.count) if start is None: print("Data Point Error") print(tweet) print(selection) continue # token_membership = torch.LongTensor(token_membership).to(device) word_to_index = [self.sent_t[sentiment]] + word_to_index token_membership = [-1] + token_membership offsets = [(0, 0)] + offsets token_pos = [0] + token_pos offsets = np.array([[off[0], off[1]] for off in offsets]) word_to_index = np.array(word_to_index) token_membership = np.array(token_membership).astype('float') token_pos = np.array(token_pos) if tid is None: raise Exception('None field detected') if sentiment is None: raise Exception('None field detected') if word_to_index is None: raise Exception('None field detected') if token_membership is None: raise Exception('None field detected') if selection is None: raise Exception('None field detected') if tweet is None: raise Exception('None field detected') if start is None: raise Exception('None field detected') if end is None: raise Exception('None field detected') if offsets is None: raise Exception('None field detected') self.samples.append({ 'tid': tid, 'sentiment': sentiment, 'tweet': word_to_index, 'selection': token_membership, 'raw_selection': selection, 'raw_tweet': tweet, 'start': start, 'end': end, 'offsets': offsets, 'pos': token_pos }) def get_splits(self, val_size=.3): N = len(self.samples) indices = np.random.permutation(N) split = int(N * (1 - val_size)) train_indices = indices[0:split] valid_indices = indices[split:] return train_indices, valid_indices def k_folds(self, k=5): N = len(self.samples) indices = np.random.permutation(N) return np.array_split(indices, k) def __len__(self): return len(self.samples) def __getitem__(self, idx): try: return self.samples[idx] except TypeError: pass return [self.samples[i] for i in idx]
class BertQaDataBuilder(QaDataBuilder): def __init__(self, config): super().__init__(config) self.preprocessor = BertWordPieceTokenizer(config.vocab_file, lowercase=config.lower_case) self.cls_id = self.preprocessor.token_to_id('[CLS]') self.sep_id = self.preprocessor.token_to_id('[SEP]') self.max_seq_length = config.max_seq_length self.max_ctx_tokens = 0 # updated in input_to_feature def get_max_ctx_tokens(self, q_len): return self.max_seq_length - q_len - 3 # 1 [CLS], 2 [SEP] def get_ctx_offset(self, q_len): return q_len + 2 # +2 for [CLS], [SEP] since q is before ctx def input_to_feature(self, one_input): if len(one_input) == 3: qid, question, context = one_input label = None elif len(one_input) == 4: qid, question, context, label = one_input else: raise ValueError('number of inputs not valid error: {}'.format( len(one_input))) q_codes = self.preprocessor.encode(question, add_special_tokens=False) ctx_codes = self.preprocessor.encode(context, add_special_tokens=False) q_ids = q_codes.ids ctx_ids = ctx_codes.ids ctx_tokens = ctx_codes.tokens ctx_spans = ctx_codes.offsets label_info = self.may_process_label(label, (context, ctx_codes)) ans_cls, ans_start, ans_end = label_info feature_dict = { 'feature_id': qid, 'question': question, 'context': context, 'question_tokens': q_codes.tokens, 'label': label, 'cls': ans_cls, 'answer_start': ans_start, 'answer_end': ans_end, } q_len = len(q_codes.tokens) ctx_token_len = len(ctx_codes.tokens) max_ctx_tokens = self.get_max_ctx_tokens(q_len) context_valid_spans = get_valid_windows(ctx_token_len, max_ctx_tokens, self.config.context_stride) win_offset = self.get_ctx_offset(q_len) for win_span in context_valid_spans: win_start, win_end = win_span win_ctx_ids = ctx_ids[win_start:win_end] feature_dict = self.build_ids(feature_dict, q_ids, win_ctx_ids) win_ctx_tokens = ctx_tokens[win_start:win_end] win_ctx_spans = ctx_spans[win_start:win_end] cls, answer_start, answer_end = self.adjust_label( feature_dict, win_offset, win_span) if feature_dict['label'] is not None and cls is None: # has label, but no valid answer_span in current window continue self.num_examples += 1 feature_id = '{}_{}'.format(self.num_examples, feature_dict['feature_id']) feature_dict['feature_id'] = feature_id feature_dict['context_tokens'] = win_ctx_tokens feature_dict['context_spans'] = win_ctx_spans feature_dict['answer_start'] = answer_start feature_dict['answer_end'] = answer_end feature = self.feature(**feature_dict) yield feature def build_ids(self, feature_dict, q_ids, win_ctx_ids): # for BERT, first put cls, then put q and ctx first_part_ids = [self.cls_id] + q_ids + [self.sep_id] second_part_ids = win_ctx_ids + [self.sep_id] input_ids = first_part_ids + second_part_ids segment_ids = [0] * len(first_part_ids) + [1] * len(second_part_ids) # pad to max_seq_length input_len = len(input_ids) input_ids += [0] * (self.max_seq_length - input_len) segment_ids += [0] * (self.max_seq_length - input_len) feature_dict['input_ids'] = input_ids feature_dict['segment_ids'] = segment_ids return feature_dict @staticmethod def two_seq_str_fn(feat): q_str = [ '|{:>5}|{:>15}|{:>10}|{:>10}'.format('q_idx', 'token', 'input_idx', 'input_id') ] q_str.extend([ '|{:>5}|{:>15}|{:>10}|{:>10}'.format(q_idx, q_token, q_idx + 1, feat.input_ids[q_idx + 1]) for q_idx, q_token in enumerate(feat.question_tokens) ]) ctx_str = [ '|{:>5}|{:>15}|{:>15}|{:>10}|{:>10}'.format( 'c_idx', 'token', 'span', 'input_idx', 'input_id') ] q_len = len(feat.question_tokens) ctx_str.extend([ '|{:>5}|{:>15}|{:>15}|{:>10}|{:>10}'.format( c_idx, c_token, str(feat.context_spans[c_idx]), c_idx + 2 + q_len, feat.input_ids[c_idx + 2 + q_len]) for c_idx, c_token in enumerate(feat.context_tokens) ]) return q_str, ctx_str
dec_seq_len=512) checkpoint = torch.load( 'checkpoints/amadeus-performer-2020-11-25-00.20.57-300.pt') model.eval(True) # model.load_state_dict(torch.load('models/amadeus-performer-2020-11-06-12.47.52.pt')) model.load_state_dict(checkpoint['model_state_dict']) model.cuda() run = True sentences = [] while run: try: sentence = input('> ') if sentence in ['quit', 'exit']: run = False continue sentences.append(tokenizer.encode(sentence)) if len(sentences) > 3: sentences = sentences[-3:] input_seq = torch.tensor(Encoding.merge(sentences[:]).ids).cuda() start_tokens = torch.tensor([tokenizer.token_to_id('[CLS]')]).cuda() out = model.generate(input_seq=input_seq, start_tokens=start_tokens, eos_token=tokenizer.token_to_id('[SEP]')) response = tokenizer.decode(out.tolist()) sentences.append(tokenizer.encode(response)) print(response) except KeyboardInterrupt: run = False