def write(text_processor: TextProcessor, output_file: str, txt_file: str, output_txt: bool = False): with open(txt_file, "r") as fp, open(output_file, "w") as writer: for line in fp: if len(line.strip()) == 0 or len(line.strip()) == 0: continue tok_line = text_processor.tokenize_one_line(line.strip(), ignore_middle_eos=True) if output_txt: tokenized = [text_processor.id2token(tok) for tok in tok_line][1:-1] tokenized = list(map(lambda tok: tok if tok != "<unk>" else "unk", tokenized)) else: tokenized = [str(tok) for tok in tok_line] writer.write(" ".join(tokenized) + "\n")
def __init__(self, root_img_dir: str, data_bin_file: str, max_capacity: int, text_processor: TextProcessor, max_img_per_batch: int, lex_dict=None, ngpu=1): self.ngpu = ngpu self.lex_dict = lex_dict self.size_transform = transforms.Resize(256) self.crop = transforms.CenterCrop(224) self.to_tensor = transforms.ToTensor() self.img_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.pad_idx = text_processor.pad_token_id() self.batches = [] self.root_img_dir = root_img_dir max_capacity *= 1000000 self.image_batches = [] self.lang_ids = set() self.all_captions = [] print("Start", datetime.datetime.now()) cur_batch, cur_imgs, cur_lex_cand_batch = [], [], [] cur_max_len = 0 with open(data_bin_file, "rb") as fp: self.unique_images, captions = marshal.load(fp) lang_id = text_processor.id2token(captions[0][1][0]) self.lang_ids.add(int(captions[0][1][0])) self.lang = text_processor.languages[ lang_id] if lang_id in text_processor.languages else 0 for caption_info in captions: image_id, caption = caption_info if self.unique_images[image_id].lower().endswith(".png"): continue caption = torch.LongTensor(caption) cur_batch.append(caption) self.all_captions.append(caption) if self.lex_dict is not None: lex_cands = get_lex_suggestions( self.lex_dict, caption, text_processor.pad_token_id()) cur_lex_cand_batch.append(lex_cands) cur_imgs.append(image_id) cur_max_len = max(cur_max_len, len(caption)) batch_capacity_size = 2 * (cur_max_len**3) * len(cur_batch) if (len(cur_imgs) > max_img_per_batch or batch_capacity_size > max_capacity ) and len( cur_batch[:-1]) >= self.ngpu and len(cur_batch) > 1: batch_tensor = pad_sequence(cur_batch[:-1], batch_first=True, padding_value=self.pad_idx) lex_cand_batch = None if self.lex_dict is not None: lex_cand_batch = pad_sequence( cur_lex_cand_batch[:-1], batch_first=True, padding_value=self.pad_idx) cur_lex_cand_batch = [cur_lex_cand_batch[-1]] pads = batch_tensor != self.pad_idx pad_indices = [int(pads.size(1)) - 1] * int(pads.size(0)) pindices = torch.nonzero(~pads) for (r, c) in pindices: pad_indices[r] = min(pad_indices[r], int(c)) self.batches.append( (batch_tensor, pads, torch.LongTensor(pad_indices), lex_cand_batch)) self.image_batches.append(cur_imgs[:-1]) cur_batch = [cur_batch[-1]] cur_imgs = [cur_imgs[-1]] cur_max_len = len(cur_batch[0]) if len(cur_batch) > 0: batch_tensor = pad_sequence(cur_batch, batch_first=True, padding_value=self.pad_idx) pads = batch_tensor != self.pad_idx pad_indices = [int(pads.size(1)) - 1] * int(pads.size(0)) lex_cand_batch = None if self.lex_dict is not None: lex_cand_batch = pad_sequence(cur_lex_cand_batch, batch_first=True, padding_value=self.pad_idx) pindices = torch.nonzero(~pads) for (r, c) in pindices: pad_indices[r] = min(pad_indices[r], int(c)) self.batches.append( (batch_tensor, pads, torch.LongTensor(pad_indices), lex_cand_batch)) self.image_batches.append(cur_imgs) print( "Loaded %d image batches of %d unique images and %d all captions!" % (len(self.batches), len( self.unique_images), len(self.all_captions))) print("End", datetime.datetime.now())
def write(text_processor: TextProcessor, cache_dir: str, seq_len: int, txt_file: str, sen_block_size: int = 10000): sen_block_size = sen_block_size current_cache, cur_cache_langs = [], [] examples = {} line_num, file_count = 0, 0 text_processor.max_len = seq_len with open(txt_file, "r") as fp: for ln, line in enumerate(fp): if len(line.strip()) == 0: continue tok_lines = text_processor.tokenize_lines(line.strip(), blind_split=True, split_len=seq_len) current_cache += list(tok_lines) cur_cache_langs += [ text_processor.languages[text_processor.id2token(tok_lines[0, 0])] ] * tok_lines.shape[0] if len(current_cache) >= 100000: for tok_line, lang in zip(current_cache, cur_cache_langs): # assuming that every list has same length due to correct padding. examples[line_num] = (tok_line.tolist(), lang) line_num += 1 if len(examples) >= sen_block_size: with open( os.path.join(cache_dir, str(file_count) + ".pkl"), "wb") as fw: marshal.dump(examples, fw) examples, file_count = {}, file_count + 1 current_cache, cur_cache_langs = [], [] print( f"from {ln} actual documents, dumped {line_num} big vectors into {file_count} files" ) if len(current_cache) > 0: for tok_line, lang in zip(current_cache, cur_cache_langs): # assuming that every list has same length due to correct padding. examples[line_num] = (tok_line.tolist(), lang) line_num += 1 if len(examples) >= sen_block_size: with open(os.path.join(cache_dir, str(file_count) + ".pkl"), "wb") as fw: marshal.dump(examples, fw) examples, file_count = {}, file_count + 1 if len(examples) >= 0: with open(os.path.join(cache_dir, str(file_count) + ".pkl"), "wb") as fw: marshal.dump(examples, fw) examples, file_count = {}, file_count + 1 print( f"from {ln} actual documents, dumped {line_num} big vectors into {file_count} files" ) with open(os.path.join(cache_dir, "info.txt"), "w") as fw: fw.write( str(sen_block_size) + "\t" + str(line_num) + "\t" + str(file_count))
def write(text_processor: TextProcessor, output_file: str, src_txt_file: str, dst_txt_file: str = None, min_len: int = 1, max_len: int = 175): examples = {} line_num = 0 lens = {} if dst_txt_file is not None: with open(src_txt_file, "r") as s_fp, open(dst_txt_file, "r") as d_fp: for src_line, dst_line in zip(s_fp, d_fp): if len(src_line.strip()) == 0 or len(dst_line.strip()) == 0: continue src_line = " ".join(["<ar>", src_line.strip(), "</s>"]) dst_line = " ".join(["<en>", dst_line.strip(), "</s>"]) src_tok_line = text_processor.tokenize_one_sentence( src_line.replace(" </s> ", " ")) src_lang = text_processor.languages[text_processor.id2token( src_tok_line[0])] dst_tok_line = text_processor.tokenize_one_sentence( dst_line.replace(" </s> ", " ")) dst_lang = text_processor.languages[text_processor.id2token( dst_tok_line[0])] if min_len <= len(src_tok_line) <= max_len and min_len <= len( dst_tok_line) <= max_len: examples[line_num] = (src_tok_line, dst_tok_line, src_lang, dst_lang) lens[line_num] = len(dst_tok_line) line_num += 1 print("Sorting") sorted_lens = sorted(lens.items(), key=lambda item: item[1]) sorted_examples = [] print("Sorted examples") for len_item in sorted_lens: line_num = len(sorted_examples) sorted_examples.append(examples[len_item[0]]) print("Dumping") with open(output_file, "wb") as fw: marshal.dump(sorted_examples, fw) else: part_num = 0 # Used for MASS training where we only have source sentences. with open(src_txt_file, "r") as s_fp: for src_line in s_fp: if len(src_line.strip()) == 0: continue src_tok_line = text_processor.tokenize_one_sentence( src_line.strip()) src_lang = text_processor.languages[text_processor.id2token( src_tok_line[0])] if min_len <= len(src_tok_line) <= max_len: examples[line_num] = (src_tok_line, src_lang) lens[line_num] = len(src_tok_line) line_num += 1 if line_num % 1000 == 0: print(line_num, "\r", end="") if len(examples) >= 30000000: print(datetime.datetime.now(), "Sorting and writing", part_num) sorted_lens = sorted(lens.items(), key=lambda item: item[1]) sorted_examples = list( map(lambda len_item: examples[len_item[0]], sorted_lens)) with open(output_file + "." + str(part_num), "wb") as fw: marshal.dump(sorted_examples, fw) examples = {} lens = {} part_num += 1 if len(examples) > 0: print(datetime.datetime.now(), "Sorting and writing", part_num) sorted_lens = sorted(lens.items(), key=lambda item: item[1]) sorted_examples = list( map(lambda len_item: examples[len_item[0]], sorted_lens)) with open(output_file + "." + str(part_num), "wb") as fw: marshal.dump(sorted_examples, fw)