Exemplo n.º 1
0
def Generate(args):
    import PIL.Image as Image
    import matplotlib.pyplot as plt
    #model = Unet(3, 3)
    #vgg_model = VGGNet(requires_grad=True, remove_fc=True)
    #model = FCN8s(pretrained_net=vgg_model, n_class=3)
    model = ESPNet_Encoder(3, p=2, q=3)
    model.load_state_dict(torch.load(args.ckpt,map_location='cpu'))
    model.eval()
    count=0
    root = './data_road/training/'
    print("Generate...")
    for filename in os.listdir(root):
        if filename == 'image':
            for filename2 in os.listdir(root+filename):
                if filename2 == '.ipynb_checkpoints':
                    break
                imgroot=os.path.join(root+'image/'+filename2)
                name = filename2[:-4]
                print("Image Processing: ",name)
                img = Image.open(imgroot)
                img = x_transforms(img)
                img = img.view(1,3,352,1216) #for Fcn 
                with torch.no_grad():
                    output= model(img)
                    output = torch.softmax(output,dim=1)
                    N, _, h, w = output.shape
                    pred = output.transpose(0, 2).transpose(3, 1).reshape(-1, 3).argmax(axis=1).reshape(N, h, w) #class 3
                    pred = pred.squeeze(0)
                    Decode_image(pred,name)
def main():
    #args = parse_args()
    '''model = init_detector(
        args.config, args.checkpoint, device=torch.device('cuda', args.device))
    '''
    model = init_detector(
        '../configs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py',
        '../checkpoints/mask_rcnn_r50_fpn_1x_city_20190727-9b3c56a5.pth',
        device=torch.device('cuda', 0))
    modelseg = ESPNet_Encoder(3, p=2, q=3).to(device)
    modelseg.load_state_dict(
        torch.load('bdd_weights_20_ESPNET_road.pth', map_location='cpu'))
    modelseg.eval()
    #camera = cv2.VideoCapture(args.camera_id)
    camera = cv2.VideoCapture('umgt.avi')
    print('Press "Esc", "q" or "Q" to exit.')
    if camera.isOpened():
        while True:
            ret_val, img = camera.read()
            imgroad = RoadSeg(img, modelseg)
            #imgroad = evaluateModel(img,modelseg)
            result = inference_detector(model, img)

            ch = cv2.waitKey(1)
            if ch == 27 or ch == ord('q') or ch == ord('Q'):
                break

            #show_result(img, result, model.CLASSES, score_thr=args.score_thr, wait_time=1)
            show_result(imgroad,
                        result,
                        model.CLASSES,
                        score_thr=0.5,
                        wait_time=1)
    cv2.destroyAllWindows()
Exemplo n.º 3
0
def oonx(model):
    dummy_input = torch.randn(1, 3, 720, 1280).to(device)
    model = ESPNet_Encoder(3, p=2, q=3).to(device)
    model.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
    model.eval()
    torch.onnx.export(model, dummy_input, "bbd.onnx", verbose=True)