def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print ("Loading model: " + modelpath)
    print ("Loading weights: " + weightspath)


    model = Net(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(model, state_dict):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print ("Model and weights LOADED successfully")

    model.eval()

    if(not os.path.exists(args.datadir)):
        print ("Error: datadir could not be loaded")


    loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset),
        num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)

    for step, (images, labels, filename, filenameGt) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images)
        #targets = Variable(labels)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        #print (numpy.unique(label.numpy()))  #debug

        
        filenameSave = "./save_results/" + filename[0].split("leftImg8bit/")[1]
        
        os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
        #image_transform(label.byte()).save(filenameSave)
        label_cityscapes.save(filenameSave)

        print (step, filenameSave)
Example #2
0
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    model = Net(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    image = Image.open(args.test_img).convert('RGB')
    image = image.resize((480, 480))  #
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    images = transform(image).unsqueeze(0).to(device)

    output = model(images)
    label = output[0].max(0)[1].byte().cpu().data
    label = label.numpy()

    mask = get_color_pallete(label, 'ade20k')
    outname = args.test_img.split('.')[0] + '.png'
    mask.save(os.path.join('./', outname))
Example #3
0
def main(args):
    weightspath = args.loadDir + args.loadWeights

    print("Loading weights: " + weightspath)

    model = Net(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    def load_my_state_dict(
        model, state_dict
    ):  # custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))  # 加载权重
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    # 读取数据集目录
    loader = DataLoader(KITTI(args.datadir,
                              input_transform_cityscapes,
                              subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it

    for step, (images, filename) in enumerate(loader):  # 迭代图片

        time_start = time.clock()

        if (not args.cpu):
            images = images.cuda()

        inputs = Variable(images)
        with torch.no_grad():
            outputs = model(
                inputs)  # output 是[1,20,512,1024], 对应于20个类别的置信度,和图像大小

        time_elapsed = (time.clock() - time_start)

        time_start = time.clock()

        instance, confidence, confidence_color, label_color = PostProcess(
            20, 1226, 370)(outputs)  # 太费时了这个, 需要简化

        time_process = (time.clock() - time_start)

        # out_json = {"instance":instance,
        #        "confidence":confidence
        # }

        yaml = YAML()
        code = yaml.load(inp)
        code['instance'] = instance  # 改变inp格式
        code['confidence'] = confidence
        yaml.indent(mapping=6, sequence=4, offset=2)  # 设置yaml空格格式

        # 设置保存路径名字
        filenameSave = "/home/gongyiqun/project/output/08/" + filename[
            0].split(args.subset)[1]
        filenameSave_bel = filenameSave.split(".png")[0] + "_bel.png"
        filenameSave_inf = filenameSave.split(".png")[0] + ".yaml"
        os.makedirs(os.path.dirname(filenameSave), exist_ok=True)

        with open(filenameSave_inf, 'w') as f:
            yaml.dump(code, f)
            # json.dump(out, f)

        # label_save = ToPILImage()(label_color)
        # label_save = label_save.resize((1242, 375), Image.BILINEAR)  # For KITTI only
        # label_save.save(filenameSave)
        # cv2.imwrite(filenameSave, label_color)
        # cv2.imwrite(filenameSave_bel, confidence_color)

        print(step, filenameSave, time_elapsed, time_process)
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    model = Net(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    loader = DataLoader(KITTI(args.datadir,
                              input_transform_cityscapes,
                              target_transform_cityscapes,
                              subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)
    # input_transform_cityscapes = Compose([ Resize((512, 1024), Image.BILINEAR), ToTensor(),
    # Normalize([.485, .456, .406], [.229, .224, .225]),])
    # with open(image_path_city('/home/gongyiqun', '4.png'), 'rb') as f:
    #     images = load_image(f).convert('RGB')
    #     images = input_transform_cityscapes(images)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    for step, (images, filename) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images)
        #targets = Variable(labels)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        label_color = Colorize()(label.unsqueeze(0))

        filenameSave = "./save_color(KITTI)/" + filename[0].split(
            "leftImg8bit/")[1]
        # filenameSave = "./save_color/"+"Others"
        os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
        #image_transform(label.byte()).save(filenameSave)
        label_save = ToPILImage()(label_color)
        label_save = label_save.resize((1242, 375),
                                       Image.BILINEAR)  # For KITTI only
        label_save.save(filenameSave)

        if (args.visualize):
            vis.image(label_color.numpy())
        print(step, filenameSave)
Example #5
0
class Segmentation(Node):
    def __init__(self):
        super().__init__('segmentation_publisher')
        self.bridge = CvBridge()
        self.pub_seg = self.create_publisher(Image, '/LEDNet/segmented_image')
        ##self.pub_seg = self.create_publisher(Image, '/amsl/demo/segmentation')
        self.sub = self.create_subscription(CompressedImage, '/usb_cam/image_raw/compressed', self.callback)

        #Please fix "weightspath" if you use other workspace. 
        self.weightspath = '/home/amsl/ros2_ws_2/src/segmentation_publisher/model_001/model_best.pth'
        self.model = Net(NUM_CLASSES)
        self.model = torch.nn.DataParallel(self.model)

        self.model = self.model.cuda()

        self.model = load_state(self.model, torch.load(self.weightspath))

        self.model.eval()
        print("Ready")
        self.count = 0

    def callback(self, oimg):

        try:
            #if you want to save images and labels ,please uncomment following codes(No.1 to No.4).
            #NO.1 #write_image_name = "image_" + str(self.count) + ".jpg"
            #No.2 #write_label_name = "label_" + str(self.count) + ".jpg"

            oimg_b = bytes(oimg.data)
            np_arr = np.fromstring(oimg_b, np.uint8)
            img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)

            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            #No.3 #cv2.imwrite("/home/amsl/images/output/seg_pub/image/" + write_image_name, img)
            
            img_size = img.shape
            image = PIL_Image.fromarray(img)
            image = image.resize((1024,512),PIL_Image.NEAREST)

            image = ToTensor()(image)
            image = torch.Tensor(np.array([image.numpy()]))

            image = image.cuda()
            
            input_image = Variable(image)
            
            with torch.no_grad():
                output_image = self.model(input_image)
            
            label = output_image[0].max(0)[1].byte().cpu().data
            label_color = Colorize()(label.unsqueeze(0))
            label_pub = ToPILImage()(label_color)
            label_pub = label_pub.resize((img_size[1],img_size[0]),PIL_Image.NEAREST)
            label_pub = np.asarray(label_pub)
            
            #show label.
            #plt.imshow(label_pub)
            #plt.pause(0.001)
            
            #No.4 #cv2.imwrite("/home/amsl/images/output/seg_pub/label/" + write_label_name, label_pub)

            #self.pub_seg.publish(self.bridge.cv2_to_imgmsg(label_pub, "bgr8"))
            self.pub_seg.publish(self.bridge.cv2_to_imgmsg(label_pub, "rgb8"))
            print("published") 
            self.count += 1
        
        except CvBridgeError as e:
            print(e)