def mdetr_resnet101_refcocog(pretrained=False, return_postprocessor=False): """ MDETR R101 with 6 encoder and 6 decoder layers. Trained on refcocog, achieves 81.64 val accuracy """ model = _make_detr("resnet101") if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url= "https://zenodo.org/record/4721981/files/refcocog_resnet101_checkpoint.pth", map_location="cpu", check_hash=True, ) model.load_state_dict(checkpoint["model"]) if return_postprocessor: return model, PostProcess() return model
def mdetr_efficientnetB3_refcoco(pretrained=False, return_postprocessor=False): """ MDETR ENB3 with 6 encoder and 6 decoder layers. Trained on refcoco, achieves 86.75 val accuracy """ model = _make_detr("timm_tf_efficientnet_b3_ns") if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url= "https://zenodo.org/record/4721981/files/refcoco_EB3_checkpoint.pth", map_location="cpu", check_hash=True, ) model.load_state_dict(checkpoint["model"]) if return_postprocessor: return model, PostProcess() return model
def mdetr_efficientnetB3(pretrained=False, return_postprocessor=False): """ MDETR ENB3 with 6 encoder and 6 decoder layers. Pretrained on our combined aligned dataset of 1.3 million images paired with text. """ model = _make_detr("timm_tf_efficientnet_b3_ns") if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url= "https://zenodo.org/record/4721981/files/pretrained_EB3_checkpoint.pth", map_location="cpu", check_hash=True, ) model.load_state_dict(checkpoint["model"]) if return_postprocessor: return model, PostProcess() return model
def mdetr_resnet101_phrasecut(pretrained=False, threshold=0.5, return_postprocessor=False): """ MDETR R101 with 6 encoder and 6 decoder layers. Trained on Phrasecut, achieves 53.1 M-IoU on the test set """ model = _make_detr("resnet101", mask=True, contrastive_align_loss=False) if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url= "https://zenodo.org/record/4721981/files/phrasecut_resnet101_checkpoint.pth", map_location="cpu", check_hash=True, ) model.load_state_dict(checkpoint["model"]) if return_postprocessor: return model, [PostProcess(), PostProcessSegm(threshold=threshold)] return model
def mdetr_efficientnetB5_gqa(pretrained=False, return_postprocessor=False): """ MDETR ENB5 with 6 encoder and 6 decoder layers. Trained on GQA, achieves 61.99 on test-std """ model = _make_detr("timm_tf_efficientnet_b5_ns", qa_dataset="gqa", contrastive_align_loss=False) if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url= "https://zenodo.org/record/4721981/files/gqa_EB5_checkpoint.pth", map_location="cpu", check_hash=True) model.load_state_dict(checkpoint["model"]) if return_postprocessor: return model, PostProcess() return model
def mdetr_clevr(pretrained=False, return_postprocessor=False): """ MDETR R18 with 6 encoder and 6 decoder layers. Trained on CLEVR, achieves 99.7% accuracy """ model = _make_detr("resnet18", num_queries=25, qa_dataset="clevr", text_encoder="distilroberta-base") if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url="https://zenodo.org/record/4721981/files/clevr_checkpoint.pth", map_location="cpu", check_hash=True) model.load_state_dict(checkpoint["model"]) if return_postprocessor: return model, PostProcess() return model