metrics_path = os.path.join('results', 'metrics.pickle') if not os.path.exists(metrics_path): return (None, None, None, None) with open(metrics_path, 'rb') as metrics_handle: metrics_obj = pickle.load(metrics_handle) return (metrics_obj['token_pairs'], metrics_obj['decoded_pairs'], metrics_obj['jaccard_similarities'], metrics_obj['levenshtein_distances']) token_pairs, decoded_pairs, jaccard_similarities, levenshtein_distances = load_metrics_obj() if not token_pairs: token_pairs = [([tokenizer.id_to_token(x) for x in ocr_tokens[i]], [tokenizer.id_to_token(x) for x in gs_tokens[i]]) for i in range(len(ocr_tokens))] save_metrics_obj(token_pairs, decoded_pairs, jaccard_similarities, levenshtein_distances) if not decoded_pairs: decoded_pairs = [(tokenizer.decode(ocr_tokens[i]), tokenizer.decode(gs_tokens[i])) for i in range(len(ocr_tokens))] save_metrics_obj(token_pairs, decoded_pairs, jaccard_similarities, levenshtein_distances) all_pairs = len(token_pairs) if not jaccard_similarities: jaccard_similarities = [] for i, token_pair in enumerate(token_pairs): jaccard_similarities.append(calculate_jaccard_similarity(token_pair[0], token_pair[1])) save_metrics_obj(token_pairs, decoded_pairs, jaccard_similarities, levenshtein_distances) if not levenshtein_distances:
class Tokenizer(object): def __init__(self, vocab_path, do_lower_case=True): if BertWordPieceTokenizer: self.tokenizer = BertWordPieceTokenizer( vocab_path, lowercase=do_lower_case, ) else: self.tokenizer = tokenization.FullTokenizer( vocab_path, do_lower_case=do_lower_case) self._do_lower_case = do_lower_case def tokenize(self, input_text): if BertWordPieceTokenizer: return self.tokenizer.encode(input_text, add_special_tokens=False).tokens else: return self.tokenizer.tokenize(input_text) def encode(self, input_text, add_special_tokens=False): input_tokens = self.tokenize(input_text) if add_special_tokens: input_tokens = ['[CLS]'] + input_tokens + ['[SEP]'] input_token_ids = self.convert_tokens_to_ids(input_tokens) return input_token_ids def padded_to_ids(self, input_text, max_length): if len(input_text) > max_length: return input_text[:max_length] else: return input_text + [0] * (max_length - len(input_text)) def convert_tokens_to_ids(self, input_tokens): if BertWordPieceTokenizer: token_ids = [ self.tokenizer.token_to_id(token) for token in input_tokens ] else: token_ids = self.tokenizer.convert_tokens_to_ids(input_tokens) return token_ids def customize_tokenize(self, input_text): temp_x = "" for c in input_text: if self._is_cjk_character(c) or self._is_punctuation( c) or self._is_space(c) or self._is_control(c): temp_x += " " + c + " " else: temp_x += c return temp_x.split() def convert_ids_to_tokens(self, input_ids): if BertWordPieceTokenizer: input_tokens = [ self.tokenizer.id_to_token(ids) for ids in input_ids ] else: input_tokens = self.tokenizer.convert_ids_to_tokens(input_ids) return input_tokens def decode(self, input_tokens): text, flag = '', False for i, token in enumerate(input_tokens): if token[:2] == '##': text += token[2:] elif len(token) == 1 and self._is_cjk_character(token): text += token elif len(token) == 1 and self._is_punctuation(token): text += token text += ' ' elif i > 0 and self._is_cjk_character(text[-1]): text += token else: text += ' ' text += token text = re.sub(' +', ' ', text) text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text) punctuation = self._cjk_punctuation() + '+-/={(<[' punctuation_regex = '|'.join([re.escape(p) for p in punctuation]) punctuation_regex = '(%s) ' % punctuation_regex text = re.sub(punctuation_regex, '\\1', text) text = re.sub('(\d\.) (\d)', '\\1\\2', text) return text.strip() @staticmethod def stem(token): """ """ if token[:2] == '##': return token[2:] else: return token @staticmethod def _is_space(ch): """ """ return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \ unicodedata.category(ch) == 'Zs' @staticmethod def _is_punctuation(ch): """ """ code = ord(ch) return 33 <= code <= 47 or \ 58 <= code <= 64 or \ 91 <= code <= 96 or \ 123 <= code <= 126 or \ unicodedata.category(ch).startswith('P') @staticmethod def _cjk_punctuation(): return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002' @staticmethod def _is_cjk_character(ch): """ reference:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) """ code = ord(ch) return 0x4E00 <= code <= 0x9FFF or \ 0x3400 <= code <= 0x4DBF or \ 0x20000 <= code <= 0x2A6DF or \ 0x2A700 <= code <= 0x2B73F or \ 0x2B740 <= code <= 0x2B81F or \ 0x2B820 <= code <= 0x2CEAF or \ 0xF900 <= code <= 0xFAFF or \ 0x2F800 <= code <= 0x2FA1F @staticmethod def _is_control(ch): """ """ return unicodedata.category(ch) in ('Cc', 'Cf') @staticmethod def _is_special(ch): """ """ return bool(ch) and (ch[0] == '[') and (ch[-1] == ']') def rematch(self, text, tokens): if is_py2: text = unicode(text) if self._do_lower_case: text = text.lower() normalized_text, char_mapping = '', [] for i, ch in enumerate(text): if self._do_lower_case: ch = unicodedata.normalize('NFD', ch) ch = ''.join( [c for c in ch if unicodedata.category(c) != 'Mn']) ch = ''.join([ c for c in ch if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c)) ]) normalized_text += ch char_mapping.extend([i] * len(ch)) text, token_mapping, offset = normalized_text, [], 0 for token in tokens: if self._is_special(token): token_mapping.append([]) else: token = self.stem(token) start = text[offset:].index(token) + offset end = start + len(token) token_mapping.append(char_mapping[start:end]) offset = end return token_mapping
class Reader(object): def __init__(self, bert_model: str, tokenizer: BaseTokenizer = None, cls: str = "[CLS]", sep: str = "[SEP]", threshold=6): self.tokenizer: BaseTokenizer = tokenizer self.cls = cls self.sep = sep if self.tokenizer is None: vocab_path: str = "tokenization/" + bert_model + ".txt" self.tokenizer = BertWordPieceTokenizer(vocab_path, lowercase="-cased" not in bert_model) self.threshold = threshold self.subword_alphabet: Optional[Alphabet] = None self.label_alphabet: Optional[Alphabet] = None self.train: Optional[List[SentInst]] = None self.dev: Optional[List[SentInst]] = None self.test: Optional[List[SentInst]] = None def _read_file(self, filename: str, mode: str = 'train') -> List[SentInst]: sent_list = [] max_len = 0 num_thresh = 0 with open(filename, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line == "": # last few blank lines break raw_tokens = line.split(' ') tokens = raw_tokens chars = [list(t) for t in raw_tokens] entities = next(f).strip() if entities == "": # no entities sent_inst = SentInst(tokens, chars, []) else: entity_list = [] entities = entities.split("|") for item in entities: pointers, label = item.split() pointers = pointers.split(",") if int(pointers[1]) > len(tokens): pdb.set_trace() span_len = int(pointers[1]) - int(pointers[0]) if span_len < 0: print("Warning! span_len < 0") continue if span_len > max_len: max_len = span_len if span_len > self.threshold: num_thresh += 1 new_entity = (int(pointers[0]), int(pointers[1]), label) # may be duplicate entities in some datasets if (mode == 'train' and new_entity not in entity_list) or (mode != 'train'): entity_list.append(new_entity) # assert len(entity_list) == len(set(entity_list)) # check duplicate sent_inst = SentInst(tokens, chars, entity_list) assert next(f).strip() == "" # separating line sent_list.append(sent_inst) print("Max length: {}".format(max_len)) print("Threshold {}: {}".format(self.threshold, num_thresh)) return sent_list def _gen_dic(self) -> None: label_set = set() for sent_list in [self.train, self.dev, self.test]: num_mention = 0 for sentInst in sent_list: for entity in sentInst.entities: label_set.add(entity[2]) num_mention += len(sentInst.entities) print("# mentions: {}".format(num_mention)) vocab = [ self.tokenizer.id_to_token(idx) for idx in range(self.tokenizer.get_vocab_size()) ] self.subword_alphabet = Alphabet(vocab, 0) self.label_alphabet = Alphabet(label_set, 0) @staticmethod def _pad_batches(input_ids_batches: List[List[List[int]]], first_subtokens_batches: List[List[List[int]]]) \ -> Tuple[List[List[List[int]]], List[List[List[int]]], List[List[List[bool]]]]: padded_input_ids_batches = [] input_mask_batches = [] mask_batches = [] all_batches = list(zip(input_ids_batches, first_subtokens_batches)) for input_ids_batch, first_subtokens_batch in all_batches: batch_len = len(input_ids_batch) max_subtokens_num = max( [len(input_ids) for input_ids in input_ids_batch]) max_sent_len = max([ len(first_subtokens) for first_subtokens in first_subtokens_batch ]) padded_input_ids_batch = [] input_mask_batch = [] mask_batch = [] for i in range(batch_len): subtokens_num = len(input_ids_batch[i]) sent_len = len(first_subtokens_batch[i]) padded_subtoken_vec = input_ids_batch[i].copy() padded_subtoken_vec.extend([0] * (max_subtokens_num - subtokens_num)) input_mask = [1] * subtokens_num + [0] * (max_subtokens_num - subtokens_num) mask = [True] * sent_len + [False] * (max_sent_len - sent_len) padded_input_ids_batch.append(padded_subtoken_vec) input_mask_batch.append(input_mask) mask_batch.append(mask) padded_input_ids_batches.append(padded_input_ids_batch) input_mask_batches.append(input_mask_batch) mask_batches.append(mask_batch) return padded_input_ids_batches, input_mask_batches, mask_batches def get_batches(self, sentences: List[SentInst], batch_size: int) -> Tuple: subtoken_dic_dic = defaultdict(lambda: defaultdict(list)) first_subtoken_dic_dic = defaultdict(lambda: defaultdict(list)) last_subtoken_dic_dic = defaultdict(lambda: defaultdict(list)) label_dic_dic = defaultdict(lambda: defaultdict(list)) this_input_ids_batches = [] this_first_subtokens_batches = [] this_last_subtokens_batches = [] this_label_batches = [] for sentInst in sentences: subtoken_vec = [] first_subtoken_vec = [] last_subtoken_vec = [] subtoken_vec.append(self.tokenizer.token_to_id(self.cls)) for t in sentInst.tokens: encoding = self.tokenizer.encode(t) ids = [ v for v, mask in zip(encoding.ids, encoding.special_tokens_mask) if mask == 0 ] first_subtoken_vec.append(len(subtoken_vec)) subtoken_vec.extend(ids) last_subtoken_vec.append(len(subtoken_vec)) subtoken_vec.append(self.tokenizer.token_to_id(self.sep)) label_list = [(u[0], u[1], self.label_alphabet.get_index(u[2])) for u in sentInst.entities] subtoken_dic_dic[len( sentInst.tokens)][len(subtoken_vec)].append(subtoken_vec) first_subtoken_dic_dic[len( sentInst.tokens)][len(subtoken_vec)].append(first_subtoken_vec) last_subtoken_dic_dic[len( sentInst.tokens)][len(subtoken_vec)].append(last_subtoken_vec) label_dic_dic[len( sentInst.tokens)][len(subtoken_vec)].append(label_list) input_ids_batches = [] first_subtokens_batches = [] last_subtokens_batches = [] label_batches = [] for length1 in sorted(subtoken_dic_dic.keys(), reverse=True): for length2 in sorted(subtoken_dic_dic[length1].keys(), reverse=True): input_ids_batches.extend(subtoken_dic_dic[length1][length2]) first_subtokens_batches.extend( first_subtoken_dic_dic[length1][length2]) last_subtokens_batches.extend( last_subtoken_dic_dic[length1][length2]) label_batches.extend(label_dic_dic[length1][length2]) [ this_input_ids_batches.append(input_ids_batches[i:i + batch_size]) for i in range(0, len(input_ids_batches), batch_size) ] [ this_first_subtokens_batches.append( first_subtokens_batches[i:i + batch_size]) for i in range(0, len(first_subtokens_batches), batch_size) ] [ this_last_subtokens_batches.append( last_subtokens_batches[i:i + batch_size]) for i in range(0, len(last_subtokens_batches), batch_size) ] [ this_label_batches.append(label_batches[i:i + batch_size]) for i in range(0, len(label_batches), batch_size) ] this_input_ids_batches, this_input_mask_batches, this_mask_batches \ = self._pad_batches(this_input_ids_batches, this_first_subtokens_batches) return (this_input_ids_batches, this_input_mask_batches, this_first_subtokens_batches, this_last_subtokens_batches, this_label_batches, this_mask_batches) def to_batch(self, batch_size: int) -> Tuple: ret_list = [] for sent_list in [self.train, self.dev, self.test]: ret_list.append(self.get_batches(sent_list, batch_size)) return tuple(ret_list) def read_all_data(self, file_path: str, train_file: str, dev_file: str, test_file: str) -> None: self.train = self._read_file(file_path + train_file) self.dev = self._read_file(file_path + dev_file, mode='dev') self.test = self._read_file(file_path + test_file, mode='test') self._gen_dic() def debug_single_sample(self, subtoken: List[int], label_list: List[Tuple[int, int, int]]) -> None: print(" ".join( [self.subword_alphabet.get_instance(t) for t in subtoken])) for label in label_list: print(label[0], label[1], self.label_alphabet.get_instance(label[2]))