def gen_dic(self) -> None: label_set = set() for sent_list in [self.train, self.dev, self.test]: num_mention = 0 for sentInst in sent_list: for entity in sentInst.entities: label_set.add(entity[2]) num_mention += len(sentInst.entities) print("# mentions: {}".format(num_mention)) self.sub_word_alphabet = Alphabet(self.bert_tokenizer.vocab, 0) self.label_alphabet = Alphabet(label_set, 0)
def _gen_dic(self) -> None: label_set = set() for sent_list in [self.train, self.dev, self.test]: num_mention = 0 for sentInst in sent_list: for entity in sentInst.entities: label_set.add(entity[2]) num_mention += len(sentInst.entities) print("# mentions: {}".format(num_mention)) vocab = [ self.tokenizer.id_to_token(idx) for idx in range(self.tokenizer.get_vocab_size()) ] self.subword_alphabet = Alphabet(vocab, 0) self.label_alphabet = Alphabet(label_set, 0)
def _gen_dic(self) -> None: word_set = set() char_set = set() label_set = set() for sent_list in [self.train, self.dev, self.test]: num_mention = 0 for sentInst in sent_list: if sent_list is self.train: for token in sentInst.chars: for char in token: char_set.add(char) for token in sentInst.tokens: if self.lowercase: token = token.lower() if token not in self.vocab2id: word_set.add(token) for entity in sentInst.entities: label_set.add(entity[2]) num_mention += len(sentInst.entities) print("# mentions: {}".format(num_mention)) self.word_iv_alphabet = Alphabet(self.vocab2id, len(PREDIFINE_TOKEN_IDS)) self.word_ooev_alphabet = Alphabet(word_set, len(PREDIFINE_TOKEN_IDS)) self.char_alphabet = Alphabet(char_set, len(PREDIFINE_CHAR_IDS)) self.label_alphabet = Alphabet(label_set, 0)
class Reader: def __init__(self) -> None: self.vocab2id: Dict[str, int] = {} self.lowercase: Optional[bool] = None self.word_iv_alphabet: Optional[Alphabet] = None self.word_ooev_alphabet: Optional[Alphabet] = None self.char_alphabet: Optional[Alphabet] = None self.label_alphabet: Optional[Alphabet] = None self.train: Optional[List[SentInst]] = None self.dev: Optional[List[SentInst]] = None self.test: Optional[List[SentInst]] = None def read_and_gen_vectors_glove(self, embed_path: str) -> None: token_embed = None ret_mat = [] with open(GLOVE_FILE, 'r') as f: id = 0 for line in f: s_s = line.split() if token_embed is None: token_embed = len(s_s) - 1 ret_mat.append(np.zeros(token_embed).astype('float32')) else: assert (token_embed + 1 == len(s_s)) id += 1 self.vocab2id[s_s[0]] = id ret_mat.append(np.array([float(x) for x in s_s[1:]])) self.lowercase = True ret_mat = np.array(ret_mat) with open(embed_path, 'wb') as f: pickle.dump(ret_mat, f) def read_and_gen_vectors_pubmed_word2vec(self, embed_path: str) -> None: ret_mat = [] with open(PUBMED_WORD2VEC_FILE, 'rb') as f: line = f.readline().rstrip(b'\n') vsize, token_embed = line.split() vsize = int(vsize) token_embed = int(token_embed) id = 0 ret_mat.append(np.zeros(token_embed).astype('float32')) for v in range(vsize): wchars = [] while True: c = f.read(1) if c == b' ': break assert (c is not None) wchars.append(c) word = b''.join(wchars) if word.startswith(b'\n'): word = word[1:] id += 1 self.vocab2id[word.decode('utf-8')] = id ret_mat.append(np.fromfile(f, np.float32, token_embed)) assert (vsize + 1 == len(ret_mat)) self.lowercase = False ret_mat = np.array(ret_mat) with open(embed_path, 'wb') as f: pickle.dump(ret_mat, f) @staticmethod def _read_file(filename: str, mode: str = 'train') -> List[SentInst]: sent_list = [] max_len = 0 num_thresh = 0 with open(filename) as f: for line in f: line = line.strip() if line == "": # last few blank lines break raw_tokens = line.split(' ') tokens = raw_tokens chars = [list(t) for t in raw_tokens] entities = next(f).strip() if entities == "": # no entities sent_inst = SentInst(tokens, chars, []) else: entity_list = [] entities = entities.split("|") for item in entities: pointers, label = item.split() pointers = pointers.split(",") if int(pointers[1]) > len(tokens): pdb.set_trace() span_len = int(pointers[1]) - int(pointers[0]) assert (span_len > 0) if span_len > max_len: max_len = span_len if span_len > 6: num_thresh += 1 new_entity = (int(pointers[0]), int(pointers[1]), label) # may be duplicate entities in some datasets if (mode == 'train' and new_entity not in entity_list) or (mode != 'train'): entity_list.append(new_entity) # assert len(entity_list) == len(set(entity_list)) # check duplicate sent_inst = SentInst(tokens, chars, entity_list) assert next(f).strip() == "" # separating line sent_list.append(sent_inst) print("Max length: {}".format(max_len)) print("Threshold 6: {}".format(num_thresh)) return sent_list def _gen_dic(self) -> None: word_set = set() char_set = set() label_set = set() for sent_list in [self.train, self.dev, self.test]: num_mention = 0 for sentInst in sent_list: if sent_list is self.train: for token in sentInst.chars: for char in token: char_set.add(char) for token in sentInst.tokens: if self.lowercase: token = token.lower() if token not in self.vocab2id: word_set.add(token) for entity in sentInst.entities: label_set.add(entity[2]) num_mention += len(sentInst.entities) print("# mentions: {}".format(num_mention)) self.word_iv_alphabet = Alphabet(self.vocab2id, len(PREDIFINE_TOKEN_IDS)) self.word_ooev_alphabet = Alphabet(word_set, len(PREDIFINE_TOKEN_IDS)) self.char_alphabet = Alphabet(char_set, len(PREDIFINE_CHAR_IDS)) self.label_alphabet = Alphabet(label_set, 0) @staticmethod def _pad_batches(token_iv_batches: List[List[List[int]]], token_ooev_batches: List[List[List[int]]], char_batches: List[List[List[List[int]]]]) \ -> Tuple[List[List[List[int]]], List[List[List[int]]], List[List[List[List[int]]]], List[List[List[bool]]]]: default_token_id = PREDIFINE_TOKEN_IDS['DEFAULT'] default_char_id = PREDIFINE_CHAR_IDS['DEFAULT'] bot_id = PREDIFINE_CHAR_IDS['BOT'] # beginning of token eot_id = PREDIFINE_CHAR_IDS['EOT'] # end of token padded_token_iv_batches = [] padded_token_ooev_batches = [] padded_char_batches = [] mask_batches = [] all_batches = list( zip(token_iv_batches, token_ooev_batches, char_batches)) for token_iv_batch, token_ooev_batch, char_batch in all_batches: batch_len = len(token_iv_batch) max_sent_len = len(token_iv_batch[0]) max_char_len = max( [max([len(t) for t in char_mat]) for char_mat in char_batch]) padded_token_iv_batch = [] padded_token_ooev_batch = [] padded_char_batch = [] mask_batch = [] for i in range(batch_len): sent_len = len(token_iv_batch[i]) padded_token_iv_vec = token_iv_batch[i].copy() padded_token_iv_vec.extend([default_token_id] * (max_sent_len - sent_len)) padded_token_ooev_vec = token_ooev_batch[i].copy() padded_token_ooev_vec.extend([default_token_id] * (max_sent_len - sent_len)) padded_char_mat = [] for t in char_batch[i]: padded_t = list() padded_t.append(bot_id) padded_t.extend(t) padded_t.append(eot_id) padded_t.extend([default_char_id] * (max_char_len - len(t))) padded_char_mat.append(padded_t) for t in range(sent_len, max_sent_len): padded_char_mat.append( [default_char_id] * (max_char_len + 2)) # max_len + bot + eot mask = [True] * sent_len + [False] * (max_sent_len - sent_len) padded_token_iv_batch.append(padded_token_iv_vec) padded_token_ooev_batch.append(padded_token_ooev_vec) padded_char_batch.append(padded_char_mat) mask_batch.append(mask) padded_token_iv_batches.append(padded_token_iv_batch) padded_token_ooev_batches.append(padded_token_ooev_batch) padded_char_batches.append(padded_char_batch) mask_batches.append(mask_batch) return padded_token_iv_batches, padded_token_ooev_batches, padded_char_batches, mask_batches def to_batch(self, batch_size: int) -> Tuple: ret_list = [] for sent_list in [self.train, self.dev, self.test]: token_iv_dic = defaultdict(list) token_ooev_dic = defaultdict(list) char_dic = defaultdict(list) label_dic = defaultdict(list) this_token_iv_batches = [] this_token_ooev_batches = [] this_char_batches = [] this_label_batches = [] for sentInst in sent_list: token_iv_vec = [] token_ooev_vec = [] for t in sentInst.tokens: if self.lowercase: t = t.lower() if t in self.vocab2id: token_iv_vec.append(self.vocab2id[t]) token_ooev_vec.append(0) else: token_iv_vec.append(0) token_ooev_vec.append( self.word_ooev_alphabet.get_index(t)) char_mat = [[self.char_alphabet.get_index(c) for c in t] for t in sentInst.chars] # max_len = max([len(t) for t in sentInst.chars]) # char_mat = [ t + [0] * (max_len - len(t)) for t in char_mat ] label_list = [(u[0], u[1], self.label_alphabet.get_index(u[2])) for u in sentInst.entities] token_iv_dic[len(sentInst.tokens)].append(token_iv_vec) token_ooev_dic[len(sentInst.tokens)].append(token_ooev_vec) char_dic[len(sentInst.tokens)].append(char_mat) label_dic[len(sentInst.tokens)].append(label_list) token_iv_batches = [] token_ooev_batches = [] char_batches = [] label_batches = [] for length in sorted(token_iv_dic.keys(), reverse=True): token_iv_batches.extend(token_iv_dic[length]) token_ooev_batches.extend(token_ooev_dic[length]) char_batches.extend(char_dic[length]) label_batches.extend(label_dic[length]) [ this_token_iv_batches.append(token_iv_batches[i:i + batch_size]) for i in range(0, len(token_iv_batches), batch_size) ] [ this_token_ooev_batches.append(token_ooev_batches[i:i + batch_size]) for i in range(0, len(token_ooev_batches), batch_size) ] [ this_char_batches.append(char_batches[i:i + batch_size]) for i in range(0, len(char_batches), batch_size) ] [ this_label_batches.append(label_batches[i:i + batch_size]) for i in range(0, len(label_batches), batch_size) ] this_token_iv_batches, this_token_ooev_batches, this_char_batches, this_mask_batches \ = self._pad_batches(this_token_iv_batches, this_token_ooev_batches, this_char_batches) ret_list.append( (this_token_iv_batches, this_token_ooev_batches, this_char_batches, this_label_batches, this_mask_batches)) return tuple(ret_list) def read_all_data(self, file_path: str, train_file: str, dev_file: str, test_file: str) -> None: self.train = self._read_file(file_path + train_file) self.dev = self._read_file(file_path + dev_file, mode='dev') self.test = self._read_file(file_path + test_file, mode='test') self._gen_dic() def debug_single_sample(self, token_v: List[int], char_mat: List[List[int]], char_len_vec: List[int], label_list: List[Tuple[int, int, int]]) -> None: print(" ".join( [self.word_ooev_alphabet.get_instance(t) for t in token_v])) for t in char_mat: print(" ".join([self.char_alphabet.get_instance(c) for c in t])) print(char_len_vec) for label in label_list: print(label[0], label[1], self.label_alphabet.get_instance(label[2]))
class Reader: def __init__(self, bert_model: str) -> None: self.bert_tokenizer: BertTokenizer = BertTokenizer.from_pretrained( bert_model, do_lower_case='-cased' not in bert_model) self.sub_word_alphabet: Alphabet = None self.label_alphabet: Alphabet = None self.train: List[SentInst] = None self.dev: List[SentInst] = None self.test: List[SentInst] = None @staticmethod def read_file(filename: str, mode: str = 'train') -> List[SentInst]: sent_list = [] max_len = 0 num_thresh = 0 with open(filename) as f: for line in f: line = line.strip() if line == "": # last few blank lines break raw_tokens = line.split() tokens = raw_tokens chars = [list(t) for t in raw_tokens] entities = next(f).strip() if entities == "": # no entities sent_inst = SentInst(tokens, chars, []) else: entity_list = [] entities = entities.split("|") for item in entities: pointers, label = item.split() pointers = pointers.split(",") if int(pointers[1]) > len(tokens): pdb.set_trace() span_len = int(pointers[1]) - int(pointers[0]) assert (span_len > 0) if span_len > max_len: max_len = span_len if span_len > 6: num_thresh += 1 new_entity = (int(pointers[0]), int(pointers[1]), label) # may be duplicate entities in some datasets if (mode == 'train' and new_entity not in entity_list) or (mode != 'train'): entity_list.append(new_entity) # assert len(entity_list) == len(set(entity_list)) # check duplicate sent_inst = SentInst(tokens, chars, entity_list) assert next(f).strip() == "" # separating line sent_list.append(sent_inst) print("Max length: {}".format(max_len)) print("Threshold 6: {}".format(num_thresh)) return sent_list def gen_dic(self) -> None: label_set = set() for sent_list in [self.train, self.dev, self.test]: num_mention = 0 for sentInst in sent_list: for entity in sentInst.entities: label_set.add(entity[2]) num_mention += len(sentInst.entities) print("# mentions: {}".format(num_mention)) self.sub_word_alphabet = Alphabet(self.bert_tokenizer.vocab, 0) self.label_alphabet = Alphabet(label_set, 0) @staticmethod def pad_batches(input_ids_batches: List[List[List[int]]], first_sub_tokens_batches: List[List[List[int]]]) \ -> Tuple[List[List[List[int]]], List[List[List[int]]], List[List[List[bool]]]]: padded_input_ids_batches = [] input_mask_batches = [] mask_batches = [] all_batches = list(zip(input_ids_batches, first_sub_tokens_batches)) for input_ids_batch, first_sub_tokens_batch in all_batches: batch_len = len(input_ids_batch) max_sub_tokens_num = max( [len(input_ids) for input_ids in input_ids_batch]) max_sent_len = max([ len(first_sub_tokens) for first_sub_tokens in first_sub_tokens_batch ]) padded_input_ids_batch = [] input_mask_batch = [] mask_batch = [] for i in range(batch_len): sub_tokens_num = len(input_ids_batch[i]) sent_len = len(first_sub_tokens_batch[i]) padded_sub_token_vec = input_ids_batch[i].copy() padded_sub_token_vec.extend( [0] * (max_sub_tokens_num - sub_tokens_num)) input_mask = [1] * sub_tokens_num + [0] * (max_sub_tokens_num - sub_tokens_num) mask = [True] * sent_len + [False] * (max_sent_len - sent_len) padded_input_ids_batch.append(padded_sub_token_vec) input_mask_batch.append(input_mask) mask_batch.append(mask) padded_input_ids_batches.append(padded_input_ids_batch) input_mask_batches.append(input_mask_batch) mask_batches.append(mask_batch) return padded_input_ids_batches, input_mask_batches, mask_batches def to_batch(self, batch_size: int) -> Tuple: bert_tokenizer = self.bert_tokenizer ret_list = [] for sent_list in [self.train, self.dev, self.test]: sub_token_dic_dic = defaultdict(lambda: defaultdict(list)) first_sub_token_dic_dic = defaultdict(lambda: defaultdict(list)) label_dic_dic = defaultdict(lambda: defaultdict(list)) this_input_ids_batches = [] this_first_sub_tokens_batches = [] this_label_batches = [] for sentInst in sent_list: sub_token_vec = [] first_sub_token_vec = [] sub_token_vec.extend( bert_tokenizer.convert_tokens_to_ids([CLS])) for t in sentInst.tokens: st = bert_tokenizer.tokenize(t) first_sub_token_vec.append(len(sub_token_vec)) sub_token_vec.extend( bert_tokenizer.convert_tokens_to_ids(st)) sub_token_vec.extend( bert_tokenizer.convert_tokens_to_ids([SEP])) label_list = [(u[0], u[1], self.label_alphabet.get_index(u[2])) for u in sentInst.entities] sub_token_dic_dic[len( sentInst.tokens)][len(sub_token_vec)].append(sub_token_vec) first_sub_token_dic_dic[len(sentInst.tokens)][len( sub_token_vec)].append(first_sub_token_vec) label_dic_dic[len( sentInst.tokens)][len(sub_token_vec)].append(label_list) input_ids_batches = [] first_sub_tokens_batches = [] label_batches = [] for length1 in sorted(sub_token_dic_dic.keys(), reverse=True): for length2 in sorted(sub_token_dic_dic[length1].keys(), reverse=True): input_ids_batches.extend( sub_token_dic_dic[length1][length2]) first_sub_tokens_batches.extend( first_sub_token_dic_dic[length1][length2]) label_batches.extend(label_dic_dic[length1][length2]) [ this_input_ids_batches.append(input_ids_batches[i:i + batch_size]) for i in range(0, len(input_ids_batches), batch_size) ] [ this_first_sub_tokens_batches.append( first_sub_tokens_batches[i:i + batch_size]) for i in range(0, len(first_sub_tokens_batches), batch_size) ] [ this_label_batches.append(label_batches[i:i + batch_size]) for i in range(0, len(label_batches), batch_size) ] this_input_ids_batches, this_input_mask_batches, this_mask_batches \ = self.pad_batches(this_input_ids_batches, this_first_sub_tokens_batches) ret_list.append((this_input_ids_batches, this_input_mask_batches, this_first_sub_tokens_batches, this_label_batches, this_mask_batches)) return tuple(ret_list) def read_all_data(self, file_path: str, train_file: str, dev_file: str, test_file: str) -> None: self.train = self.read_file(file_path + train_file) self.dev = self.read_file(file_path + dev_file, mode='dev') self.test = self.read_file(file_path + test_file, mode='test') self.gen_dic() def debug_single_sample(self, sub_token: List[int], label_list: List[Tuple[int, int, int]]) -> None: print(" ".join( [self.sub_word_alphabet.get_instance(t) for t in sub_token])) for label in label_list: print(label[0], label[1], self.label_alphabet.get_instance(label[2]))
class Reader(object): def __init__(self, bert_model: str, tokenizer: BaseTokenizer = None, cls: str = "[CLS]", sep: str = "[SEP]", threshold=6): self.tokenizer: BaseTokenizer = tokenizer self.cls = cls self.sep = sep if self.tokenizer is None: vocab_path: str = "tokenization/" + bert_model + ".txt" self.tokenizer = BertWordPieceTokenizer(vocab_path, lowercase="-cased" not in bert_model) self.threshold = threshold self.subword_alphabet: Optional[Alphabet] = None self.label_alphabet: Optional[Alphabet] = None self.train: Optional[List[SentInst]] = None self.dev: Optional[List[SentInst]] = None self.test: Optional[List[SentInst]] = None def _read_file(self, filename: str, mode: str = 'train') -> List[SentInst]: sent_list = [] max_len = 0 num_thresh = 0 with open(filename, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line == "": # last few blank lines break raw_tokens = line.split(' ') tokens = raw_tokens chars = [list(t) for t in raw_tokens] entities = next(f).strip() if entities == "": # no entities sent_inst = SentInst(tokens, chars, []) else: entity_list = [] entities = entities.split("|") for item in entities: pointers, label = item.split() pointers = pointers.split(",") if int(pointers[1]) > len(tokens): pdb.set_trace() span_len = int(pointers[1]) - int(pointers[0]) if span_len < 0: print("Warning! span_len < 0") continue if span_len > max_len: max_len = span_len if span_len > self.threshold: num_thresh += 1 new_entity = (int(pointers[0]), int(pointers[1]), label) # may be duplicate entities in some datasets if (mode == 'train' and new_entity not in entity_list) or (mode != 'train'): entity_list.append(new_entity) # assert len(entity_list) == len(set(entity_list)) # check duplicate sent_inst = SentInst(tokens, chars, entity_list) assert next(f).strip() == "" # separating line sent_list.append(sent_inst) print("Max length: {}".format(max_len)) print("Threshold {}: {}".format(self.threshold, num_thresh)) return sent_list def _gen_dic(self) -> None: label_set = set() for sent_list in [self.train, self.dev, self.test]: num_mention = 0 for sentInst in sent_list: for entity in sentInst.entities: label_set.add(entity[2]) num_mention += len(sentInst.entities) print("# mentions: {}".format(num_mention)) vocab = [ self.tokenizer.id_to_token(idx) for idx in range(self.tokenizer.get_vocab_size()) ] self.subword_alphabet = Alphabet(vocab, 0) self.label_alphabet = Alphabet(label_set, 0) @staticmethod def _pad_batches(input_ids_batches: List[List[List[int]]], first_subtokens_batches: List[List[List[int]]]) \ -> Tuple[List[List[List[int]]], List[List[List[int]]], List[List[List[bool]]]]: padded_input_ids_batches = [] input_mask_batches = [] mask_batches = [] all_batches = list(zip(input_ids_batches, first_subtokens_batches)) for input_ids_batch, first_subtokens_batch in all_batches: batch_len = len(input_ids_batch) max_subtokens_num = max( [len(input_ids) for input_ids in input_ids_batch]) max_sent_len = max([ len(first_subtokens) for first_subtokens in first_subtokens_batch ]) padded_input_ids_batch = [] input_mask_batch = [] mask_batch = [] for i in range(batch_len): subtokens_num = len(input_ids_batch[i]) sent_len = len(first_subtokens_batch[i]) padded_subtoken_vec = input_ids_batch[i].copy() padded_subtoken_vec.extend([0] * (max_subtokens_num - subtokens_num)) input_mask = [1] * subtokens_num + [0] * (max_subtokens_num - subtokens_num) mask = [True] * sent_len + [False] * (max_sent_len - sent_len) padded_input_ids_batch.append(padded_subtoken_vec) input_mask_batch.append(input_mask) mask_batch.append(mask) padded_input_ids_batches.append(padded_input_ids_batch) input_mask_batches.append(input_mask_batch) mask_batches.append(mask_batch) return padded_input_ids_batches, input_mask_batches, mask_batches def get_batches(self, sentences: List[SentInst], batch_size: int) -> Tuple: subtoken_dic_dic = defaultdict(lambda: defaultdict(list)) first_subtoken_dic_dic = defaultdict(lambda: defaultdict(list)) last_subtoken_dic_dic = defaultdict(lambda: defaultdict(list)) label_dic_dic = defaultdict(lambda: defaultdict(list)) this_input_ids_batches = [] this_first_subtokens_batches = [] this_last_subtokens_batches = [] this_label_batches = [] for sentInst in sentences: subtoken_vec = [] first_subtoken_vec = [] last_subtoken_vec = [] subtoken_vec.append(self.tokenizer.token_to_id(self.cls)) for t in sentInst.tokens: encoding = self.tokenizer.encode(t) ids = [ v for v, mask in zip(encoding.ids, encoding.special_tokens_mask) if mask == 0 ] first_subtoken_vec.append(len(subtoken_vec)) subtoken_vec.extend(ids) last_subtoken_vec.append(len(subtoken_vec)) subtoken_vec.append(self.tokenizer.token_to_id(self.sep)) label_list = [(u[0], u[1], self.label_alphabet.get_index(u[2])) for u in sentInst.entities] subtoken_dic_dic[len( sentInst.tokens)][len(subtoken_vec)].append(subtoken_vec) first_subtoken_dic_dic[len( sentInst.tokens)][len(subtoken_vec)].append(first_subtoken_vec) last_subtoken_dic_dic[len( sentInst.tokens)][len(subtoken_vec)].append(last_subtoken_vec) label_dic_dic[len( sentInst.tokens)][len(subtoken_vec)].append(label_list) input_ids_batches = [] first_subtokens_batches = [] last_subtokens_batches = [] label_batches = [] for length1 in sorted(subtoken_dic_dic.keys(), reverse=True): for length2 in sorted(subtoken_dic_dic[length1].keys(), reverse=True): input_ids_batches.extend(subtoken_dic_dic[length1][length2]) first_subtokens_batches.extend( first_subtoken_dic_dic[length1][length2]) last_subtokens_batches.extend( last_subtoken_dic_dic[length1][length2]) label_batches.extend(label_dic_dic[length1][length2]) [ this_input_ids_batches.append(input_ids_batches[i:i + batch_size]) for i in range(0, len(input_ids_batches), batch_size) ] [ this_first_subtokens_batches.append( first_subtokens_batches[i:i + batch_size]) for i in range(0, len(first_subtokens_batches), batch_size) ] [ this_last_subtokens_batches.append( last_subtokens_batches[i:i + batch_size]) for i in range(0, len(last_subtokens_batches), batch_size) ] [ this_label_batches.append(label_batches[i:i + batch_size]) for i in range(0, len(label_batches), batch_size) ] this_input_ids_batches, this_input_mask_batches, this_mask_batches \ = self._pad_batches(this_input_ids_batches, this_first_subtokens_batches) return (this_input_ids_batches, this_input_mask_batches, this_first_subtokens_batches, this_last_subtokens_batches, this_label_batches, this_mask_batches) def to_batch(self, batch_size: int) -> Tuple: ret_list = [] for sent_list in [self.train, self.dev, self.test]: ret_list.append(self.get_batches(sent_list, batch_size)) return tuple(ret_list) def read_all_data(self, file_path: str, train_file: str, dev_file: str, test_file: str) -> None: self.train = self._read_file(file_path + train_file) self.dev = self._read_file(file_path + dev_file, mode='dev') self.test = self._read_file(file_path + test_file, mode='test') self._gen_dic() def debug_single_sample(self, subtoken: List[int], label_list: List[Tuple[int, int, int]]) -> None: print(" ".join( [self.subword_alphabet.get_instance(t) for t in subtoken])) for label in label_list: print(label[0], label[1], self.label_alphabet.get_instance(label[2]))