Exemplo n.º 1
0
def convert(dataset_path: Path, split, output_path, coco_path):
    all_datapoints: List[Datapoint] = []

    safe_ids = set()
    with open("combined_ref_exp_safe_train_ids.txt", "r") as f:
        for line in f:
            safe_ids.add(int(line.strip()))

    with open(f"{coco_path}/annotations/instances_train2014.json", "r") as f:
        coco_annotations = json.load(f)
    coco_anns = coco_annotations["annotations"]
    annid2cocoann = {item["id"]: item for item in coco_anns}

    for dataset_name in [
            "refcoco/refs(unc).p", "refcoco+/refs(unc).p",
            "refcocog/refs(umd).p"
    ]:
        d_name = dataset_name.split("/")[0]

        with open(dataset_path / dataset_name, "rb") as f:
            data = pickle.load(f)

        for item in tqdm(data):
            if item["split"] != split:
                continue
            if item["split"] == "train" and item["image_id"] not in safe_ids:
                continue
            for s in item["sentences"]:
                refexp = s["sent"]
                target_bbox = annid2cocoann[item["ann_id"]]["bbox"]
                converted_bbox = [
                    target_bbox[0],
                    target_bbox[1],
                    target_bbox[2] + target_bbox[0],
                    target_bbox[3] + target_bbox[1],
                ]

                _, _, root_spans, neg_spans = get_root_and_nouns(refexp)
                cur_datapoint = Datapoint(
                    image_id=item["image_id"],
                    dataset_name=d_name,
                    original_id=item["ann_id"],
                    caption=refexp,
                    annotations=[],
                    tokens_negative=consolidate_spans(neg_spans, refexp),
                )

                cur_obj = Annotation(
                    area=annid2cocoann[item["ann_id"]]["area"],
                    iscrowd=annid2cocoann[item["ann_id"]]["iscrowd"],
                    category_id=item["category_id"],
                    bbox=target_bbox,
                    giou_friendly_bbox=converted_bbox,
                    tokens_positive=consolidate_spans(root_spans, refexp),
                )
                cur_datapoint.annotations.append(cur_obj)
                all_datapoints.append(cur_datapoint)

    with open(output_path / "refexp_dict.pkl", "wb") as f:
        pickle.dump(all_datapoints, f)
Exemplo n.º 2
0
def convert(split, data_path, img_path, sg_path, output_path, imid2data):

    with open(data_path / f"{split}_balanced_questions.json", "r") as f:
        data = json.load(f)
    with open(sg_path / f"{split}_sceneGraphs.json", "r") as f:
        sg_data = json.load(f)

    img2ann = defaultdict(dict)
    for k, v in data.items():
        img2ann[v["imageId"]][k] = v

    # Add missing annotations by inspecting the semantic field
    regexp = re.compile(r"([0-9]+)")
    regexp2 = re.compile(r"([A-z]+)")
    count = 0

    for k, v in img2ann.items():
        for ann_id, annotations in v.items():
            expected_boxes = []
            for item in annotations["semantic"]:
                if item["operation"] == "select":
                    if len(regexp.findall(item["argument"])) > 0:
                        expected_boxes.append(
                            (regexp2.findall(item["argument"])[0].strip(),
                             regexp.findall(item["argument"])[0]))
            question_boxes = list(
                annotations["annotations"]["question"].values())

            for name, box_id in expected_boxes:
                if box_id not in question_boxes:
                    count += 1
                    beg = annotations["question"].find(name)
                    end = beg + len(name)
                    annotations["annotations"]["question"][(beg, end)] = box_id

    # Add annotations for the questions where there is a box for the answer but not for the question (what/where/who questions)
    for k, v in img2ann.items():
        for ann_id, ann in v.items():
            question_objects = list(ann["annotations"]["question"].values())
            answer_objects = list(ann["annotations"]["answer"].values())
            if len(set(answer_objects) - set(question_objects)) > 0:

                for box_id in answer_objects:
                    if box_id not in question_objects:

                        if ann["question"].find("What") > -1:
                            beg = ann["question"].find("What")
                            end = beg + len("What")
                        elif ann["question"].find("what") > -1:
                            beg = ann["question"].find("what")
                            end = beg + len("what")
                        elif ann["question"].find("Who") > -1:
                            beg = ann["question"].find("Who")
                            end = beg + len("Who")
                        elif ann["question"].find("who") > -1:
                            beg = ann["question"].find("who")
                            end = beg + len("who")
                        elif ann["question"].find("Where") > -1:
                            beg = ann["question"].find("Where")
                            end = beg + len("Where")
                        elif ann["question"].find("where") > -1:
                            beg = ann["question"].find("where")
                            end = beg + len("where")
                        else:
                            continue

                        ann["annotations"]["question"][(beg, end)] = box_id

    all_datapoints: List[Datapoint] = []
    d_name = "gqa"

    for k, v in tqdm(img2ann.items()):
        for ann_id, annotation in v.items():
            question = annotation["question"]
            cur_datapoint = Datapoint(
                image_id=k,
                dataset_name="gqa",
                original_id=ann_id,
                caption=question,
                annotations=[],
                tokens_negative=[(0, len(question))],
            )

            if len(annotation["annotations"]["question"]) > 0:

                for text_tok_id, box_anno_id in annotation["annotations"][
                        "question"].items():
                    target_bbox = sg_data[k]["objects"][box_anno_id]
                    x, y, h, w = target_bbox["x"], target_bbox[
                        "y"], target_bbox["h"], target_bbox["w"]
                    target_bbox = [x, y, w, h]
                    converted_bbox = [
                        target_bbox[0],
                        target_bbox[1],
                        target_bbox[2] + target_bbox[0],
                        target_bbox[3] + target_bbox[1],
                    ]

                    if isinstance(text_tok_id, str):
                        if ":" in text_tok_id:
                            text_tok_id = text_tok_id.split(":")
                        if isinstance(text_tok_id,
                                      list) and len(text_tok_id) > 1:
                            beg = sum([
                                len(x)
                                for x in question.split()[:int(text_tok_id[0])]
                            ]) + int(text_tok_id[0])
                            end = (sum([
                                len(x) for x in question.split()
                                [:int(text_tok_id[1]) - 1]
                            ]) + int(text_tok_id[1]) - 1)
                            end = end + len(
                                question.split()[int(text_tok_id[1]) - 1])
                        else:
                            beg = sum([
                                len(x)
                                for x in question.split()[:int(text_tok_id)]
                            ]) + int(text_tok_id)
                            end = beg + len(question.split()[int(text_tok_id)])
                    else:
                        beg, end = text_tok_id

                    cleaned_span = consolidate_spans([(beg, end)], question)

                    cur_ann = Annotation(
                        area=h * w,
                        iscrowd=0,
                        category_id=1,
                        bbox=target_bbox,
                        giou_friendly_bbox=converted_bbox,
                        tokens_positive=cleaned_span,
                    )
                    cur_datapoint.annotations.append(cur_ann)
            all_datapoints.append(cur_datapoint)

    with open(output_path / "gqa_dict.pkl", "wb") as f:
        pickle.dump(all_datapoints, f)
Exemplo n.º 3
0
def convert(split, data_path, sg_path, output_path, imid2data, type,
            coco_path):

    if split == "train" and type == "all":
        data = {}
        for i in tqdm(range(10)):
            with open(
                    data_path /
                    f"train_all_questions/train_all_questions_{i}.json",
                    "r") as f:
                data.update(json.load(f))
        print(len(data))
    else:
        with open(data_path / f"{split}_{type}_questions.json", "r") as f:
            data = json.load(f)

    if split in ["train", "val"]:
        with open(sg_path / f"{split}_sceneGraphs.json", "r") as f:
            sg_data = json.load(f)

    img2ann = defaultdict(dict)
    for k, v in data.items():
        img2ann[v["imageId"]][k] = v

    if split in ["train", "val", "testdev"]:

        # Add missing annotations by inspecting the semantic field
        regexp = re.compile(r"([0-9]+)")
        regexp2 = re.compile(r"([A-z]+)")
        count = 0

        for k, v in img2ann.items():
            for ann_id, annotations in v.items():
                expected_boxes = []
                for item in annotations["semantic"]:
                    if item["operation"] == "select":
                        if len(regexp.findall(item["argument"])) > 0:
                            expected_boxes.append(
                                (regexp2.findall(item["argument"])[0].strip(),
                                 regexp.findall(item["argument"])[0]))
                question_boxes = [
                    v
                    for k, v in annotations["annotations"]["question"].items()
                ]

                for name, box_id in expected_boxes:
                    if box_id not in question_boxes:
                        count += 1
                        beg = annotations["question"].find(name)
                        end = beg + len(name)
                        annotations["annotations"]["question"][(beg,
                                                                end)] = box_id

        # Add annotations for the questions where there is a box for the answer but not for the question (what/where/who questions)
        for k, v in img2ann.items():
            for ann_id, ann in v.items():
                question_objects = [
                    vv for kk, vv in ann["annotations"]["question"].items()
                ]
                answer_objects = [
                    vv for kk, vv in ann["annotations"]["answer"].items()
                ]
                if len(set(answer_objects) - set(question_objects)) > 0:

                    for box_id in answer_objects:
                        if box_id not in question_objects:

                            if ann["question"].find("What") > -1:
                                beg = ann["question"].find("What")
                                end = beg + len("What")
                            elif ann["question"].find("what") > -1:
                                beg = ann["question"].find("what")
                                end = beg + len("what")
                            elif ann["question"].find("Who") > -1:
                                beg = ann["question"].find("Who")
                                end = beg + len("Who")
                            elif ann["question"].find("who") > -1:
                                beg = ann["question"].find("who")
                                end = beg + len("who")
                            elif ann["question"].find("Where") > -1:
                                beg = ann["question"].find("Where")
                                end = beg + len("Where")
                            elif ann["question"].find("where") > -1:
                                beg = ann["question"].find("where")
                                end = beg + len("where")
                            else:
                                continue

                            ann["annotations"]["question"][(beg, end)] = box_id

    print(f"Dumping {split}...")
    next_img_id = 0
    next_id = 0

    annotations = []
    images = []

    d_name = "gqa"

    if split in ["testdev", "test", "challenge", "submission"]:
        with open(f"{coco_path}/annotations/image_info_test2015.json",
                  "r") as f:
            iminfo = json.load(f)
            imid2data = {x["id"]: x for x in iminfo["images"]}

    for k, v in tqdm(img2ann.items()):

        for ann_id, annotation in v.items():
            question = annotation["question"]
            questionId = ann_id

            filename = f"{k}.jpg"
            if split in ["submission"]:
                cur_img = {
                    "file_name": filename,
                    "height": 400,
                    "width": 800,
                    "id": next_img_id,
                    "original_id": k,
                    "caption": question,
                    "tokens_negative": [(0, len(question))],
                    "dataset_name": d_name,
                    "question_type": None,
                    "answer": None,
                    "questionId": questionId,
                }

            elif split in ["test", "challenge", "submission"]:
                cur_img = {
                    "file_name": filename,
                    "height": imid2data[int(k.strip("n"))]["height"],
                    "width": imid2data[int(k.strip("n"))]["width"],
                    "id": next_img_id,
                    "original_id": k,
                    "caption": question,
                    "tokens_negative": [(0, len(question))],
                    "dataset_name": d_name,
                    "question_type": None,
                    "answer": None,
                    "questionId": questionId,
                }

            elif split == "testdev":
                cur_img = {
                    "file_name": filename,
                    "height": imid2data[int(k.strip("n"))]["height"],
                    "width": imid2data[int(k.strip("n"))]["width"],
                    "id": next_img_id,
                    "original_id": k,
                    "caption": question,
                    "tokens_negative": [(0, len(question))],
                    "dataset_name": d_name,
                    "question_type": annotation["types"]["semantic"],
                    "answer": annotation["answer"],
                    "questionId": questionId,
                }
            else:
                cur_img = {
                    "file_name": filename,
                    "height": imid2data[int(k)]["height"],
                    "width": imid2data[int(k)]["width"],
                    "id": next_img_id,
                    "original_id": k,
                    "caption": question,
                    "tokens_negative": [(0, len(question))],
                    "dataset_name": d_name,
                    "question_type": annotation["types"]["semantic"],
                    "answer": annotation["answer"],
                    "questionId": questionId,
                }

            if (split not in ["testdev", "test", "challenge", "submission"]
                    and len(annotation["annotations"]["question"]) > 0):

                for text_tok_id, box_anno_id in annotation["annotations"][
                        "question"].items():
                    target_bbox = sg_data[k]["objects"][box_anno_id]
                    x, y, h, w = target_bbox["x"], target_bbox[
                        "y"], target_bbox["h"], target_bbox["w"]
                    target_bbox = [x, y, w, h]

                    if isinstance(text_tok_id, str):
                        if ":" in text_tok_id:
                            text_tok_id = text_tok_id.split(":")
                        if isinstance(text_tok_id,
                                      list) and len(text_tok_id) > 1:
                            beg = sum([
                                len(x)
                                for x in question.split()[:int(text_tok_id[0])]
                            ]) + int(text_tok_id[0])
                            end = (sum([
                                len(x) for x in question.split()
                                [:int(text_tok_id[1]) - 1]
                            ]) + int(text_tok_id[1]) - 1)
                            end = end + len(
                                question.split()[int(text_tok_id[1]) - 1])
                        else:
                            beg = sum([
                                len(x)
                                for x in question.split()[:int(text_tok_id)]
                            ]) + int(text_tok_id)
                            end = beg + len(question.split()[int(text_tok_id)])
                    else:
                        beg, end = text_tok_id

                    cleaned_span = consolidate_spans([(beg, end)], question)

                    cur_obj = {
                        "area": h * w,
                        "iscrowd": 0,
                        "category_id": 1,
                        "bbox": target_bbox,
                        "tokens_positive": cleaned_span,
                        "image_id": next_img_id,
                        "id": next_id,
                    }

                    next_id += 1
                    annotations.append(cur_obj)

            next_img_id += 1
            images.append(cur_img)

    ds = {
        "info": [],
        "licenses": [],
        "images": images,
        "annotations": annotations,
        "categories": []
    }
    with open(output_path / f"finetune_gqa_{split}_{type}.json",
              "w") as j_file:
        json.dump(ds, j_file)
    return next_img_id, next_id
Exemplo n.º 4
0
def preprocess_region(region):
    filtered_region = {
        "caption":
        simplify_punctuation(normalize_whitespace(region["phrase"])),
        "original_image_id": region["image_id"],
        "original_region_id": region["region_id"],
        "boxes": [],
        "tokens_positive": [],
        "found_objects": False,
    }
    if len(filtered_region["caption"]) < 3:
        raise PreprocessError("caption too short, skipping" +
                              filtered_region["caption"])
    _, _, root_spans, negative_spans = get_root_and_nouns(
        filtered_region["caption"].lower(), False)

    # Filter objects that have multiple synsets, they are likely to be spurious
    obj_synsets = set(
        [o["synsets"][0] for o in region["objects"] if len(o["synsets"]) == 1])
    synsets_count = Counter([s["synset_name"] for s in region["synsets"]])
    # Filter synsets that occur multiple times, since we don't have mapping to objects
    all_synsets = set([
        s["synset_name"] for s in region["synsets"]
        if synsets_count[s["synset_name"]] == 1
    ])
    authorized_synsets = obj_synsets.intersection(all_synsets)
    syn2span: Dict[str, Tuple[int, int]] = {
        s["synset_name"]: (s["entity_idx_start"], s["entity_idx_end"])
        for s in region["synsets"] if s["synset_name"] in authorized_synsets
    }

    synlist, spanlist = [], []
    for k, s in syn2span.items():
        synlist.append(k)
        spanlist.append([s])

    # the spans positions may have been altered by the whitespace removal, so we recompute here
    spanlist, new_caption = get_canonical_spans(spanlist,
                                                region["phrase"],
                                                whitespace_only=True)
    if new_caption.lower().strip() != filtered_region["caption"].lower().strip(
    ):
        raise PreprocessError(
            f"Inconsistent whitespace removal: '{new_caption}' vs '{filtered_region['caption']}'"
        )

    assert len(synlist) == len(spanlist)
    syn2span = {k: v[0] for k, v in zip(synlist, spanlist)}

    root_objs = []
    other_objs: Dict[Tuple[int, int], List[List[int]]] = {}
    for obj in region["objects"]:
        if len(obj["synsets"]
               ) == 1 and obj["synsets"][0] in authorized_synsets:
            cur_span = syn2span[obj["synsets"][0]]
            if span_intersect_spanlist(cur_span, root_spans):
                root_objs.append(obj_to_box(obj))
                filtered_region["found_objects"] = True
            else:
                if cur_span not in other_objs:
                    other_objs[cur_span] = []
                    negative_spans.append(cur_span)
                other_objs[cur_span].append(obj_to_box(obj))
                filtered_region["found_objects"] = True

    if len(root_objs) == 0:
        # If we don't have a box for the root of the sentence, we use the box of the region itself.
        root_objs.append(region_to_box(region))

    dedup_root_objs = combine_boxes(root_objs)
    filtered_region["boxes"] += dedup_root_objs
    root_spans = consolidate_spans(root_spans, filtered_region["caption"])
    filtered_region["tokens_positive"] += [
        root_spans for _ in range(len(dedup_root_objs))
    ]

    for span, objs in other_objs.items():
        dedup_objs = combine_boxes(objs)
        filtered_region["boxes"] += dedup_objs
        cur_spans = consolidate_spans([span], filtered_region["caption"])
        filtered_region["tokens_positive"] += [
            cur_spans for _ in range(len(dedup_objs))
        ]

    filtered_region["tokens_negative"] = consolidate_spans(
        negative_spans, filtered_region["caption"])
    return filtered_region
Exemplo n.º 5
0
def deduplicate_regions(regions, iou_threshold=0.5):
    """This functions accepts pre-processed region descriptions for a given image, and removes regions that are redundant.
    Two regions are deemed redundant if 1) the text is closely matching 2) the IOU between region boxes is > iou_threshold
    A cleaned description is returned.
    """
    def helper_merge(regions):
        if len(regions) <= 1:
            return regions
        uf = UnionFind(len(regions))
        for r in regions:
            spans, txt2 = get_canonical_spans(r["tokens_positive"],
                                              r["caption"])
            if txt != txt2:
                raise PreprocessError(
                    f"inconsistent canonicalization fct. Mismatch: '{txt}' and '{txt2}'"
                )
            r["cano_tokens"] = spans

        for r1 in range(len(regions)):
            for r2 in range(r1 + 1, len(regions)):
                compatible = True
                assert len(regions[r1]["boxes"]) == len(
                    regions[r1]["cano_tokens"])
                assert len(regions[r2]["boxes"]) == len(
                    regions[r2]["cano_tokens"])
                ious = box_iou_helper(regions[r1]["boxes"],
                                      regions[r2]["boxes"])
                for b1 in range(len(regions[r1]["cano_tokens"])):
                    for b2 in range(len(regions[r2]["cano_tokens"])):
                        if (len(regions[r1]["cano_tokens"][b1]) == 0
                                or len(regions[r2]["cano_tokens"][b2])
                                == 0) or (spanlist_intersect_spanlist(
                                    regions[r1]["cano_tokens"][b1],
                                    regions[r2]["cano_tokens"][b2])
                                          and ious[b1][b2] < iou_threshold):
                            compatible = False
                            break
                    if not compatible:
                        break
                if compatible:
                    uf.unite(r1, r2)
        compo2regions = defaultdict(list)
        for i, r in enumerate(regions):
            compo2regions[uf.find(i)].append(r)

        final_regions = []
        for reg_list in compo2regions.values():
            if len(reg_list) == 1:
                final_regions.append(reg_list[0])
            else:
                # We pick as representative of this cluster the region with the most boxes
                sorted_regions = sorted([(len(r["boxes"]), i)
                                         for i, r in enumerate(reg_list)],
                                        reverse=True)
                reg_ids = [sr[1] for sr in sorted_regions]
                # We need to put the boxes and token spans in buckets
                cano_spans_buckets = []
                orig_spans_buckets = []
                boxes_buckets = []
                for idx in reg_ids:
                    for b in range(len(reg_list[idx]["boxes"])):
                        # find the bucket
                        bucket = -1
                        for j in range(len(cano_spans_buckets)):
                            if spanlist_intersect_spanlist(
                                    reg_list[idx]["cano_tokens"][b],
                                    cano_spans_buckets[j]):
                                bucket = j
                                break
                        if bucket == -1:
                            # bucket not found, creating one.
                            if idx != reg_ids[0]:
                                # This shouldn't happen. But if it does, we give up on the merging
                                return regions
                                assert idx == reg_ids[0], (
                                    "TODO: if this triggers, it means another regions has token spans than aren't covered by the main region."
                                    +
                                    "We need to create a new token span, which involve finding the span in the original sentencen of the main region. Don't forget to update the negative tokens"
                                )

                            bucket = len(orig_spans_buckets)
                            orig_spans_buckets.append(
                                reg_list[idx]["tokens_positive"][b])
                            cano_spans_buckets.append(
                                reg_list[idx]["cano_tokens"][b])
                            boxes_buckets.append([reg_list[idx]["boxes"][b]])
                        else:
                            boxes_buckets[bucket].append(
                                reg_list[idx]["boxes"][b])
                assert len(orig_spans_buckets) == len(boxes_buckets)
                merged_region = deepcopy(reg_list[reg_ids[0]])
                merged_region["tokens_positive"] = []
                merged_region["boxes"] = []
                for i in range(len(boxes_buckets)):
                    dedup_objs = combine_boxes(boxes_buckets[i],
                                               iou_threshold=0.5)
                    merged_region["boxes"] += dedup_objs
                    merged_region["tokens_positive"] += [
                        orig_spans_buckets[i] for _ in range(len(dedup_objs))
                    ]
                final_regions.append(merged_region)
        for r in final_regions:
            del r["cano_tokens"]
        return final_regions

    txt2region = defaultdict(list)
    for r in regions:
        txt2region[normalize_sentence(r["caption"])].append(r)

    stupid_sentence_set = set(["wall", "side", "building"])
    final_regions = []
    for txt, regions in txt2region.items():
        # Edge case, we remove the sentences like "the wall on the side of the building" which are uninformative and have spurious boxes
        if "wall" in txt and set(
                txt.strip().split(" ")).issubset(stupid_sentence_set):
            continue
        if len(regions) == 1:
            final_regions.append(deepcopy(regions[0]))
        else:
            # print(txt)

            regions_with_boxes = [r for r in regions if r["found_objects"]]
            all_boxes = sum([r["boxes"] for r in regions_with_boxes], [])
            # print("regions with boxes", len(regions_with_boxes))

            regions_without_boxes = []
            for r in regions:
                if not r["found_objects"]:
                    # we consider than one of the region with boxes will be better suited and drop this one
                    # if there is a positive iou. Otherwise, we have to keep it
                    if len(regions_with_boxes) == 0 or box_iou_helper(
                            all_boxes, r["boxes"]).max().item() < 0.1:
                        regions_without_boxes.append(r)

            # print("regions without boxes", len(regions_without_boxes))

            try:
                new_regions_with_boxes = helper_merge(regions_with_boxes)
            except PreprocessError as e:
                print("skipping", e)
                # Ouch, hit a cornercase, we give up on the merge
                new_regions_with_boxes = regions_with_boxes
            try:
                new_regions_without_boxes = helper_merge(regions_without_boxes)
            except PreprocessError as e:
                print("skipping", e)
                # Ouch, hit a cornercase, we give up on the merge
                new_regions_without_boxes = regions_without_boxes

            # now collapse into one big region. We do it only when the captions are exactly matching, otherwise it's a nightmare to recompute spans
            capt2region = defaultdict(list)
            for r in new_regions_with_boxes + new_regions_without_boxes:
                capt2region[r["caption"]].append(r)
            for capt, reg_list in capt2region.items():
                all_boxes = sum([r["boxes"] for r in reg_list], [])
                all_tokens = sum([r["tokens_positive"] for r in reg_list], [])
                compo2boxes, compo2id = get_boxes_equiv(all_boxes,
                                                        iou_threshold=0.75)
                final_boxes = []
                final_tokens = []
                if compo2boxes is not None:
                    for compo in compo2boxes.keys():
                        box_list = compo2boxes[compo]
                        id_list = compo2id[compo]
                        final_boxes.append(
                            xyxy_to_xywh(torch.stack(box_list,
                                                     0).mean(0)).tolist())
                        final_tokens.append(
                            consolidate_spans(
                                sum([all_tokens[i] for i in id_list], []),
                                capt))
                else:
                    final_boxes = all_boxes
                    final_tokens = all_tokens

                merged_region = {
                    "caption":
                    capt,
                    "original_image_id":
                    reg_list[0]["original_image_id"],
                    "original_region_id":
                    reg_list[0]["original_region_id"],
                    "boxes":
                    final_boxes,
                    "tokens_positive":
                    final_tokens,
                    "tokens_negative":
                    consolidate_spans(
                        sum([r["tokens_negative"] for r in reg_list], []),
                        capt),
                    "found_objects":
                    False,
                }
                final_regions.append(merged_region)

    return final_regions
Exemplo n.º 6
0
def convert(dataset_path: Path, split: str, output_path, coco_path, next_img_id: int = 0, next_id: int = 0):
    """Do the heavy lifting on the given split (eg 'train')"""

    print(f"Exporting {split}...")

    with open(f"{coco_path}/annotations/instances_train2014.json", "r") as f:
        coco_annotations = json.load(f)
    coco_images = coco_annotations["images"]
    coco_anns = coco_annotations["annotations"]
    annid2cocoann = {item["id"]: item for item in coco_anns}
    imgid2cocoimgs = {item["id"]: item for item in coco_images}

    categories = coco_annotations["categories"]
    annotations = []
    images = []

    for dataset_name in ["refcoco/refs(unc).p", "refcoco+/refs(unc).p", "refcocog/refs(umd).p"]:
        d_name = dataset_name.split("/")[0]

        with open(dataset_path / dataset_name, "rb") as f:
            data = pickle.load(f)

        for item in data:
            if item["split"] != split:
                continue

            for s in item["sentences"]:
                refexp = s["sent"]
                _, _, root_spans, neg_spans = get_root_and_nouns(refexp)
                root_spans = consolidate_spans(root_spans, refexp)
                neg_spans = consolidate_spans(neg_spans, refexp)

                filename = "_".join(item["file_name"].split("_")[:-1]) + ".jpg"
                cur_img = {
                    "file_name": filename,
                    "height": imgid2cocoimgs[item["image_id"]]["height"],
                    "width": imgid2cocoimgs[item["image_id"]]["width"],
                    "id": next_img_id,
                    "original_id": item["image_id"],
                    "caption": refexp,
                    "dataset_name": d_name,
                    "tokens_negative": neg_spans,
                }

                cur_obj = {
                    "area": annid2cocoann[item["ann_id"]]["area"],
                    "iscrowd": annid2cocoann[item["ann_id"]]["iscrowd"],
                    "image_id": next_img_id,
                    "category_id": item["category_id"],
                    "id": next_id,
                    "bbox": annid2cocoann[item["ann_id"]]["bbox"],
                    # "segmentation": annid2cocoann[item['ann_id']]['segmentation'],
                    "original_id": item["ann_id"],
                    "tokens_positive": root_spans,
                }
                next_id += 1
                annotations.append(cur_obj)
                next_img_id += 1
                images.append(cur_img)

    ds = {
        "info": coco_annotations["info"],
        "licenses": coco_annotations["licenses"],
        "images": images,
        "annotations": annotations,
        "categories": coco_annotations["categories"],
    }
    with open(output_path / f"final_refexp_val.json", "w") as j_file:
        json.dump(ds, j_file)
    return next_img_id, next_id
Exemplo n.º 7
0
def get_refexp_groups(
        im2datapoint: Dict[str, List[Datapoint]]) -> List[Datapoint]:
    """This functions accepts a dictionary that contains all the datapoints from a given id.
    These datapoints are assumed to come from the same image subset (vg or coco)

    For each image, given the list of datapoints, we try to combine several datapoints together.
    The combination simply concatenates the captions for the combined datapoints, as well as the list of boxes.
    For a combination to be deemed acceptable, we require that the boxes are not overlapping too much.
    This ensures that only one part of the combined caption is referring to a particular object in the image.
    To achieve this combination, we use a greedy graph-coloring algorithm.

    This function returns a flat list of all the combined datapoints that were created.
    """
    combined_datapoints: List[Datapoint] = []

    for image_id, all_datapoints in tqdm(im2datapoint.items()):
        # get all the referring expressions for this image
        refexps = [datapoint.caption for datapoint in all_datapoints]

        # Create a graph where there is an edge between two datapoints iff they are NOT compatible
        adj_list = {i: [] for i in range(len(refexps))}

        # Get the list of all boxes (in "giou_friendly" format, aka [top_left_x, top_left_y, bottom_right_x, bottom_right_y]) for each datapoint
        all_boxes = []
        for datapoint in all_datapoints:
            if len(datapoint.annotations) > 0:
                all_boxes.append(
                    torch.stack([
                        torch.as_tensor(ann.giou_friendly_bbox)
                        for ann in datapoint.annotations
                    ]))
            else:
                all_boxes.append(torch.zeros(0, 4))

        # To find which referring expressions to combine into a single instance, we apply a graph coloring step
        # First we build the graph of refexps such that nodes correspond to refexps and and edge occurs between
        # two nodes when max giou between ANY boxes in the annotations > 0.5. This implies they are both referring
        # to the same box and hence should not be combined into one example.
        for i in range(len(all_datapoints)):
            for j in range(i + 1, len(all_datapoints)):
                giou = box_iou(all_boxes[i], all_boxes[j])
                if giou.numel() > 0 and torch.max(giou).item() > 0.5:
                    adj_list[i].append(j)
                    adj_list[j].append(i)

        # Here we build the colored graph corresponding to the adjacency list given by adj_list
        colored_graph: Dict[int, int] = {}  # Color of each vertex
        nodes_degree = [(len(v), k) for k, v in adj_list.items()]
        nodes_sorted = sorted(nodes_degree, reverse=True)
        global_colors = [0]  # Colors used so far
        color_size = defaultdict(
            int)  # total length of the captions assigned to each color

        def get_color(admissible_color_set, new_length):
            admissible_color_list = sorted(list(admissible_color_set))
            for color in admissible_color_list:
                if color_size[color] + new_length + 2 <= 250:
                    return color
            return None

        # Loop over all nodes and color with the lowest color that is compatible
        # We add the constraint that the sum of the lengths of all the captions assigned to a given color is less than 250 (our max sequence length)
        for _, node in nodes_sorted:
            used_colors = set()
            # Gather the colors of the neighbours
            for adj_node in adj_list[node]:
                if adj_node in colored_graph:
                    used_colors.add(colored_graph[adj_node])
            if len(used_colors) < 1:
                # Neighbours are uncolored, we take the smallest color
                curr_color = get_color(global_colors,
                                       len(all_datapoints[node].caption))
            else:
                # Find the smallest unused color
                curr_color = get_color(
                    set(global_colors) - set(used_colors),
                    len(all_datapoints[node].caption))
            if curr_color is None:
                # Couldn't find a suitable color, creating one
                global_colors.append(max(global_colors) + 1)
                curr_color = global_colors[-1]
            colored_graph[node] = curr_color
            color_size[curr_color] += len(all_datapoints[node].caption)

        # Collect the datapoints that all have the same color
        color2datapoints: Dict[int, List[Datapoint]] = defaultdict(list)
        for node, color in colored_graph.items():
            color2datapoints[color].append(all_datapoints[node])

        # Make sure we have a valid coloring by checking that adjacent nodes have different colors
        for k, v in adj_list.items():
            for node in v:
                assert colored_graph[k] != colored_graph[node]

        for cur_datapoint_list in color2datapoints.values():
            if len(cur_datapoint_list) == 0:
                continue
            # collect the captions, and maybe add a trailing punctuation mark if there is not already
            all_captions = [
                set_last_char(datapoint.caption)
                for datapoint in cur_datapoint_list
            ]
            combined_caption = " ".join(all_captions) + " "

            # compute the combined (offsetted) negative span
            cur_offset = 0
            combined_tokens_negative: List[Tuple[int, int]] = []
            for i, datapoint in enumerate(cur_datapoint_list):
                combined_tokens_negative += shift_spans(
                    datapoint.tokens_negative, cur_offset)
                cur_offset += len(all_captions[i]) + 1  # 1 for space
            assert cur_offset == len(combined_caption)

            cur_combined_datapoint = Datapoint(
                image_id=image_id,
                dataset_name="mixed",
                tokens_negative=consolidate_spans(combined_tokens_negative,
                                                  combined_caption),
                original_id=-1,
                caption=combined_caption.rstrip(),
                annotations=[],
            )

            # compute the offsetted positive span and append all annotations
            cur_offset = 0
            for data_id, datapoint in enumerate(cur_datapoint_list):
                for ann_id, ann in enumerate(datapoint.annotations):
                    new_annotation = Annotation(
                        area=ann.area,
                        iscrowd=ann.iscrowd,
                        category_id=ann.category_id,
                        bbox=ann.bbox,
                        giou_friendly_bbox=[],  # We don't need that anymore
                        tokens_positive=consolidate_spans(
                            shift_spans(ann.tokens_positive, cur_offset),
                            combined_caption),
                    )
                    cur_combined_datapoint.annotations.append(new_annotation)
                cur_offset += len(all_captions[data_id]) + 1
            assert cur_offset == len(combined_caption)

            combined_datapoints.append(cur_combined_datapoint)

    return combined_datapoints