Example #1
0
def mdetr_resnet101_refcocog(pretrained=False, return_postprocessor=False):
    """
    MDETR R101 with 6 encoder and 6 decoder layers.
    Trained on refcocog, achieves 81.64 val accuracy
    """
    model = _make_detr("resnet101")
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url=
            "https://zenodo.org/record/4721981/files/refcocog_resnet101_checkpoint.pth",
            map_location="cpu",
            check_hash=True,
        )
        model.load_state_dict(checkpoint["model"])
    if return_postprocessor:
        return model, PostProcess()
    return model
Example #2
0
def mdetr_efficientnetB3_refcoco(pretrained=False, return_postprocessor=False):
    """
    MDETR ENB3 with 6 encoder and 6 decoder layers.
    Trained on refcoco, achieves 86.75 val accuracy
    """
    model = _make_detr("timm_tf_efficientnet_b3_ns")
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url=
            "https://zenodo.org/record/4721981/files/refcoco_EB3_checkpoint.pth",
            map_location="cpu",
            check_hash=True,
        )
        model.load_state_dict(checkpoint["model"])
    if return_postprocessor:
        return model, PostProcess()
    return model
Example #3
0
def mdetr_efficientnetB3(pretrained=False, return_postprocessor=False):
    """
    MDETR ENB3 with 6 encoder and 6 decoder layers.
    Pretrained on our combined aligned dataset of 1.3 million images paired with text.
    """

    model = _make_detr("timm_tf_efficientnet_b3_ns")
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url=
            "https://zenodo.org/record/4721981/files/pretrained_EB3_checkpoint.pth",
            map_location="cpu",
            check_hash=True,
        )
        model.load_state_dict(checkpoint["model"])
    if return_postprocessor:
        return model, PostProcess()
    return model
Example #4
0
def mdetr_resnet101_phrasecut(pretrained=False,
                              threshold=0.5,
                              return_postprocessor=False):
    """
    MDETR R101 with 6 encoder and 6 decoder layers.
    Trained on Phrasecut, achieves 53.1 M-IoU on the test set
    """
    model = _make_detr("resnet101", mask=True, contrastive_align_loss=False)
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url=
            "https://zenodo.org/record/4721981/files/phrasecut_resnet101_checkpoint.pth",
            map_location="cpu",
            check_hash=True,
        )
        model.load_state_dict(checkpoint["model"])
    if return_postprocessor:
        return model, [PostProcess(), PostProcessSegm(threshold=threshold)]
    return model
Example #5
0
def mdetr_efficientnetB5_gqa(pretrained=False, return_postprocessor=False):
    """
    MDETR ENB5 with 6 encoder and 6 decoder layers.
    Trained on GQA, achieves 61.99 on test-std
    """

    model = _make_detr("timm_tf_efficientnet_b5_ns",
                       qa_dataset="gqa",
                       contrastive_align_loss=False)
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url=
            "https://zenodo.org/record/4721981/files/gqa_EB5_checkpoint.pth",
            map_location="cpu",
            check_hash=True)
        model.load_state_dict(checkpoint["model"])
    if return_postprocessor:
        return model, PostProcess()
    return model
Example #6
0
def mdetr_clevr(pretrained=False, return_postprocessor=False):
    """
    MDETR R18 with 6 encoder and 6 decoder layers.
    Trained on CLEVR, achieves 99.7% accuracy
    """

    model = _make_detr("resnet18",
                       num_queries=25,
                       qa_dataset="clevr",
                       text_encoder="distilroberta-base")
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://zenodo.org/record/4721981/files/clevr_checkpoint.pth",
            map_location="cpu",
            check_hash=True)
        model.load_state_dict(checkpoint["model"])
    if return_postprocessor:
        return model, PostProcess()
    return model