Esempio n. 1
0
    def gen_dic(self) -> None:
        label_set = set()

        for sent_list in [self.train, self.dev, self.test]:
            num_mention = 0
            for sentInst in sent_list:
                for entity in sentInst.entities:
                    label_set.add(entity[2])
                num_mention += len(sentInst.entities)
            print("# mentions: {}".format(num_mention))

        self.sub_word_alphabet = Alphabet(self.bert_tokenizer.vocab, 0)
        self.label_alphabet = Alphabet(label_set, 0)
Esempio n. 2
0
    def _gen_dic(self) -> None:
        label_set = set()

        for sent_list in [self.train, self.dev, self.test]:
            num_mention = 0
            for sentInst in sent_list:
                for entity in sentInst.entities:
                    label_set.add(entity[2])
                num_mention += len(sentInst.entities)
            print("# mentions: {}".format(num_mention))

        vocab = [
            self.tokenizer.id_to_token(idx)
            for idx in range(self.tokenizer.get_vocab_size())
        ]
        self.subword_alphabet = Alphabet(vocab, 0)
        self.label_alphabet = Alphabet(label_set, 0)
Esempio n. 3
0
    def _gen_dic(self) -> None:
        word_set = set()
        char_set = set()
        label_set = set()

        for sent_list in [self.train, self.dev, self.test]:
            num_mention = 0
            for sentInst in sent_list:
                if sent_list is self.train:
                    for token in sentInst.chars:
                        for char in token:
                            char_set.add(char)
                    for token in sentInst.tokens:
                        if self.lowercase:
                            token = token.lower()
                        if token not in self.vocab2id:
                            word_set.add(token)
                for entity in sentInst.entities:
                    label_set.add(entity[2])
                num_mention += len(sentInst.entities)
            print("# mentions: {}".format(num_mention))

        self.word_iv_alphabet = Alphabet(self.vocab2id,
                                         len(PREDIFINE_TOKEN_IDS))
        self.word_ooev_alphabet = Alphabet(word_set, len(PREDIFINE_TOKEN_IDS))
        self.char_alphabet = Alphabet(char_set, len(PREDIFINE_CHAR_IDS))
        self.label_alphabet = Alphabet(label_set, 0)
Esempio n. 4
0
class Reader:
    def __init__(self) -> None:

        self.vocab2id: Dict[str, int] = {}
        self.lowercase: Optional[bool] = None

        self.word_iv_alphabet: Optional[Alphabet] = None
        self.word_ooev_alphabet: Optional[Alphabet] = None
        self.char_alphabet: Optional[Alphabet] = None
        self.label_alphabet: Optional[Alphabet] = None

        self.train: Optional[List[SentInst]] = None
        self.dev: Optional[List[SentInst]] = None
        self.test: Optional[List[SentInst]] = None

    def read_and_gen_vectors_glove(self, embed_path: str) -> None:
        token_embed = None
        ret_mat = []
        with open(GLOVE_FILE, 'r') as f:
            id = 0
            for line in f:
                s_s = line.split()
                if token_embed is None:
                    token_embed = len(s_s) - 1
                    ret_mat.append(np.zeros(token_embed).astype('float32'))
                else:
                    assert (token_embed + 1 == len(s_s))
                id += 1
                self.vocab2id[s_s[0]] = id
                ret_mat.append(np.array([float(x) for x in s_s[1:]]))

        self.lowercase = True

        ret_mat = np.array(ret_mat)
        with open(embed_path, 'wb') as f:
            pickle.dump(ret_mat, f)

    def read_and_gen_vectors_pubmed_word2vec(self, embed_path: str) -> None:
        ret_mat = []
        with open(PUBMED_WORD2VEC_FILE, 'rb') as f:
            line = f.readline().rstrip(b'\n')
            vsize, token_embed = line.split()
            vsize = int(vsize)
            token_embed = int(token_embed)
            id = 0
            ret_mat.append(np.zeros(token_embed).astype('float32'))

            for v in range(vsize):
                wchars = []
                while True:
                    c = f.read(1)
                    if c == b' ':
                        break
                    assert (c is not None)
                    wchars.append(c)
                word = b''.join(wchars)
                if word.startswith(b'\n'):
                    word = word[1:]
                id += 1
                self.vocab2id[word.decode('utf-8')] = id
                ret_mat.append(np.fromfile(f, np.float32, token_embed))
            assert (vsize + 1 == len(ret_mat))

        self.lowercase = False

        ret_mat = np.array(ret_mat)
        with open(embed_path, 'wb') as f:
            pickle.dump(ret_mat, f)

    @staticmethod
    def _read_file(filename: str, mode: str = 'train') -> List[SentInst]:
        sent_list = []
        max_len = 0
        num_thresh = 0
        with open(filename) as f:
            for line in f:
                line = line.strip()
                if line == "":  # last few blank lines
                    break

                raw_tokens = line.split(' ')
                tokens = raw_tokens
                chars = [list(t) for t in raw_tokens]

                entities = next(f).strip()
                if entities == "":  # no entities
                    sent_inst = SentInst(tokens, chars, [])
                else:
                    entity_list = []
                    entities = entities.split("|")
                    for item in entities:
                        pointers, label = item.split()
                        pointers = pointers.split(",")
                        if int(pointers[1]) > len(tokens):
                            pdb.set_trace()
                        span_len = int(pointers[1]) - int(pointers[0])
                        assert (span_len > 0)
                        if span_len > max_len:
                            max_len = span_len
                        if span_len > 6:
                            num_thresh += 1

                        new_entity = (int(pointers[0]), int(pointers[1]),
                                      label)
                        # may be duplicate entities in some datasets
                        if (mode == 'train' and new_entity
                                not in entity_list) or (mode != 'train'):
                            entity_list.append(new_entity)

                    # assert len(entity_list) == len(set(entity_list)) # check duplicate
                    sent_inst = SentInst(tokens, chars, entity_list)
                assert next(f).strip() == ""  # separating line

                sent_list.append(sent_inst)
        print("Max length: {}".format(max_len))
        print("Threshold 6: {}".format(num_thresh))
        return sent_list

    def _gen_dic(self) -> None:
        word_set = set()
        char_set = set()
        label_set = set()

        for sent_list in [self.train, self.dev, self.test]:
            num_mention = 0
            for sentInst in sent_list:
                if sent_list is self.train:
                    for token in sentInst.chars:
                        for char in token:
                            char_set.add(char)
                    for token in sentInst.tokens:
                        if self.lowercase:
                            token = token.lower()
                        if token not in self.vocab2id:
                            word_set.add(token)
                for entity in sentInst.entities:
                    label_set.add(entity[2])
                num_mention += len(sentInst.entities)
            print("# mentions: {}".format(num_mention))

        self.word_iv_alphabet = Alphabet(self.vocab2id,
                                         len(PREDIFINE_TOKEN_IDS))
        self.word_ooev_alphabet = Alphabet(word_set, len(PREDIFINE_TOKEN_IDS))
        self.char_alphabet = Alphabet(char_set, len(PREDIFINE_CHAR_IDS))
        self.label_alphabet = Alphabet(label_set, 0)

    @staticmethod
    def _pad_batches(token_iv_batches: List[List[List[int]]],
                     token_ooev_batches: List[List[List[int]]],
                     char_batches: List[List[List[List[int]]]]) \
            -> Tuple[List[List[List[int]]],
                     List[List[List[int]]],
                     List[List[List[List[int]]]],
                     List[List[List[bool]]]]:

        default_token_id = PREDIFINE_TOKEN_IDS['DEFAULT']
        default_char_id = PREDIFINE_CHAR_IDS['DEFAULT']
        bot_id = PREDIFINE_CHAR_IDS['BOT']  # beginning of token
        eot_id = PREDIFINE_CHAR_IDS['EOT']  # end of token

        padded_token_iv_batches = []
        padded_token_ooev_batches = []
        padded_char_batches = []
        mask_batches = []

        all_batches = list(
            zip(token_iv_batches, token_ooev_batches, char_batches))
        for token_iv_batch, token_ooev_batch, char_batch in all_batches:

            batch_len = len(token_iv_batch)
            max_sent_len = len(token_iv_batch[0])
            max_char_len = max(
                [max([len(t) for t in char_mat]) for char_mat in char_batch])

            padded_token_iv_batch = []
            padded_token_ooev_batch = []
            padded_char_batch = []
            mask_batch = []

            for i in range(batch_len):

                sent_len = len(token_iv_batch[i])

                padded_token_iv_vec = token_iv_batch[i].copy()
                padded_token_iv_vec.extend([default_token_id] *
                                           (max_sent_len - sent_len))
                padded_token_ooev_vec = token_ooev_batch[i].copy()
                padded_token_ooev_vec.extend([default_token_id] *
                                             (max_sent_len - sent_len))
                padded_char_mat = []
                for t in char_batch[i]:
                    padded_t = list()
                    padded_t.append(bot_id)
                    padded_t.extend(t)
                    padded_t.append(eot_id)
                    padded_t.extend([default_char_id] *
                                    (max_char_len - len(t)))
                    padded_char_mat.append(padded_t)
                for t in range(sent_len, max_sent_len):
                    padded_char_mat.append(
                        [default_char_id] *
                        (max_char_len + 2))  # max_len + bot + eot
                mask = [True] * sent_len + [False] * (max_sent_len - sent_len)

                padded_token_iv_batch.append(padded_token_iv_vec)
                padded_token_ooev_batch.append(padded_token_ooev_vec)
                padded_char_batch.append(padded_char_mat)
                mask_batch.append(mask)

            padded_token_iv_batches.append(padded_token_iv_batch)
            padded_token_ooev_batches.append(padded_token_ooev_batch)
            padded_char_batches.append(padded_char_batch)
            mask_batches.append(mask_batch)

        return padded_token_iv_batches, padded_token_ooev_batches, padded_char_batches, mask_batches

    def to_batch(self, batch_size: int) -> Tuple:
        ret_list = []

        for sent_list in [self.train, self.dev, self.test]:
            token_iv_dic = defaultdict(list)
            token_ooev_dic = defaultdict(list)
            char_dic = defaultdict(list)
            label_dic = defaultdict(list)

            this_token_iv_batches = []
            this_token_ooev_batches = []
            this_char_batches = []
            this_label_batches = []

            for sentInst in sent_list:

                token_iv_vec = []
                token_ooev_vec = []
                for t in sentInst.tokens:
                    if self.lowercase:
                        t = t.lower()
                    if t in self.vocab2id:
                        token_iv_vec.append(self.vocab2id[t])
                        token_ooev_vec.append(0)
                    else:
                        token_iv_vec.append(0)
                        token_ooev_vec.append(
                            self.word_ooev_alphabet.get_index(t))

                char_mat = [[self.char_alphabet.get_index(c) for c in t]
                            for t in sentInst.chars]
                # max_len = max([len(t) for t in sentInst.chars])
                # char_mat = [ t + [0] * (max_len - len(t)) for t in char_mat ]

                label_list = [(u[0], u[1], self.label_alphabet.get_index(u[2]))
                              for u in sentInst.entities]

                token_iv_dic[len(sentInst.tokens)].append(token_iv_vec)
                token_ooev_dic[len(sentInst.tokens)].append(token_ooev_vec)
                char_dic[len(sentInst.tokens)].append(char_mat)
                label_dic[len(sentInst.tokens)].append(label_list)

            token_iv_batches = []
            token_ooev_batches = []
            char_batches = []
            label_batches = []
            for length in sorted(token_iv_dic.keys(), reverse=True):
                token_iv_batches.extend(token_iv_dic[length])
                token_ooev_batches.extend(token_ooev_dic[length])
                char_batches.extend(char_dic[length])
                label_batches.extend(label_dic[length])

            [
                this_token_iv_batches.append(token_iv_batches[i:i +
                                                              batch_size])
                for i in range(0, len(token_iv_batches), batch_size)
            ]
            [
                this_token_ooev_batches.append(token_ooev_batches[i:i +
                                                                  batch_size])
                for i in range(0, len(token_ooev_batches), batch_size)
            ]
            [
                this_char_batches.append(char_batches[i:i + batch_size])
                for i in range(0, len(char_batches), batch_size)
            ]
            [
                this_label_batches.append(label_batches[i:i + batch_size])
                for i in range(0, len(label_batches), batch_size)
            ]

            this_token_iv_batches, this_token_ooev_batches, this_char_batches, this_mask_batches \
                = self._pad_batches(this_token_iv_batches, this_token_ooev_batches, this_char_batches)

            ret_list.append(
                (this_token_iv_batches, this_token_ooev_batches,
                 this_char_batches, this_label_batches, this_mask_batches))

        return tuple(ret_list)

    def read_all_data(self, file_path: str, train_file: str, dev_file: str,
                      test_file: str) -> None:
        self.train = self._read_file(file_path + train_file)
        self.dev = self._read_file(file_path + dev_file, mode='dev')
        self.test = self._read_file(file_path + test_file, mode='test')
        self._gen_dic()

    def debug_single_sample(self, token_v: List[int],
                            char_mat: List[List[int]], char_len_vec: List[int],
                            label_list: List[Tuple[int, int, int]]) -> None:
        print(" ".join(
            [self.word_ooev_alphabet.get_instance(t) for t in token_v]))
        for t in char_mat:
            print(" ".join([self.char_alphabet.get_instance(c) for c in t]))
        print(char_len_vec)
        for label in label_list:
            print(label[0], label[1],
                  self.label_alphabet.get_instance(label[2]))
Esempio n. 5
0
class Reader:
    def __init__(self, bert_model: str) -> None:

        self.bert_tokenizer: BertTokenizer = BertTokenizer.from_pretrained(
            bert_model, do_lower_case='-cased' not in bert_model)

        self.sub_word_alphabet: Alphabet = None
        self.label_alphabet: Alphabet = None

        self.train: List[SentInst] = None
        self.dev: List[SentInst] = None
        self.test: List[SentInst] = None

    @staticmethod
    def read_file(filename: str, mode: str = 'train') -> List[SentInst]:
        sent_list = []
        max_len = 0
        num_thresh = 0
        with open(filename) as f:
            for line in f:
                line = line.strip()
                if line == "":  # last few blank lines
                    break

                raw_tokens = line.split()
                tokens = raw_tokens
                chars = [list(t) for t in raw_tokens]

                entities = next(f).strip()
                if entities == "":  # no entities
                    sent_inst = SentInst(tokens, chars, [])
                else:
                    entity_list = []
                    entities = entities.split("|")
                    for item in entities:
                        pointers, label = item.split()
                        pointers = pointers.split(",")
                        if int(pointers[1]) > len(tokens):
                            pdb.set_trace()
                        span_len = int(pointers[1]) - int(pointers[0])
                        assert (span_len > 0)
                        if span_len > max_len:
                            max_len = span_len
                        if span_len > 6:
                            num_thresh += 1

                        new_entity = (int(pointers[0]), int(pointers[1]),
                                      label)
                        # may be duplicate entities in some datasets
                        if (mode == 'train' and new_entity
                                not in entity_list) or (mode != 'train'):
                            entity_list.append(new_entity)

                    # assert len(entity_list) == len(set(entity_list)) # check duplicate
                    sent_inst = SentInst(tokens, chars, entity_list)
                assert next(f).strip() == ""  # separating line

                sent_list.append(sent_inst)
        print("Max length: {}".format(max_len))
        print("Threshold 6: {}".format(num_thresh))
        return sent_list

    def gen_dic(self) -> None:
        label_set = set()

        for sent_list in [self.train, self.dev, self.test]:
            num_mention = 0
            for sentInst in sent_list:
                for entity in sentInst.entities:
                    label_set.add(entity[2])
                num_mention += len(sentInst.entities)
            print("# mentions: {}".format(num_mention))

        self.sub_word_alphabet = Alphabet(self.bert_tokenizer.vocab, 0)
        self.label_alphabet = Alphabet(label_set, 0)

    @staticmethod
    def pad_batches(input_ids_batches: List[List[List[int]]],
                    first_sub_tokens_batches: List[List[List[int]]]) \
            -> Tuple[List[List[List[int]]],
                     List[List[List[int]]],
                     List[List[List[bool]]]]:

        padded_input_ids_batches = []
        input_mask_batches = []
        mask_batches = []

        all_batches = list(zip(input_ids_batches, first_sub_tokens_batches))
        for input_ids_batch, first_sub_tokens_batch in all_batches:

            batch_len = len(input_ids_batch)
            max_sub_tokens_num = max(
                [len(input_ids) for input_ids in input_ids_batch])
            max_sent_len = max([
                len(first_sub_tokens)
                for first_sub_tokens in first_sub_tokens_batch
            ])

            padded_input_ids_batch = []
            input_mask_batch = []
            mask_batch = []

            for i in range(batch_len):

                sub_tokens_num = len(input_ids_batch[i])
                sent_len = len(first_sub_tokens_batch[i])

                padded_sub_token_vec = input_ids_batch[i].copy()
                padded_sub_token_vec.extend(
                    [0] * (max_sub_tokens_num - sub_tokens_num))
                input_mask = [1] * sub_tokens_num + [0] * (max_sub_tokens_num -
                                                           sub_tokens_num)
                mask = [True] * sent_len + [False] * (max_sent_len - sent_len)

                padded_input_ids_batch.append(padded_sub_token_vec)
                input_mask_batch.append(input_mask)
                mask_batch.append(mask)

            padded_input_ids_batches.append(padded_input_ids_batch)
            input_mask_batches.append(input_mask_batch)
            mask_batches.append(mask_batch)

        return padded_input_ids_batches, input_mask_batches, mask_batches

    def to_batch(self, batch_size: int) -> Tuple:
        bert_tokenizer = self.bert_tokenizer

        ret_list = []

        for sent_list in [self.train, self.dev, self.test]:
            sub_token_dic_dic = defaultdict(lambda: defaultdict(list))
            first_sub_token_dic_dic = defaultdict(lambda: defaultdict(list))
            label_dic_dic = defaultdict(lambda: defaultdict(list))

            this_input_ids_batches = []
            this_first_sub_tokens_batches = []
            this_label_batches = []

            for sentInst in sent_list:

                sub_token_vec = []
                first_sub_token_vec = []
                sub_token_vec.extend(
                    bert_tokenizer.convert_tokens_to_ids([CLS]))
                for t in sentInst.tokens:
                    st = bert_tokenizer.tokenize(t)
                    first_sub_token_vec.append(len(sub_token_vec))
                    sub_token_vec.extend(
                        bert_tokenizer.convert_tokens_to_ids(st))
                sub_token_vec.extend(
                    bert_tokenizer.convert_tokens_to_ids([SEP]))

                label_list = [(u[0], u[1], self.label_alphabet.get_index(u[2]))
                              for u in sentInst.entities]

                sub_token_dic_dic[len(
                    sentInst.tokens)][len(sub_token_vec)].append(sub_token_vec)
                first_sub_token_dic_dic[len(sentInst.tokens)][len(
                    sub_token_vec)].append(first_sub_token_vec)
                label_dic_dic[len(
                    sentInst.tokens)][len(sub_token_vec)].append(label_list)

            input_ids_batches = []
            first_sub_tokens_batches = []
            label_batches = []
            for length1 in sorted(sub_token_dic_dic.keys(), reverse=True):
                for length2 in sorted(sub_token_dic_dic[length1].keys(),
                                      reverse=True):
                    input_ids_batches.extend(
                        sub_token_dic_dic[length1][length2])
                    first_sub_tokens_batches.extend(
                        first_sub_token_dic_dic[length1][length2])
                    label_batches.extend(label_dic_dic[length1][length2])

            [
                this_input_ids_batches.append(input_ids_batches[i:i +
                                                                batch_size])
                for i in range(0, len(input_ids_batches), batch_size)
            ]
            [
                this_first_sub_tokens_batches.append(
                    first_sub_tokens_batches[i:i + batch_size])
                for i in range(0, len(first_sub_tokens_batches), batch_size)
            ]
            [
                this_label_batches.append(label_batches[i:i + batch_size])
                for i in range(0, len(label_batches), batch_size)
            ]

            this_input_ids_batches, this_input_mask_batches, this_mask_batches \
                = self.pad_batches(this_input_ids_batches, this_first_sub_tokens_batches)

            ret_list.append((this_input_ids_batches, this_input_mask_batches,
                             this_first_sub_tokens_batches, this_label_batches,
                             this_mask_batches))

        return tuple(ret_list)

    def read_all_data(self, file_path: str, train_file: str, dev_file: str,
                      test_file: str) -> None:
        self.train = self.read_file(file_path + train_file)
        self.dev = self.read_file(file_path + dev_file, mode='dev')
        self.test = self.read_file(file_path + test_file, mode='test')
        self.gen_dic()

    def debug_single_sample(self, sub_token: List[int],
                            label_list: List[Tuple[int, int, int]]) -> None:
        print(" ".join(
            [self.sub_word_alphabet.get_instance(t) for t in sub_token]))
        for label in label_list:
            print(label[0], label[1],
                  self.label_alphabet.get_instance(label[2]))
Esempio n. 6
0
class Reader(object):
    def __init__(self,
                 bert_model: str,
                 tokenizer: BaseTokenizer = None,
                 cls: str = "[CLS]",
                 sep: str = "[SEP]",
                 threshold=6):

        self.tokenizer: BaseTokenizer = tokenizer
        self.cls = cls
        self.sep = sep
        if self.tokenizer is None:
            vocab_path: str = "tokenization/" + bert_model + ".txt"
            self.tokenizer = BertWordPieceTokenizer(vocab_path,
                                                    lowercase="-cased"
                                                    not in bert_model)

        self.threshold = threshold
        self.subword_alphabet: Optional[Alphabet] = None
        self.label_alphabet: Optional[Alphabet] = None

        self.train: Optional[List[SentInst]] = None
        self.dev: Optional[List[SentInst]] = None
        self.test: Optional[List[SentInst]] = None

    def _read_file(self, filename: str, mode: str = 'train') -> List[SentInst]:
        sent_list = []
        max_len = 0
        num_thresh = 0
        with open(filename, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line == "":  # last few blank lines
                    break

                raw_tokens = line.split(' ')
                tokens = raw_tokens
                chars = [list(t) for t in raw_tokens]

                entities = next(f).strip()
                if entities == "":  # no entities
                    sent_inst = SentInst(tokens, chars, [])
                else:
                    entity_list = []
                    entities = entities.split("|")
                    for item in entities:
                        pointers, label = item.split()
                        pointers = pointers.split(",")
                        if int(pointers[1]) > len(tokens):
                            pdb.set_trace()
                        span_len = int(pointers[1]) - int(pointers[0])
                        if span_len < 0:
                            print("Warning! span_len < 0")
                            continue
                        if span_len > max_len:
                            max_len = span_len
                        if span_len > self.threshold:
                            num_thresh += 1

                        new_entity = (int(pointers[0]), int(pointers[1]),
                                      label)
                        # may be duplicate entities in some datasets
                        if (mode == 'train' and new_entity
                                not in entity_list) or (mode != 'train'):
                            entity_list.append(new_entity)

                    # assert len(entity_list) == len(set(entity_list)) # check duplicate
                    sent_inst = SentInst(tokens, chars, entity_list)
                assert next(f).strip() == ""  # separating line

                sent_list.append(sent_inst)
        print("Max length: {}".format(max_len))
        print("Threshold {}: {}".format(self.threshold, num_thresh))
        return sent_list

    def _gen_dic(self) -> None:
        label_set = set()

        for sent_list in [self.train, self.dev, self.test]:
            num_mention = 0
            for sentInst in sent_list:
                for entity in sentInst.entities:
                    label_set.add(entity[2])
                num_mention += len(sentInst.entities)
            print("# mentions: {}".format(num_mention))

        vocab = [
            self.tokenizer.id_to_token(idx)
            for idx in range(self.tokenizer.get_vocab_size())
        ]
        self.subword_alphabet = Alphabet(vocab, 0)
        self.label_alphabet = Alphabet(label_set, 0)

    @staticmethod
    def _pad_batches(input_ids_batches: List[List[List[int]]],
                     first_subtokens_batches: List[List[List[int]]]) \
            -> Tuple[List[List[List[int]]],
                     List[List[List[int]]],
                     List[List[List[bool]]]]:

        padded_input_ids_batches = []
        input_mask_batches = []
        mask_batches = []

        all_batches = list(zip(input_ids_batches, first_subtokens_batches))
        for input_ids_batch, first_subtokens_batch in all_batches:

            batch_len = len(input_ids_batch)
            max_subtokens_num = max(
                [len(input_ids) for input_ids in input_ids_batch])
            max_sent_len = max([
                len(first_subtokens)
                for first_subtokens in first_subtokens_batch
            ])

            padded_input_ids_batch = []
            input_mask_batch = []
            mask_batch = []

            for i in range(batch_len):

                subtokens_num = len(input_ids_batch[i])
                sent_len = len(first_subtokens_batch[i])

                padded_subtoken_vec = input_ids_batch[i].copy()
                padded_subtoken_vec.extend([0] *
                                           (max_subtokens_num - subtokens_num))
                input_mask = [1] * subtokens_num + [0] * (max_subtokens_num -
                                                          subtokens_num)
                mask = [True] * sent_len + [False] * (max_sent_len - sent_len)

                padded_input_ids_batch.append(padded_subtoken_vec)
                input_mask_batch.append(input_mask)
                mask_batch.append(mask)

            padded_input_ids_batches.append(padded_input_ids_batch)
            input_mask_batches.append(input_mask_batch)
            mask_batches.append(mask_batch)

        return padded_input_ids_batches, input_mask_batches, mask_batches

    def get_batches(self, sentences: List[SentInst], batch_size: int) -> Tuple:
        subtoken_dic_dic = defaultdict(lambda: defaultdict(list))
        first_subtoken_dic_dic = defaultdict(lambda: defaultdict(list))
        last_subtoken_dic_dic = defaultdict(lambda: defaultdict(list))
        label_dic_dic = defaultdict(lambda: defaultdict(list))

        this_input_ids_batches = []
        this_first_subtokens_batches = []
        this_last_subtokens_batches = []
        this_label_batches = []

        for sentInst in sentences:
            subtoken_vec = []
            first_subtoken_vec = []
            last_subtoken_vec = []
            subtoken_vec.append(self.tokenizer.token_to_id(self.cls))
            for t in sentInst.tokens:
                encoding = self.tokenizer.encode(t)
                ids = [
                    v for v, mask in zip(encoding.ids,
                                         encoding.special_tokens_mask)
                    if mask == 0
                ]
                first_subtoken_vec.append(len(subtoken_vec))
                subtoken_vec.extend(ids)
                last_subtoken_vec.append(len(subtoken_vec))
            subtoken_vec.append(self.tokenizer.token_to_id(self.sep))

            label_list = [(u[0], u[1], self.label_alphabet.get_index(u[2]))
                          for u in sentInst.entities]

            subtoken_dic_dic[len(
                sentInst.tokens)][len(subtoken_vec)].append(subtoken_vec)
            first_subtoken_dic_dic[len(
                sentInst.tokens)][len(subtoken_vec)].append(first_subtoken_vec)
            last_subtoken_dic_dic[len(
                sentInst.tokens)][len(subtoken_vec)].append(last_subtoken_vec)
            label_dic_dic[len(
                sentInst.tokens)][len(subtoken_vec)].append(label_list)

        input_ids_batches = []
        first_subtokens_batches = []
        last_subtokens_batches = []
        label_batches = []
        for length1 in sorted(subtoken_dic_dic.keys(), reverse=True):
            for length2 in sorted(subtoken_dic_dic[length1].keys(),
                                  reverse=True):
                input_ids_batches.extend(subtoken_dic_dic[length1][length2])
                first_subtokens_batches.extend(
                    first_subtoken_dic_dic[length1][length2])
                last_subtokens_batches.extend(
                    last_subtoken_dic_dic[length1][length2])
                label_batches.extend(label_dic_dic[length1][length2])

        [
            this_input_ids_batches.append(input_ids_batches[i:i + batch_size])
            for i in range(0, len(input_ids_batches), batch_size)
        ]
        [
            this_first_subtokens_batches.append(
                first_subtokens_batches[i:i + batch_size])
            for i in range(0, len(first_subtokens_batches), batch_size)
        ]
        [
            this_last_subtokens_batches.append(
                last_subtokens_batches[i:i + batch_size])
            for i in range(0, len(last_subtokens_batches), batch_size)
        ]
        [
            this_label_batches.append(label_batches[i:i + batch_size])
            for i in range(0, len(label_batches), batch_size)
        ]

        this_input_ids_batches, this_input_mask_batches, this_mask_batches \
            = self._pad_batches(this_input_ids_batches, this_first_subtokens_batches)

        return (this_input_ids_batches, this_input_mask_batches,
                this_first_subtokens_batches, this_last_subtokens_batches,
                this_label_batches, this_mask_batches)

    def to_batch(self, batch_size: int) -> Tuple:
        ret_list = []
        for sent_list in [self.train, self.dev, self.test]:
            ret_list.append(self.get_batches(sent_list, batch_size))
        return tuple(ret_list)

    def read_all_data(self, file_path: str, train_file: str, dev_file: str,
                      test_file: str) -> None:
        self.train = self._read_file(file_path + train_file)
        self.dev = self._read_file(file_path + dev_file, mode='dev')
        self.test = self._read_file(file_path + test_file, mode='test')
        self._gen_dic()

    def debug_single_sample(self, subtoken: List[int],
                            label_list: List[Tuple[int, int, int]]) -> None:
        print(" ".join(
            [self.subword_alphabet.get_instance(t) for t in subtoken]))
        for label in label_list:
            print(label[0], label[1],
                  self.label_alphabet.get_instance(label[2]))