def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print ("Loading model: " + modelpath)
    print ("Loading weights: " + weightspath)


    model = Net(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(model, state_dict):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print ("Model and weights LOADED successfully")

    model.eval()

    if(not os.path.exists(args.datadir)):
        print ("Error: datadir could not be loaded")


    loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset),
        num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)

    for step, (images, labels, filename, filenameGt) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images)
        #targets = Variable(labels)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        #print (numpy.unique(label.numpy()))  #debug

        
        filenameSave = "./save_results/" + filename[0].split("leftImg8bit/")[1]
        
        os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
        #image_transform(label.byte()).save(filenameSave)
        label_cityscapes.save(filenameSave)

        print (step, filenameSave)
Beispiel #2
0
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    model = Net(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    image = Image.open(args.test_img).convert('RGB')
    image = image.resize((480, 480))  #
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    images = transform(image).unsqueeze(0).to(device)

    output = model(images)
    label = output[0].max(0)[1].byte().cpu().data
    label = label.numpy()

    mask = get_color_pallete(label, 'ade20k')
    outname = args.test_img.split('.')[0] + '.png'
    mask.save(os.path.join('./', outname))
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    model = Net(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    loader = DataLoader(KITTI(args.datadir,
                              input_transform_cityscapes,
                              target_transform_cityscapes,
                              subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)
    # input_transform_cityscapes = Compose([ Resize((512, 1024), Image.BILINEAR), ToTensor(),
    # Normalize([.485, .456, .406], [.229, .224, .225]),])
    # with open(image_path_city('/home/gongyiqun', '4.png'), 'rb') as f:
    #     images = load_image(f).convert('RGB')
    #     images = input_transform_cityscapes(images)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    for step, (images, filename) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images)
        #targets = Variable(labels)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        label_color = Colorize()(label.unsqueeze(0))

        filenameSave = "./save_color(KITTI)/" + filename[0].split(
            "leftImg8bit/")[1]
        # filenameSave = "./save_color/"+"Others"
        os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
        #image_transform(label.byte()).save(filenameSave)
        label_save = ToPILImage()(label_color)
        label_save = label_save.resize((1242, 375),
                                       Image.BILINEAR)  # For KITTI only
        label_save.save(filenameSave)

        if (args.visualize):
            vis.image(label_color.numpy())
        print(step, filenameSave)
Beispiel #4
0
def main(args):
    weightspath = args.loadDir + args.loadWeights

    print("Loading weights: " + weightspath)

    model = Net(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    def load_my_state_dict(
        model, state_dict
    ):  # custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))  # 加载权重
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    # 读取数据集目录
    loader = DataLoader(KITTI(args.datadir,
                              input_transform_cityscapes,
                              subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it

    for step, (images, filename) in enumerate(loader):  # 迭代图片

        time_start = time.clock()

        if (not args.cpu):
            images = images.cuda()

        inputs = Variable(images)
        with torch.no_grad():
            outputs = model(
                inputs)  # output 是[1,20,512,1024], 对应于20个类别的置信度,和图像大小

        time_elapsed = (time.clock() - time_start)

        time_start = time.clock()

        instance, confidence, confidence_color, label_color = PostProcess(
            20, 1226, 370)(outputs)  # 太费时了这个, 需要简化

        time_process = (time.clock() - time_start)

        # out_json = {"instance":instance,
        #        "confidence":confidence
        # }

        yaml = YAML()
        code = yaml.load(inp)
        code['instance'] = instance  # 改变inp格式
        code['confidence'] = confidence
        yaml.indent(mapping=6, sequence=4, offset=2)  # 设置yaml空格格式

        # 设置保存路径名字
        filenameSave = "/home/gongyiqun/project/output/08/" + filename[
            0].split(args.subset)[1]
        filenameSave_bel = filenameSave.split(".png")[0] + "_bel.png"
        filenameSave_inf = filenameSave.split(".png")[0] + ".yaml"
        os.makedirs(os.path.dirname(filenameSave), exist_ok=True)

        with open(filenameSave_inf, 'w') as f:
            yaml.dump(code, f)
            # json.dump(out, f)

        # label_save = ToPILImage()(label_color)
        # label_save = label_save.resize((1242, 375), Image.BILINEAR)  # For KITTI only
        # label_save.save(filenameSave)
        # cv2.imwrite(filenameSave, label_color)
        # cv2.imwrite(filenameSave_bel, confidence_color)

        print(step, filenameSave, time_elapsed, time_process)
Beispiel #5
0
class Segmentation(Node):
    def __init__(self):
        super().__init__('segmentation_publisher')
        self.bridge = CvBridge()
        self.pub_seg = self.create_publisher(Image, '/LEDNet/segmented_image')
        ##self.pub_seg = self.create_publisher(Image, '/amsl/demo/segmentation')
        self.sub = self.create_subscription(CompressedImage, '/usb_cam/image_raw/compressed', self.callback)

        #Please fix "weightspath" if you use other workspace. 
        self.weightspath = '/home/amsl/ros2_ws_2/src/segmentation_publisher/model_001/model_best.pth'
        self.model = Net(NUM_CLASSES)
        self.model = torch.nn.DataParallel(self.model)

        self.model = self.model.cuda()

        self.model = load_state(self.model, torch.load(self.weightspath))

        self.model.eval()
        print("Ready")
        self.count = 0

    def callback(self, oimg):

        try:
            #if you want to save images and labels ,please uncomment following codes(No.1 to No.4).
            #NO.1 #write_image_name = "image_" + str(self.count) + ".jpg"
            #No.2 #write_label_name = "label_" + str(self.count) + ".jpg"

            oimg_b = bytes(oimg.data)
            np_arr = np.fromstring(oimg_b, np.uint8)
            img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)

            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            #No.3 #cv2.imwrite("/home/amsl/images/output/seg_pub/image/" + write_image_name, img)
            
            img_size = img.shape
            image = PIL_Image.fromarray(img)
            image = image.resize((1024,512),PIL_Image.NEAREST)

            image = ToTensor()(image)
            image = torch.Tensor(np.array([image.numpy()]))

            image = image.cuda()
            
            input_image = Variable(image)
            
            with torch.no_grad():
                output_image = self.model(input_image)
            
            label = output_image[0].max(0)[1].byte().cpu().data
            label_color = Colorize()(label.unsqueeze(0))
            label_pub = ToPILImage()(label_color)
            label_pub = label_pub.resize((img_size[1],img_size[0]),PIL_Image.NEAREST)
            label_pub = np.asarray(label_pub)
            
            #show label.
            #plt.imshow(label_pub)
            #plt.pause(0.001)
            
            #No.4 #cv2.imwrite("/home/amsl/images/output/seg_pub/label/" + write_label_name, label_pub)

            #self.pub_seg.publish(self.bridge.cv2_to_imgmsg(label_pub, "bgr8"))
            self.pub_seg.publish(self.bridge.cv2_to_imgmsg(label_pub, "rgb8"))
            print("published") 
            self.count += 1
        
        except CvBridgeError as e:
            print(e)
Beispiel #6
0
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print ("Loading model: " + modelpath)
    print ("Loading weights: " + weightspath)

    model = Net(NUM_CLASSES)

    #model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = torch.nn.DataParallel(model).cuda()

    def load_my_state_dict(model, state_dict):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                if name.startswith("module."):
                    own_state[name.split("module.")[-1]].copy_(param)
                else:
                    print(name, " not loaded")
                    continue
            else:
                own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath, map_location=lambda storage, loc: storage))
    print ("Model and weights LOADED successfully")


    model.eval()

    if(not os.path.exists(args.datadir)):
        print ("Error: datadir could not be loaded")


    loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)


    iouEvalVal = iouEval(NUM_CLASSES)

    start = time.time()

    for step, (images, labels, filename, filenameGt) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            labels = labels.cuda()

        inputs = Variable(images)
        with torch.no_grad():
            outputs = model(inputs)

        iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, labels)

        filenameSave = filename[0].split("leftImg8bit/")[1] 

        print (step, filenameSave)


    iouVal, iou_classes = iouEvalVal.getIoU()

    iou_classes_str = []
    for i in range(iou_classes.size(0)):
        iouStr = getColorEntry(iou_classes[i])+'{:0.2f}'.format(iou_classes[i]*100) + '\033[0m'
        iou_classes_str.append(iouStr)

    print("---------------------------------------")
    print("Took ", time.time()-start, "seconds")
    print("=======================================")
    #print("TOTAL IOU: ", iou * 100, "%")
    print("Per-Class IoU:")
    print(iou_classes_str[0], "Road")
    print(iou_classes_str[1], "sidewalk")
    print(iou_classes_str[2], "building")
    print(iou_classes_str[3], "wall")
    print(iou_classes_str[4], "fence")
    print(iou_classes_str[5], "pole")
    print(iou_classes_str[6], "traffic light")
    print(iou_classes_str[7], "traffic sign")
    print(iou_classes_str[8], "vegetation")
    print(iou_classes_str[9], "terrain")
    print(iou_classes_str[10], "sky")
    print(iou_classes_str[11], "person")
    print(iou_classes_str[12], "rider")
    print(iou_classes_str[13], "car")
    print(iou_classes_str[14], "truck")
    print(iou_classes_str[15], "bus")
    print(iou_classes_str[16], "train")
    print(iou_classes_str[17], "motorcycle")
    print(iou_classes_str[18], "bicycle")
    print("=======================================")
    iouStr = getColorEntry(iouVal)+'{:0.2f}'.format(iouVal*100) + '\033[0m'
    print ("MEAN IoU: ", iouStr, "%")