def __init__(self):
     self.cmap = np.load(
         '/opt/carla/PythonAPI/carla_scripts/light-weight-refinenet/utils/cmap.npy'
     )
     self.has_cuda = torch.cuda.is_available()
     self.n_classes = 60
     self.net = rf_lw152(self.n_classes, pretrained=True).eval()
     if self.has_cuda:
         self.net = self.net.cuda()
Beispiel #2
0
def create_segmenter(net, pretrained, num_classes):
    """Create Encoder; for now only ResNet [50,101,152]"""
    from models.resnet import rf_lw50, rf_lw101, rf_lw152
    if str(net) == '50':
        return rf_lw50(num_classes, imagenet=pretrained)
    elif str(net) == '101':
        return rf_lw101(num_classes, imagenet=pretrained)
    elif str(net) == '152':
        return rf_lw152(num_classes, imagenet=pretrained)
    else:
        raise ValueError("{} is not supported".format(str(net)))
Beispiel #3
0
def get_segmenter(
    enc_backbone,
    enc_pretrained,
    num_classes,
):
    """Create Encoder-Decoder; for now only ResNet [50,101,152] Encoders are supported"""
    if enc_backbone == "50":
        return rf_lw50(num_classes, imagenet=enc_pretrained)
    elif enc_backbone == "101":
        return rf_lw101(num_classes, imagenet=enc_pretrained)
    elif enc_backbone == "152":
        return rf_lw152(num_classes, imagenet=enc_pretrained)
    else:
        raise ValueError("{} is not supported".format(str(enc_backbone)))
def create_segmenter(net, pretrained, num_classes):
    """Create Encoder; for now only ResNet [50,101,152]"""
    import sys
    sys.path.append("../")
    from models.resnet import rf_lw50, rf_lw101, rf_lw152

    init_model = '../models/resnet/50_person.ckpt'
    if str(net) == '50':
        return rf_lw50(num_classes, model_path=init_model, imagenet=pretrained)
    elif str(net) == '101':
        return rf_lw101(num_classes,
                        model_path=init_model,
                        imagenet=pretrained)
    elif str(net) == '152':
        return rf_lw152(num_classes, imagenet=pretrained)
    else:
        raise ValueError("{} is not supported".format(str(net)))
Beispiel #5
0
def get_model(model_name, classes, pre_train=False, mode='train'):
    if model_name == 'ESpnet_2_8_decoder':
        from models import Espnet
        if pre_train:
            pre_train_path = os.path.join(pre_train)
            model = Espnet.ESPNet(classes, 2, 8, pre_train_path, mode=mode)
        else:
            model = Espnet.ESPNet(classes, 2, 8, mode=mode)
    elif model_name == 'ESpnet_2_8':
        from models import Espnet
        model = Espnet.ESPNet_Encoder(classes, 2, 8)
    elif model_name == 'EDAnet':
        from models import EDANet
        model = EDANet.EDANet(classes)
    elif model_name == 'ERFnet':
        from models import ERFnet
        model = ERFnet.Net(classes)
    elif model_name == 'Enet':
        from models import Enet
        model = Enet.ENet(classes)
    elif model_name == 'IRRnet_2_8':
        from models import Irregularity_conv
        model = Irregularity_conv.ESPNet(classes, 2, 8, mode=mode)
    elif model_name == 'MOBILE_V2':
        from models import mobilenet
        model = mobilenet.mbv2(classes)
    elif model_name == 'RF_LW_resnet_50':
        from models import resnet
        model = resnet.rf_lw50(classes)
    elif model_name == 'RF_LW_resnet_101':
        from models import resnet
        model = resnet.rf_lw101(classes)
    elif model_name == 'RF_LW_resnet_152':
        from models import resnet
        model = resnet.rf_lw152(classes)
    elif model_name == 'Bisenet':
        from models import BiSeNet
        model = BiSeNet.BiSeNet(out_class=classes)
    elif model_name == 'Basenet':
        from models import Basenet
        model = Basenet.Basenet(classes)
    else:
        raise NotImplementedError
    return model
Beispiel #6
0
def create_multiNet(net,
                    num_classes,
                    num_depths=10,
                    pretrained=None,
                    task_type=2):
    from models.multitask import multiNet
    from models.resnet import rf_lw50, rf_lw101, rf_lw152

    if str(net) == 'multi':
        return multiNet(num_classes, num_depths, task_type)
    elif str(net) == '50':
        return rf_lw50(num_classes, imagenet=pretrained)
    elif str(net) == '101':
        return rf_lw101(num_classes, imagenet=pretrained)
    elif str(net) == '152':
        return rf_lw152(num_classes, imagenet=pretrained)

    else:
        raise ValueError("{} is not supported".format(str(net)))
Beispiel #7
0
REFINE_SOURCE_DIR = "D:\ProjectsD\ISMAR2019\code\src\SemanticSegmentation\RefineNet"
sys.path.insert(0, REFINE_SOURCE_DIR)
    
from models.resnet import rf_lw152
from utils.helpers import prepare_img
import cv2
import numpy as np
import torch
from PIL import Image


has_cuda = torch.cuda.is_available()
n_classes = 7
result = None

net = rf_lw152(n_classes, pretrained=True).eval().cuda()
with torch.no_grad():
    null_img = np.zeros((480, 640, 3), np.uint8)
    img_inp = torch.tensor(prepare_img(null_img).transpose(2, 0, 1)[None]).float().cuda()
    preds = net(img_inp)[0].data.cpu().numpy().transpose(1, 2, 0)
print(preds.shape)



def execute(rgb_image):
    global result
    with torch.no_grad():
        print('python taken!')
        cv2.imwrite("test.png", rgb_image)
        orig_size = rgb_image.shape[:2][::-1]
        img_inp = torch.tensor(prepare_img(rgb_image).transpose(2, 0, 1)[None]).float()
Beispiel #8
0
import torch
from models.resnet import rf_lw50, rf_lw152
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from utils.helpers import prepare_img
import cv2
import torch.nn as nn
import os

net = nn.DataParallel(rf_lw152(num_classes=10))
if torch.cuda.is_available:
    net = net.cuda()
    net.load_state_dict(torch.load('./ckpt/resnet152.tar')['segmenter'])
else:
    net.load_state_dict(
        torch.load('./ckpt/resnet152.tar', map_location='cpu')['segmenter'])
net.eval()

color_map = np.load('../utils/color_map.npy')

img_dir = '../examples/imgs/VrepYCB/augment/'
imgs = []
for img in os.listdir(img_dir):
    if img.endswith('.jpg'):
        imgs.append(os.path.join(img_dir, img))

n_rows = len(imgs)

plt.figure(figsize=(16, 12))
idx = 1