Esempio n. 1
0
def detr_resnet101_panoptic(pretrained=False,
                            num_classes=250,
                            threshold=0.85,
                            return_postprocessor=False):
    """
    DETR-DC5 R101 with 6 encoder and 6 decoder layers.

    Achieves 45.1 PQ on COCO val5k.

   threshold is the minimum confidence required for keeping segments in the prediction
    """
    model = _make_detr("resnet101",
                       dilation=False,
                       num_classes=num_classes,
                       mask=True)
    is_thing_map = {i: i <= 90 for i in range(250)}
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url=
            "https://dl.fbaipublicfiles.com/detr/detr-r101-panoptic-40021d53.pth",
            map_location="cpu",
            check_hash=True,
        )
        model.load_state_dict(checkpoint["model"])
    if return_postprocessor:
        return model, PostProcessPanoptic(is_thing_map, threshold=threshold)
    return model
Esempio n. 2
0
def detr_resnet50_dc5_panoptic(pretrained=False,
                               num_classes=250,
                               threshold=0.85,
                               return_postprocessor=False):
    """
    DETR-DC5 R50 with 6 encoder and 6 decoder layers.

    The last block of ResNet-50 has dilation to increase
    output resolution.
    Achieves 44.6 on COCO val5k.

   threshold is the minimum confidence required for keeping segments in the prediction
    """
    model = _make_detr("resnet50",
                       dilation=True,
                       num_classes=num_classes,
                       mask=True)
    is_thing_map = {i: i <= 90 for i in range(250)}
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url=
            "https://dl.fbaipublicfiles.com/detr/detr-r50-dc5-panoptic-da08f1b1.pth",
            map_location="cpu",
            check_hash=True,
        )
        model.load_state_dict(checkpoint["model"])
    if return_postprocessor:
        return model, PostProcessPanoptic(is_thing_map, threshold=threshold)
    return model