示例#1
0
def write(text_processor: TextProcessor, output_file: str, input_file: str,
          max_len: int, sample_size: int):
    with open(input_file, "r") as r:
        obj = json.load(r)

    annotations = obj["annotations"]
    captions = list(
        map(lambda annotation: caption_data(annotation), annotations))
    print(len(captions))

    skipped_long_sens = 0
    image_path_dict, unique_images = dict(), dict()

    tok_captions = {}
    image_ids = {}
    for ci, c in enumerate(captions):
        if ci % 1000 == 0:
            print(ci,
                  "/",
                  len(captions),
                  "->",
                  len(tok_captions),
                  len(unique_images),
                  end="\r")
        tok_sen = text_processor.tokenize_one_sentence(c[1])
        if len(tok_sen) > max_len:
            skipped_long_sens += 1
            continue

        path = c[0]
        if path not in image_path_dict:
            image_id = len(unique_images)
            unique_images[image_id] = path
            image_path_dict[path] = image_id
        elif path in image_path_dict:
            image_id = image_path_dict[path]
            unique_images[image_id] = path

        caption_id = len(tok_captions)
        tok_captions[caption_id] = tok_sen
        image_ids[caption_id] = image_id

        if (ci + 1) >= sample_size and sample_size > 0:
            break

    print("Skipped long sentences:", skipped_long_sens, "from", len(captions))
    tok_captions_sorted = sorted(tok_captions.items(),
                                 key=lambda item: len(item[1]))
    caption_sorted = list(
        map(lambda e: (image_ids[e[0]], e[1]), tok_captions_sorted))
    print("Longest sentence", len(tok_captions_sorted[-1][1]))
    with open(output_file, "wb") as wfp:
        marshal.dump((unique_images, caption_sorted), wfp)
    print("Dumped", len(caption_sorted), "captions from", len(unique_images),
          "unique images")
def write(text_processor: TextProcessor, output_file: str, input_file: str, max_len: int, sample_size: int, lang):
    eos = "</s>"
    if lang is not None:
        lang = "<" + lang + ">"

    skipped_long_sens = 0
    image_path_dict, unique_images = dict(), dict()

    tok_captions = {}
    image_ids = {}
    with open(input_file, "r") as r:
        for ci, line in enumerate(r):
            try:
                path, caption = line.strip().split("\t")
                if lang is not None and not caption.startswith(lang):
                    caption = " ".join([lang, caption, eos])
                tok_sen = text_processor.tokenize_one_sentence(caption)
                if len(tok_sen) > max_len:
                    skipped_long_sens += 1
                    continue

                if path not in image_path_dict:
                    image_id = len(unique_images)
                    unique_images[image_id] = path
                    image_path_dict[path] = image_id
                elif path in image_path_dict:
                    image_id = image_path_dict[path]
                    unique_images[image_id] = path

                caption_id = len(tok_captions)
                tok_captions[caption_id] = tok_sen
                image_ids[caption_id] = image_id

                if (ci + 1) >= sample_size and sample_size > 0:
                    break
            except:
                print(line.strip())

    print("Skipped long sentences:", skipped_long_sens)
    tok_captions_sorted = sorted(tok_captions.items(), key=lambda item: len(item[1]))
    caption_sorted = list(map(lambda e: (image_ids[e[0]], e[1]), tok_captions_sorted))
    print("Longest sentence", len(tok_captions_sorted[-1][1]))
    with open(output_file, "wb") as wfp:
        marshal.dump((unique_images, caption_sorted), wfp)
    print("Dumped", len(caption_sorted), "captions from", len(unique_images), "unique images")
示例#3
0
def write(text_processor: TextProcessor,
          output_file: str,
          input_file: str,
          root_img_dir,
          skip_check: bool = False,
          max_len: int = 256,
          ref_file=None,
          choose_relevant=True,
          only_captions=False):
    ref_images = None
    if ref_file is not None:
        with open(ref_file, "rb") as fp:
            doc_dicts = json.load(fp)
            ref_images = set(
                chain(*map(
                    lambda v: list(map(lambda im: im["img_path"], v["images"])
                                   ), doc_dicts)))

    with open(input_file, "rb") as fp:
        doc_dicts = json.load(fp)
        num_captions = sum(list(map(lambda v: len(v["images"]), doc_dicts)))
        if only_captions:
            captions = list(
                chain(*map(lambda v: extract_captions(v, ref_images),
                           doc_dicts)))
        elif choose_relevant:
            captions = list(
                chain(*map(lambda v: extract_shared_sentences(v, ref_images),
                           doc_dicts)))
        else:
            captions = list(
                chain(*map(lambda v: extract_sentences(v, ref_images),
                           doc_dicts)))
        print(num_captions, len(captions))

    transform = transforms.Compose([  # [1]
        transforms.Resize(256),  # [2]
        transforms.CenterCrop(224),  # [3]
        transforms.ToTensor(),  # [4]
        transforms.Normalize(  # [5]
            mean=[0.485, 0.456, 0.406],  # [6]
            std=[0.229, 0.224, 0.225]  # [7]
        )
    ])

    skipped_long_sens = 0
    image_path_dict, unique_images = dict(), dict()

    tok_captions = {}
    image_ids = {}
    for ci, c in enumerate(captions):
        if ci % 1000 == 0:
            print(ci,
                  "/",
                  len(captions),
                  "->",
                  len(tok_captions),
                  len(unique_images),
                  end="\r")
        try:
            tok_sen = text_processor.tokenize_one_sentence(c[1])
            if len(tok_sen) > max_len:
                skipped_long_sens += 1
                continue

            path = c[0]
            if not skip_check and path not in image_path_dict:
                with Image.open(os.path.join(root_img_dir, path)) as im:
                    # make sure not to deal with rgba or grayscale images.
                    _ = transform(im.convert("RGB"))
                    im.close()
                image_id = len(unique_images)
                unique_images[image_id] = path
                image_path_dict[path] = image_id
            if skip_check and path not in image_path_dict:
                image_id = len(unique_images)
                unique_images[image_id] = path
                image_path_dict[path] = image_id
            elif path in image_path_dict:
                image_id = image_path_dict[path]
                unique_images[image_id] = path

            caption_id = len(tok_captions)
            tok_captions[caption_id] = tok_sen
            image_ids[caption_id] = image_id
        except:
            pass

    print("Skipped long sentences:", skipped_long_sens, "from", len(captions))
    tok_captions_sorted = sorted(tok_captions.items(),
                                 key=lambda item: len(item[1]))
    caption_sorted = list(
        map(lambda e: (image_ids[e[0]], e[1]), tok_captions_sorted))
    print("Longest sentence", len(tok_captions_sorted[-1][1]))
    with open(output_file, "wb") as wfp:
        marshal.dump((unique_images, caption_sorted), wfp)
    print("Dumped", len(caption_sorted), "captions from", len(unique_images),
          "unique images")
示例#4
0
def write(text_processor: TextProcessor,
          output_file: str,
          src_txt_file: str,
          dst_txt_file: str = None,
          min_len: int = 1,
          max_len: int = 175):
    examples = {}
    line_num = 0

    lens = {}
    if dst_txt_file is not None:
        with open(src_txt_file, "r") as s_fp, open(dst_txt_file, "r") as d_fp:
            for src_line, dst_line in zip(s_fp, d_fp):
                if len(src_line.strip()) == 0 or len(dst_line.strip()) == 0:
                    continue
                src_line = " ".join(["<ar>", src_line.strip(), "</s>"])
                dst_line = " ".join(["<en>", dst_line.strip(), "</s>"])
                src_tok_line = text_processor.tokenize_one_sentence(
                    src_line.replace(" </s> ", " "))
                src_lang = text_processor.languages[text_processor.id2token(
                    src_tok_line[0])]
                dst_tok_line = text_processor.tokenize_one_sentence(
                    dst_line.replace(" </s> ", " "))
                dst_lang = text_processor.languages[text_processor.id2token(
                    dst_tok_line[0])]

                if min_len <= len(src_tok_line) <= max_len and min_len <= len(
                        dst_tok_line) <= max_len:
                    examples[line_num] = (src_tok_line, dst_tok_line, src_lang,
                                          dst_lang)
                    lens[line_num] = len(dst_tok_line)
                    line_num += 1

        print("Sorting")
        sorted_lens = sorted(lens.items(), key=lambda item: item[1])
        sorted_examples = []
        print("Sorted examples")
        for len_item in sorted_lens:
            line_num = len(sorted_examples)
            sorted_examples.append(examples[len_item[0]])

        print("Dumping")
        with open(output_file, "wb") as fw:
            marshal.dump(sorted_examples, fw)

    else:
        part_num = 0
        # Used for MASS training where we only have source sentences.
        with open(src_txt_file, "r") as s_fp:
            for src_line in s_fp:
                if len(src_line.strip()) == 0: continue
                src_tok_line = text_processor.tokenize_one_sentence(
                    src_line.strip())
                src_lang = text_processor.languages[text_processor.id2token(
                    src_tok_line[0])]
                if min_len <= len(src_tok_line) <= max_len:
                    examples[line_num] = (src_tok_line, src_lang)
                    lens[line_num] = len(src_tok_line)
                    line_num += 1
                    if line_num % 1000 == 0:
                        print(line_num, "\r", end="")

                if len(examples) >= 30000000:
                    print(datetime.datetime.now(), "Sorting and writing",
                          part_num)
                    sorted_lens = sorted(lens.items(),
                                         key=lambda item: item[1])
                    sorted_examples = list(
                        map(lambda len_item: examples[len_item[0]],
                            sorted_lens))
                    with open(output_file + "." + str(part_num), "wb") as fw:
                        marshal.dump(sorted_examples, fw)
                    examples = {}
                    lens = {}
                    part_num += 1

        if len(examples) > 0:
            print(datetime.datetime.now(), "Sorting and writing", part_num)
            sorted_lens = sorted(lens.items(), key=lambda item: item[1])
            sorted_examples = list(
                map(lambda len_item: examples[len_item[0]], sorted_lens))
            with open(output_file + "." + str(part_num), "wb") as fw:
                marshal.dump(sorted_examples, fw)