예제 #1
0
    metrics_path = os.path.join('results', 'metrics.pickle')
    if not os.path.exists(metrics_path):
        return (None, None, None, None)

    with open(metrics_path, 'rb') as metrics_handle:
        metrics_obj = pickle.load(metrics_handle)
        
    return (metrics_obj['token_pairs'],
            metrics_obj['decoded_pairs'],
            metrics_obj['jaccard_similarities'],
            metrics_obj['levenshtein_distances'])
        
token_pairs, decoded_pairs, jaccard_similarities, levenshtein_distances = load_metrics_obj()

if not token_pairs:
    token_pairs = [([tokenizer.id_to_token(x) for x in ocr_tokens[i]], [tokenizer.id_to_token(x) for x in gs_tokens[i]]) for i in range(len(ocr_tokens))]
    save_metrics_obj(token_pairs, decoded_pairs, jaccard_similarities, levenshtein_distances)
    
if not decoded_pairs:
    decoded_pairs = [(tokenizer.decode(ocr_tokens[i]), tokenizer.decode(gs_tokens[i])) for i in range(len(ocr_tokens))]
    save_metrics_obj(token_pairs, decoded_pairs, jaccard_similarities, levenshtein_distances)
    
all_pairs = len(token_pairs)
if not jaccard_similarities:
    jaccard_similarities = []
    for i, token_pair in enumerate(token_pairs):
        jaccard_similarities.append(calculate_jaccard_similarity(token_pair[0], token_pair[1]))
    
    save_metrics_obj(token_pairs, decoded_pairs, jaccard_similarities, levenshtein_distances)
    
if not levenshtein_distances:
예제 #2
0
class Tokenizer(object):
    def __init__(self, vocab_path, do_lower_case=True):
        if BertWordPieceTokenizer:
            self.tokenizer = BertWordPieceTokenizer(
                vocab_path,
                lowercase=do_lower_case,
            )
        else:
            self.tokenizer = tokenization.FullTokenizer(
                vocab_path, do_lower_case=do_lower_case)
        self._do_lower_case = do_lower_case

    def tokenize(self, input_text):
        if BertWordPieceTokenizer:
            return self.tokenizer.encode(input_text,
                                         add_special_tokens=False).tokens
        else:
            return self.tokenizer.tokenize(input_text)

    def encode(self, input_text, add_special_tokens=False):
        input_tokens = self.tokenize(input_text)
        if add_special_tokens:
            input_tokens = ['[CLS]'] + input_tokens + ['[SEP]']
        input_token_ids = self.convert_tokens_to_ids(input_tokens)
        return input_token_ids

    def padded_to_ids(self, input_text, max_length):
        if len(input_text) > max_length:
            return input_text[:max_length]
        else:
            return input_text + [0] * (max_length - len(input_text))

    def convert_tokens_to_ids(self, input_tokens):
        if BertWordPieceTokenizer:
            token_ids = [
                self.tokenizer.token_to_id(token) for token in input_tokens
            ]
        else:
            token_ids = self.tokenizer.convert_tokens_to_ids(input_tokens)
        return token_ids

    def customize_tokenize(self, input_text):
        temp_x = ""
        for c in input_text:
            if self._is_cjk_character(c) or self._is_punctuation(
                    c) or self._is_space(c) or self._is_control(c):
                temp_x += " " + c + " "
            else:
                temp_x += c
        return temp_x.split()

    def convert_ids_to_tokens(self, input_ids):
        if BertWordPieceTokenizer:
            input_tokens = [
                self.tokenizer.id_to_token(ids) for ids in input_ids
            ]
        else:
            input_tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
        return input_tokens

    def decode(self, input_tokens):
        text, flag = '', False
        for i, token in enumerate(input_tokens):
            if token[:2] == '##':
                text += token[2:]
            elif len(token) == 1 and self._is_cjk_character(token):
                text += token
            elif len(token) == 1 and self._is_punctuation(token):
                text += token
                text += ' '
            elif i > 0 and self._is_cjk_character(text[-1]):
                text += token
            else:
                text += ' '
                text += token
        text = re.sub(' +', ' ', text)
        text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text)
        punctuation = self._cjk_punctuation() + '+-/={(<['
        punctuation_regex = '|'.join([re.escape(p) for p in punctuation])
        punctuation_regex = '(%s) ' % punctuation_regex
        text = re.sub(punctuation_regex, '\\1', text)
        text = re.sub('(\d\.) (\d)', '\\1\\2', text)

        return text.strip()

    @staticmethod
    def stem(token):
        """
    """
        if token[:2] == '##':
            return token[2:]
        else:
            return token

    @staticmethod
    def _is_space(ch):
        """
    """
        return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \
            unicodedata.category(ch) == 'Zs'

    @staticmethod
    def _is_punctuation(ch):
        """
    """
        code = ord(ch)
        return 33 <= code <= 47 or \
            58 <= code <= 64 or \
            91 <= code <= 96 or \
            123 <= code <= 126 or \
            unicodedata.category(ch).startswith('P')

    @staticmethod
    def _cjk_punctuation():
        return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002'

    @staticmethod
    def _is_cjk_character(ch):
        """
    reference:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
    """
        code = ord(ch)
        return 0x4E00 <= code <= 0x9FFF or \
            0x3400 <= code <= 0x4DBF or \
            0x20000 <= code <= 0x2A6DF or \
            0x2A700 <= code <= 0x2B73F or \
            0x2B740 <= code <= 0x2B81F or \
            0x2B820 <= code <= 0x2CEAF or \
            0xF900 <= code <= 0xFAFF or \
            0x2F800 <= code <= 0x2FA1F

    @staticmethod
    def _is_control(ch):
        """
    """
        return unicodedata.category(ch) in ('Cc', 'Cf')

    @staticmethod
    def _is_special(ch):
        """
    """
        return bool(ch) and (ch[0] == '[') and (ch[-1] == ']')

    def rematch(self, text, tokens):
        if is_py2:
            text = unicode(text)

        if self._do_lower_case:
            text = text.lower()

        normalized_text, char_mapping = '', []
        for i, ch in enumerate(text):
            if self._do_lower_case:
                ch = unicodedata.normalize('NFD', ch)
                ch = ''.join(
                    [c for c in ch if unicodedata.category(c) != 'Mn'])
            ch = ''.join([
                c for c in ch
                if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c))
            ])
            normalized_text += ch
            char_mapping.extend([i] * len(ch))

        text, token_mapping, offset = normalized_text, [], 0
        for token in tokens:
            if self._is_special(token):
                token_mapping.append([])
            else:
                token = self.stem(token)
                start = text[offset:].index(token) + offset
                end = start + len(token)
                token_mapping.append(char_mapping[start:end])
                offset = end

        return token_mapping
예제 #3
0
class Reader(object):
    def __init__(self,
                 bert_model: str,
                 tokenizer: BaseTokenizer = None,
                 cls: str = "[CLS]",
                 sep: str = "[SEP]",
                 threshold=6):

        self.tokenizer: BaseTokenizer = tokenizer
        self.cls = cls
        self.sep = sep
        if self.tokenizer is None:
            vocab_path: str = "tokenization/" + bert_model + ".txt"
            self.tokenizer = BertWordPieceTokenizer(vocab_path,
                                                    lowercase="-cased"
                                                    not in bert_model)

        self.threshold = threshold
        self.subword_alphabet: Optional[Alphabet] = None
        self.label_alphabet: Optional[Alphabet] = None

        self.train: Optional[List[SentInst]] = None
        self.dev: Optional[List[SentInst]] = None
        self.test: Optional[List[SentInst]] = None

    def _read_file(self, filename: str, mode: str = 'train') -> List[SentInst]:
        sent_list = []
        max_len = 0
        num_thresh = 0
        with open(filename, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line == "":  # last few blank lines
                    break

                raw_tokens = line.split(' ')
                tokens = raw_tokens
                chars = [list(t) for t in raw_tokens]

                entities = next(f).strip()
                if entities == "":  # no entities
                    sent_inst = SentInst(tokens, chars, [])
                else:
                    entity_list = []
                    entities = entities.split("|")
                    for item in entities:
                        pointers, label = item.split()
                        pointers = pointers.split(",")
                        if int(pointers[1]) > len(tokens):
                            pdb.set_trace()
                        span_len = int(pointers[1]) - int(pointers[0])
                        if span_len < 0:
                            print("Warning! span_len < 0")
                            continue
                        if span_len > max_len:
                            max_len = span_len
                        if span_len > self.threshold:
                            num_thresh += 1

                        new_entity = (int(pointers[0]), int(pointers[1]),
                                      label)
                        # may be duplicate entities in some datasets
                        if (mode == 'train' and new_entity
                                not in entity_list) or (mode != 'train'):
                            entity_list.append(new_entity)

                    # assert len(entity_list) == len(set(entity_list)) # check duplicate
                    sent_inst = SentInst(tokens, chars, entity_list)
                assert next(f).strip() == ""  # separating line

                sent_list.append(sent_inst)
        print("Max length: {}".format(max_len))
        print("Threshold {}: {}".format(self.threshold, num_thresh))
        return sent_list

    def _gen_dic(self) -> None:
        label_set = set()

        for sent_list in [self.train, self.dev, self.test]:
            num_mention = 0
            for sentInst in sent_list:
                for entity in sentInst.entities:
                    label_set.add(entity[2])
                num_mention += len(sentInst.entities)
            print("# mentions: {}".format(num_mention))

        vocab = [
            self.tokenizer.id_to_token(idx)
            for idx in range(self.tokenizer.get_vocab_size())
        ]
        self.subword_alphabet = Alphabet(vocab, 0)
        self.label_alphabet = Alphabet(label_set, 0)

    @staticmethod
    def _pad_batches(input_ids_batches: List[List[List[int]]],
                     first_subtokens_batches: List[List[List[int]]]) \
            -> Tuple[List[List[List[int]]],
                     List[List[List[int]]],
                     List[List[List[bool]]]]:

        padded_input_ids_batches = []
        input_mask_batches = []
        mask_batches = []

        all_batches = list(zip(input_ids_batches, first_subtokens_batches))
        for input_ids_batch, first_subtokens_batch in all_batches:

            batch_len = len(input_ids_batch)
            max_subtokens_num = max(
                [len(input_ids) for input_ids in input_ids_batch])
            max_sent_len = max([
                len(first_subtokens)
                for first_subtokens in first_subtokens_batch
            ])

            padded_input_ids_batch = []
            input_mask_batch = []
            mask_batch = []

            for i in range(batch_len):

                subtokens_num = len(input_ids_batch[i])
                sent_len = len(first_subtokens_batch[i])

                padded_subtoken_vec = input_ids_batch[i].copy()
                padded_subtoken_vec.extend([0] *
                                           (max_subtokens_num - subtokens_num))
                input_mask = [1] * subtokens_num + [0] * (max_subtokens_num -
                                                          subtokens_num)
                mask = [True] * sent_len + [False] * (max_sent_len - sent_len)

                padded_input_ids_batch.append(padded_subtoken_vec)
                input_mask_batch.append(input_mask)
                mask_batch.append(mask)

            padded_input_ids_batches.append(padded_input_ids_batch)
            input_mask_batches.append(input_mask_batch)
            mask_batches.append(mask_batch)

        return padded_input_ids_batches, input_mask_batches, mask_batches

    def get_batches(self, sentences: List[SentInst], batch_size: int) -> Tuple:
        subtoken_dic_dic = defaultdict(lambda: defaultdict(list))
        first_subtoken_dic_dic = defaultdict(lambda: defaultdict(list))
        last_subtoken_dic_dic = defaultdict(lambda: defaultdict(list))
        label_dic_dic = defaultdict(lambda: defaultdict(list))

        this_input_ids_batches = []
        this_first_subtokens_batches = []
        this_last_subtokens_batches = []
        this_label_batches = []

        for sentInst in sentences:
            subtoken_vec = []
            first_subtoken_vec = []
            last_subtoken_vec = []
            subtoken_vec.append(self.tokenizer.token_to_id(self.cls))
            for t in sentInst.tokens:
                encoding = self.tokenizer.encode(t)
                ids = [
                    v for v, mask in zip(encoding.ids,
                                         encoding.special_tokens_mask)
                    if mask == 0
                ]
                first_subtoken_vec.append(len(subtoken_vec))
                subtoken_vec.extend(ids)
                last_subtoken_vec.append(len(subtoken_vec))
            subtoken_vec.append(self.tokenizer.token_to_id(self.sep))

            label_list = [(u[0], u[1], self.label_alphabet.get_index(u[2]))
                          for u in sentInst.entities]

            subtoken_dic_dic[len(
                sentInst.tokens)][len(subtoken_vec)].append(subtoken_vec)
            first_subtoken_dic_dic[len(
                sentInst.tokens)][len(subtoken_vec)].append(first_subtoken_vec)
            last_subtoken_dic_dic[len(
                sentInst.tokens)][len(subtoken_vec)].append(last_subtoken_vec)
            label_dic_dic[len(
                sentInst.tokens)][len(subtoken_vec)].append(label_list)

        input_ids_batches = []
        first_subtokens_batches = []
        last_subtokens_batches = []
        label_batches = []
        for length1 in sorted(subtoken_dic_dic.keys(), reverse=True):
            for length2 in sorted(subtoken_dic_dic[length1].keys(),
                                  reverse=True):
                input_ids_batches.extend(subtoken_dic_dic[length1][length2])
                first_subtokens_batches.extend(
                    first_subtoken_dic_dic[length1][length2])
                last_subtokens_batches.extend(
                    last_subtoken_dic_dic[length1][length2])
                label_batches.extend(label_dic_dic[length1][length2])

        [
            this_input_ids_batches.append(input_ids_batches[i:i + batch_size])
            for i in range(0, len(input_ids_batches), batch_size)
        ]
        [
            this_first_subtokens_batches.append(
                first_subtokens_batches[i:i + batch_size])
            for i in range(0, len(first_subtokens_batches), batch_size)
        ]
        [
            this_last_subtokens_batches.append(
                last_subtokens_batches[i:i + batch_size])
            for i in range(0, len(last_subtokens_batches), batch_size)
        ]
        [
            this_label_batches.append(label_batches[i:i + batch_size])
            for i in range(0, len(label_batches), batch_size)
        ]

        this_input_ids_batches, this_input_mask_batches, this_mask_batches \
            = self._pad_batches(this_input_ids_batches, this_first_subtokens_batches)

        return (this_input_ids_batches, this_input_mask_batches,
                this_first_subtokens_batches, this_last_subtokens_batches,
                this_label_batches, this_mask_batches)

    def to_batch(self, batch_size: int) -> Tuple:
        ret_list = []
        for sent_list in [self.train, self.dev, self.test]:
            ret_list.append(self.get_batches(sent_list, batch_size))
        return tuple(ret_list)

    def read_all_data(self, file_path: str, train_file: str, dev_file: str,
                      test_file: str) -> None:
        self.train = self._read_file(file_path + train_file)
        self.dev = self._read_file(file_path + dev_file, mode='dev')
        self.test = self._read_file(file_path + test_file, mode='test')
        self._gen_dic()

    def debug_single_sample(self, subtoken: List[int],
                            label_list: List[Tuple[int, int, int]]) -> None:
        print(" ".join(
            [self.subword_alphabet.get_instance(t) for t in subtoken]))
        for label in label_list:
            print(label[0], label[1],
                  self.label_alphabet.get_instance(label[2]))