예제 #1
0
def train(args):
    step_size  = 50
    gamma      = 0.5
    #vgg_model = VGGNet(requires_grad=True, remove_fc=True)
    #model = FCNs(pretrained_net=vgg_model, n_class=3).to(device)
    model = ESPNet_Encoder(3, p=2, q=3).to(device)
    #model=torchvision.models.segmentation.fcn_resnet101(pretrained=False, progress=True, num_classes=3).to(device)
    #model = Unet(3, 3).to(device)
    batch_size = args.batch_size
    #criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), weight_decay=1e-5)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)  # decay LR by a factor of 0.5 every 30 epochs
    criterion = nn.BCEWithLogitsLoss()
    #criterion = nn.CrossEntropyLoss()
    #optimizer = optim.RMSprop(model.parameters(), lr=1e-4, momentum=0, weight_decay=1e-5)
    #optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
    road_dataset = RoadDataset("./data_road/training/",transform=x_transforms,target_transform=y_transforms)
    dataloaders = DataLoader(road_dataset, batch_size=batch_size, shuffle=True)
    train_model(model, criterion, optimizer, dataloaders,scheduler)
예제 #2
0
def Generate(args):
    import PIL.Image as Image
    import matplotlib.pyplot as plt
    #model = Unet(3, 3)
    #vgg_model = VGGNet(requires_grad=True, remove_fc=True)
    #model = FCN8s(pretrained_net=vgg_model, n_class=3)
    model = ESPNet_Encoder(3, p=2, q=3)
    model.load_state_dict(torch.load(args.ckpt,map_location='cpu'))
    model.eval()
    count=0
    root = './data_road/training/'
    print("Generate...")
    for filename in os.listdir(root):
        if filename == 'image':
            for filename2 in os.listdir(root+filename):
                if filename2 == '.ipynb_checkpoints':
                    break
                imgroot=os.path.join(root+'image/'+filename2)
                name = filename2[:-4]
                print("Image Processing: ",name)
                img = Image.open(imgroot)
                img = x_transforms(img)
                img = img.view(1,3,352,1216) #for Fcn 
                with torch.no_grad():
                    output= model(img)
                    output = torch.softmax(output,dim=1)
                    N, _, h, w = output.shape
                    pred = output.transpose(0, 2).transpose(3, 1).reshape(-1, 3).argmax(axis=1).reshape(N, h, w) #class 3
                    pred = pred.squeeze(0)
                    Decode_image(pred,name)
def main():
    #args = parse_args()
    '''model = init_detector(
        args.config, args.checkpoint, device=torch.device('cuda', args.device))
    '''
    model = init_detector(
        '../configs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py',
        '../checkpoints/mask_rcnn_r50_fpn_1x_city_20190727-9b3c56a5.pth',
        device=torch.device('cuda', 0))
    modelseg = ESPNet_Encoder(3, p=2, q=3).to(device)
    modelseg.load_state_dict(
        torch.load('bdd_weights_20_ESPNET_road.pth', map_location='cpu'))
    modelseg.eval()
    #camera = cv2.VideoCapture(args.camera_id)
    camera = cv2.VideoCapture('umgt.avi')
    print('Press "Esc", "q" or "Q" to exit.')
    if camera.isOpened():
        while True:
            ret_val, img = camera.read()
            imgroad = RoadSeg(img, modelseg)
            #imgroad = evaluateModel(img,modelseg)
            result = inference_detector(model, img)

            ch = cv2.waitKey(1)
            if ch == 27 or ch == ord('q') or ch == ord('Q'):
                break

            #show_result(img, result, model.CLASSES, score_thr=args.score_thr, wait_time=1)
            show_result(imgroad,
                        result,
                        model.CLASSES,
                        score_thr=0.5,
                        wait_time=1)
    cv2.destroyAllWindows()
예제 #4
0
def oonx(model):
    dummy_input = torch.randn(1, 3, 720, 1280).to(device)
    model = ESPNet_Encoder(3, p=2, q=3).to(device)
    model.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
    model.eval()
    torch.onnx.export(model, dummy_input, "bbd.onnx", verbose=True)
예제 #5
0
def Model_visualization(args):
    from torchsummary import summary
    #model = Unet(3, 3).to(device)
    model = ESPNet_Encoder(3, p=2, q=3).to(device)
    summary(model, input_size=(3, 720, 1280))