Exemplo n.º 1
0
def rf_lw152(num_classes, imagenet=False, pretrained=True, **kwargs):
    model = ResNetLW(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, **kwargs)
    if imagenet:
        key = "152_imagenet"
        url = models_urls[key]
        model.load_state_dict(maybe_download(key, url), strict=False)
    elif pretrained:
        dataset = data_info.get(num_classes, None)
        if dataset:
            bname = "152_" + dataset.lower()
            key = "rf_lw" + bname
            url = models_urls[bname]
            model.load_state_dict(maybe_download(key, url), strict=False)
    return model
def rf101(num_classes, imagenet=False, pretrained=True, **kwargs):
    model = RefineNet(Bottleneck, [3, 4, 23, 3],
                      num_classes=num_classes,
                      **kwargs)
    if imagenet:
        key = '101_imagenet'
        url = models_urls[key]
        model.load_state_dict(maybe_download(key, url), strict=False)
    elif pretrained:
        dataset = data_info.get(21, None)
        bname = '101_' + dataset.lower()
        key = 'rf' + bname
        url = models_urls[bname]
        model.load_state_dict(maybe_download(key, url), strict=False)
    return model
Exemplo n.º 3
0
def model_init(model, num_layers, num_parallel, imagenet=False, pretrained=True):
    if imagenet:
        key = str(num_layers) + '_imagenet'
        url = models_urls[key]
        state_dict = maybe_download(key, url)
        model_dict = expand_model_dict(model.state_dict(), state_dict, num_parallel)
        model.load_state_dict(model_dict, strict=True)
    elif pretrained:
        dataset = data_info.get(num_classes, None)
        if dataset:
            bname = str(num_layers) + '_' + dataset.lower()
            key = 'rf' + bname
            url = models_urls[bname]
            model.load_state_dict(maybe_download(key, url), strict=False)
    return model
Exemplo n.º 4
0
def rf_lw50(num_classes, imagenet=True, pretrained=False, **kwargs):
    model = ResNetLW(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, **kwargs)
    if imagenet:
        key = "50_imagenet"
        url = models_urls[key]
        model.load_state_dict(maybe_download(key, url), strict=False)
    elif pretrained:
        dataset = data_info.get(num_classes, None)
        #cpkt = torch.load("/home/kong/Documents/light-weight-refinenet-master/ckpt/checkpoint.pth.tar")
        if dataset:
            bname = "50_" + dataset.lower()
            key = "rf_lw" + bname
            url = models_urls[bname]
           # model.load_state_dict(cpkt["segmenter"])
            model.load_state_dict(maybe_download(key, url), strict=False)
    return model
Exemplo n.º 5
0
def rf_lw50(num_classes, imagenet=False, pretrained=True, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, **kwargs)
    if imagenet:
        key = '50_imagenet'
        url = models_urls[key]
        pretrained_model_dict = maybe_download(key, url)
        model.load_state_dict(pretrained_model_dict, strict=False)
        use_pretrained_depth_track(model.state_dict(),
                                   pretrained_model_dict)  ### new
    elif pretrained:
        dataset = data_info.get(num_classes, None)
        if dataset:
            bname = '50_' + dataset.lower()
            key = 'rf_lw' + bname
            url = models_urls[bname]
            model.load_state_dict(maybe_download(key, url), strict=False)
    return model
Exemplo n.º 6
0
def mbv2(num_classes, imagenet=False, pretrained=True, **kwargs):
    """Constructs the network.

    Args:
        num_classes (int): the number of classes for the segmentation head to output.

    """
    model = MBv2(num_classes, **kwargs)
    if imagenet:
        key = "mbv2_imagenet"
        url = models_urls[key]
        model.load_state_dict(maybe_download(key, url), strict=False)
    elif pretrained:
        dataset = data_info.get(num_classes, None)
        if dataset:
            bname = "mbv2_" + dataset.lower()
            key = "rf_lw" + bname
            url = models_urls[bname]
            model.load_state_dict(maybe_download(key, url), strict=False)
    return model
Exemplo n.º 7
0
def rf_lw152(num_classes, pretrained=True, **kwargs):
    model = ResNetLW(Bottleneck, [3, 8, 36, 3],
                     num_classes=num_classes,
                     **kwargs)
    if pretrained:
        dataset = data_info.get(num_classes, None)
        if dataset:
            bname = '152_' + dataset.lower()
            key = 'rf_lw' + bname
            url = models_urls[bname]
            model.load_state_dict(maybe_download(key, url))
    return model
Exemplo n.º 8
0
def rf_lw152(num_classes, imagenet=False, pretrained=True, **kwargs):
    model = ResNetLW(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, **kwargs)
    if imagenet:
        key = "152_imagenet"
        url = models_urls[key]
        model.load_state_dict(maybe_download(key, url), strict=False)
    elif pretrained:
        dataset = data_info.get(num_classes, None)
        if dataset:
            bname = "152_" + dataset.lower()
            key = "rf_lw" + bname
            url = models_urls[bname]
            print(url)
            if bname == "152_ear":
              X = maybe_download(key, url)
              new_state_dict = dict()
              for k, v in X['model'].items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
              model.load_state_dict(new_state_dict)
            else:
              model.load_state_dict(maybe_download(key, url), strict=False)
    return model
def mbv2(num_classes, pretrained=True, **kwargs):
    """Constructs the network.

    Args:
        num_classes (int): the number of classes for the segmentation head to output.

    """
    model = MBv2(num_classes, **kwargs)
    if pretrained:
        dataset = data_info.get(num_classes, None)
        if dataset:
            bname = 'mbv2_' + dataset.lower()
            key = 'rf_lw' + bname
            url = models_urls[bname]
            model.load_state_dict(maybe_download(key, url))
    return model
Exemplo n.º 10
0
def rf_lw101(num_classes, imagenet=False, pretrained=True, **kwargs):
    model = ResNetLW(Bottleneck, [3, 4, 23, 3],
                     num_classes=num_classes,
                     **kwargs)
    if imagenet:
        key = '101_imagenet'
        url = models_urls[key]
        model.load_state_dict(maybe_download(key, url), strict=False)
    elif pretrained:
        dataset = data_info.get(num_classes, None)
        if dataset:
            bname = '101_' + dataset.lower()
            key = 'rf_lw' + bname
            url = models_urls[bname]
            model.load_state_dict(torch.load(
                '/home/yangjing/code/wash-hand/light-weight-refinenet-master/ckpt/checkpoint.pth.tar'
            ),
                                  strict=False)
    return model
Exemplo n.º 11
0
def rf_lw152(num_classes, imagenet=False, pretrained=True, **kwargs):
    model = ResNetLW(Bottleneck, [3, 8, 36, 3],
                     num_classes=num_classes,
                     **kwargs)
    if imagenet:
        key = '152_imagenet'
        url = models_urls[key]
        model.load_state_dict(maybe_download(key, url), strict=False)
    elif pretrained:
        dataset = data_info.get(num_classes, None)
        if dataset:
            bname = '152_' + dataset.lower()
            key = 'rf_lw' + bname
            url = models_urls[bname]
            #model.load_state_dict(maybe_download(key, url), strict=False)
            model.load_state_dict(torch.load(
                '/home/yangjing/code/wash-hand/light-weight-refinenet-master/models/resnet/152_person.ckpt'
            ),
                                  strict=False)

    return model
Exemplo n.º 12
0
def rf_lw50(num_classes, imagenet=False, pretrained=True, **kwargs):
    model = ResNetLW(Bottleneck, [3, 4, 6, 3],
                     num_classes=num_classes,
                     **kwargs)
    if imagenet:
        key = '50_imagenet'
        url = models_urls[key]
        model.load_state_dict(maybe_download(key, url), strict=False)
    elif pretrained:
        dataset = data_info.get(num_classes, None)
        if dataset:
            print('load /snap/40')
            #bname = '50_' + dataset.lower()
            #key = 'rf_lw' + bname
            #url = models_urls[bname]
            #model.load_state_dict(maybe_download(key, url), strict=False)
            #model.load_state_dict(torch.load('/home/yangjing/code/wash-hand/light-weight-refinenet-master/models/resnet/50_person.ckpt'),strict=False)
            mload = torch.load(
                '/home/yangjing/code/wash-hand/light-weight-refinenet-master/snap/40_checkpoint.pth.tar'
            )
            for k, v in enumerate(mload):
                print(k, v)
            model.load_state_dict(mload, strict=False)
    return model