def write(text_processor: TextProcessor, output_file: str, input_file: str, max_len: int, sample_size: int): with open(input_file, "r") as r: obj = json.load(r) annotations = obj["annotations"] captions = list( map(lambda annotation: caption_data(annotation), annotations)) print(len(captions)) skipped_long_sens = 0 image_path_dict, unique_images = dict(), dict() tok_captions = {} image_ids = {} for ci, c in enumerate(captions): if ci % 1000 == 0: print(ci, "/", len(captions), "->", len(tok_captions), len(unique_images), end="\r") tok_sen = text_processor.tokenize_one_sentence(c[1]) if len(tok_sen) > max_len: skipped_long_sens += 1 continue path = c[0] if path not in image_path_dict: image_id = len(unique_images) unique_images[image_id] = path image_path_dict[path] = image_id elif path in image_path_dict: image_id = image_path_dict[path] unique_images[image_id] = path caption_id = len(tok_captions) tok_captions[caption_id] = tok_sen image_ids[caption_id] = image_id if (ci + 1) >= sample_size and sample_size > 0: break print("Skipped long sentences:", skipped_long_sens, "from", len(captions)) tok_captions_sorted = sorted(tok_captions.items(), key=lambda item: len(item[1])) caption_sorted = list( map(lambda e: (image_ids[e[0]], e[1]), tok_captions_sorted)) print("Longest sentence", len(tok_captions_sorted[-1][1])) with open(output_file, "wb") as wfp: marshal.dump((unique_images, caption_sorted), wfp) print("Dumped", len(caption_sorted), "captions from", len(unique_images), "unique images")
def write(text_processor: TextProcessor, output_file: str, input_file: str, max_len: int, sample_size: int, lang): eos = "</s>" if lang is not None: lang = "<" + lang + ">" skipped_long_sens = 0 image_path_dict, unique_images = dict(), dict() tok_captions = {} image_ids = {} with open(input_file, "r") as r: for ci, line in enumerate(r): try: path, caption = line.strip().split("\t") if lang is not None and not caption.startswith(lang): caption = " ".join([lang, caption, eos]) tok_sen = text_processor.tokenize_one_sentence(caption) if len(tok_sen) > max_len: skipped_long_sens += 1 continue if path not in image_path_dict: image_id = len(unique_images) unique_images[image_id] = path image_path_dict[path] = image_id elif path in image_path_dict: image_id = image_path_dict[path] unique_images[image_id] = path caption_id = len(tok_captions) tok_captions[caption_id] = tok_sen image_ids[caption_id] = image_id if (ci + 1) >= sample_size and sample_size > 0: break except: print(line.strip()) print("Skipped long sentences:", skipped_long_sens) tok_captions_sorted = sorted(tok_captions.items(), key=lambda item: len(item[1])) caption_sorted = list(map(lambda e: (image_ids[e[0]], e[1]), tok_captions_sorted)) print("Longest sentence", len(tok_captions_sorted[-1][1])) with open(output_file, "wb") as wfp: marshal.dump((unique_images, caption_sorted), wfp) print("Dumped", len(caption_sorted), "captions from", len(unique_images), "unique images")
def write(text_processor: TextProcessor, output_file: str, input_file: str, root_img_dir, skip_check: bool = False, max_len: int = 256, ref_file=None, choose_relevant=True, only_captions=False): ref_images = None if ref_file is not None: with open(ref_file, "rb") as fp: doc_dicts = json.load(fp) ref_images = set( chain(*map( lambda v: list(map(lambda im: im["img_path"], v["images"]) ), doc_dicts))) with open(input_file, "rb") as fp: doc_dicts = json.load(fp) num_captions = sum(list(map(lambda v: len(v["images"]), doc_dicts))) if only_captions: captions = list( chain(*map(lambda v: extract_captions(v, ref_images), doc_dicts))) elif choose_relevant: captions = list( chain(*map(lambda v: extract_shared_sentences(v, ref_images), doc_dicts))) else: captions = list( chain(*map(lambda v: extract_sentences(v, ref_images), doc_dicts))) print(num_captions, len(captions)) transform = transforms.Compose([ # [1] transforms.Resize(256), # [2] transforms.CenterCrop(224), # [3] transforms.ToTensor(), # [4] transforms.Normalize( # [5] mean=[0.485, 0.456, 0.406], # [6] std=[0.229, 0.224, 0.225] # [7] ) ]) skipped_long_sens = 0 image_path_dict, unique_images = dict(), dict() tok_captions = {} image_ids = {} for ci, c in enumerate(captions): if ci % 1000 == 0: print(ci, "/", len(captions), "->", len(tok_captions), len(unique_images), end="\r") try: tok_sen = text_processor.tokenize_one_sentence(c[1]) if len(tok_sen) > max_len: skipped_long_sens += 1 continue path = c[0] if not skip_check and path not in image_path_dict: with Image.open(os.path.join(root_img_dir, path)) as im: # make sure not to deal with rgba or grayscale images. _ = transform(im.convert("RGB")) im.close() image_id = len(unique_images) unique_images[image_id] = path image_path_dict[path] = image_id if skip_check and path not in image_path_dict: image_id = len(unique_images) unique_images[image_id] = path image_path_dict[path] = image_id elif path in image_path_dict: image_id = image_path_dict[path] unique_images[image_id] = path caption_id = len(tok_captions) tok_captions[caption_id] = tok_sen image_ids[caption_id] = image_id except: pass print("Skipped long sentences:", skipped_long_sens, "from", len(captions)) tok_captions_sorted = sorted(tok_captions.items(), key=lambda item: len(item[1])) caption_sorted = list( map(lambda e: (image_ids[e[0]], e[1]), tok_captions_sorted)) print("Longest sentence", len(tok_captions_sorted[-1][1])) with open(output_file, "wb") as wfp: marshal.dump((unique_images, caption_sorted), wfp) print("Dumped", len(caption_sorted), "captions from", len(unique_images), "unique images")
def write(text_processor: TextProcessor, output_file: str, src_txt_file: str, dst_txt_file: str = None, min_len: int = 1, max_len: int = 175): examples = {} line_num = 0 lens = {} if dst_txt_file is not None: with open(src_txt_file, "r") as s_fp, open(dst_txt_file, "r") as d_fp: for src_line, dst_line in zip(s_fp, d_fp): if len(src_line.strip()) == 0 or len(dst_line.strip()) == 0: continue src_line = " ".join(["<ar>", src_line.strip(), "</s>"]) dst_line = " ".join(["<en>", dst_line.strip(), "</s>"]) src_tok_line = text_processor.tokenize_one_sentence( src_line.replace(" </s> ", " ")) src_lang = text_processor.languages[text_processor.id2token( src_tok_line[0])] dst_tok_line = text_processor.tokenize_one_sentence( dst_line.replace(" </s> ", " ")) dst_lang = text_processor.languages[text_processor.id2token( dst_tok_line[0])] if min_len <= len(src_tok_line) <= max_len and min_len <= len( dst_tok_line) <= max_len: examples[line_num] = (src_tok_line, dst_tok_line, src_lang, dst_lang) lens[line_num] = len(dst_tok_line) line_num += 1 print("Sorting") sorted_lens = sorted(lens.items(), key=lambda item: item[1]) sorted_examples = [] print("Sorted examples") for len_item in sorted_lens: line_num = len(sorted_examples) sorted_examples.append(examples[len_item[0]]) print("Dumping") with open(output_file, "wb") as fw: marshal.dump(sorted_examples, fw) else: part_num = 0 # Used for MASS training where we only have source sentences. with open(src_txt_file, "r") as s_fp: for src_line in s_fp: if len(src_line.strip()) == 0: continue src_tok_line = text_processor.tokenize_one_sentence( src_line.strip()) src_lang = text_processor.languages[text_processor.id2token( src_tok_line[0])] if min_len <= len(src_tok_line) <= max_len: examples[line_num] = (src_tok_line, src_lang) lens[line_num] = len(src_tok_line) line_num += 1 if line_num % 1000 == 0: print(line_num, "\r", end="") if len(examples) >= 30000000: print(datetime.datetime.now(), "Sorting and writing", part_num) sorted_lens = sorted(lens.items(), key=lambda item: item[1]) sorted_examples = list( map(lambda len_item: examples[len_item[0]], sorted_lens)) with open(output_file + "." + str(part_num), "wb") as fw: marshal.dump(sorted_examples, fw) examples = {} lens = {} part_num += 1 if len(examples) > 0: print(datetime.datetime.now(), "Sorting and writing", part_num) sorted_lens = sorted(lens.items(), key=lambda item: item[1]) sorted_examples = list( map(lambda len_item: examples[len_item[0]], sorted_lens)) with open(output_file + "." + str(part_num), "wb") as fw: marshal.dump(sorted_examples, fw)