def main():
    os.makedirs(OUT_PATH, exist_ok=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    seg_model = SegNet(3, 2)
    if device == 'cuda':
        seg_model.to(device)
    seg_model.load_state_dict(torch.load(CKPT_PATH))

    seg_model.eval()

    test_set = LaneTestDataset(list_path='./test.tsv',
                               dir_path='./data_road',
                               img_shape=(IMG_W, IMG_H))
    test_loader = DataLoader(test_set,
                             batch_size=1,
                             shuffle=False,
                             num_workers=1)

    with torch.no_grad():
        for image, image_path in tqdm(test_loader):
            image = image.to(device)
            output = seg_model(image)
            output = torch.sigmoid(output)
            mask = torch.argmax(output, dim=1).cpu().numpy().transpose(
                (1, 2, 0))
            mask = mask.reshape(IMG_H, IMG_W)
            image = image.cpu().numpy().reshape(3, IMG_H, IMG_W).transpose(
                (1, 2, 0)) * 255
            image[..., 2] = np.where(mask == 0, 255, image[..., 2])

            cv2.imwrite(
                os.path.join(OUT_PATH, os.path.basename(image_path[0])), image)
示例#2
0
def main():
    # get model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # fcn_model = FCNs(pretrained_net=VGGNet(pretrained=True, requires_grad=True))
    seg_model = SegNet(3, 2)
    # seg_model.load_weights('./vgg16-397923af.pth')
    # criterion = nn.BCELoss()
    criterion = BCEFocalLoss()
    optimizer = optim.Adam(seg_model.parameters(), lr=LR, weight_decay=0.0001)
    evaluator = Evaluator(num_class=2)

    if device == 'cuda':
        seg_model.to(device)
        criterion.to(device)

    # get dataloader
    train_set = LaneClsDataset(list_path='train.tsv',
                               dir_path='data_road',
                               img_shape=(IMG_W, IMG_H))
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
                              shuffle=True, num_workers=8)

    val_set = LaneClsDataset(list_path='val.tsv',
                             dir_path='data_road',
                             img_shape=(IMG_W, IMG_H))
    val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1)

    # info records
    loss_dict = defaultdict(list)
    px_acc_dict = defaultdict(list)
    mean_px_acc_dict = defaultdict(list)
    mean_iou_dict = defaultdict(list)
    freq_iou_dict = defaultdict(list)

    for epoch_idx in range(1, MAX_EPOCH + 1):
        # train stage
        seg_model.train()
        evaluator.reset()
        train_loss = 0.0
        for batch_idx, (image, label) in enumerate(train_loader):

            lr = LR
            lr = lr_func((epoch_idx-1) * 88 + batch_idx, lr)
            for param in optimizer.param_groups:
                param['lr']=lr

            image = image.to(device)
            # print(label.shape)
            # label = label.reshape(BATCH_SIZE, 288, 800)
            label = label.to(device)
            optimizer.zero_grad()
            output = seg_model(image)
            output = torch.sigmoid(output)

            loss = criterion(output, label.long())
            loss.backward()

            evaluator.add_batch(torch.argmax(output, dim=1).cpu().numpy(),
                                torch.argmax(label, dim=1).cpu().numpy())
            train_loss += loss.item()
            print("[Train][Epoch] {}/{}, [Batch] {}/{}, [lr] {:.6f},[Loss] {:.6f}".format(epoch_idx,
                                                                              MAX_EPOCH,
                                                                              batch_idx+1,
                                                                              len(train_loader),
                                                                              lr,
                                                                              loss.item()))
            optimizer.step()
        loss_dict['train'].append(train_loss/len(train_loader))
        px_acc = evaluator.Pixel_Accuracy() * 100
        px_acc_dict['train'].append(px_acc)
        mean_px_acc = evaluator.Pixel_Accuracy_Class() * 100
        mean_px_acc_dict['train'].append(mean_px_acc)
        mean_iou = evaluator.Mean_Intersection_over_Union() * 100
        mean_iou_dict['train'].append(mean_iou)
        freq_iou = evaluator.Frequency_Weighted_Intersection_over_Union() * 100
        freq_iou_dict['train'].append(freq_iou)
        print("[Train][Epoch] {}/{}, [PA] {:.2f}%, [MeanPA] {:.2f}%, [MeanIOU] {:.2f}%, ""[FreqIOU] {:.2f}%".format(
            epoch_idx,
            MAX_EPOCH,
            px_acc,
            mean_px_acc,
            mean_iou,
            freq_iou))

        evaluator.reset()
        # validate stage
        seg_model.eval()
        with torch.no_grad():
            val_loss = 0.0
            for image, label in val_loader:
                image, label = image.to(device), label.to(device)
                output = seg_model(image)
                output = torch.sigmoid(output)
                loss = criterion(output, label.long())
                val_loss += loss.item()
                evaluator.add_batch(torch.argmax(output, dim=1).cpu().numpy(),
                                    torch.argmax(label, dim=1).cpu().numpy())
            val_loss /= len(val_loader)
            loss_dict['val'].append(val_loss)
            px_acc = evaluator.Pixel_Accuracy() * 100
            px_acc_dict['val'].append(px_acc)
            mean_px_acc = evaluator.Pixel_Accuracy_Class() * 100
            mean_px_acc_dict['val'].append(mean_px_acc)
            mean_iou = evaluator.Mean_Intersection_over_Union() * 100
            mean_iou_dict['val'].append(mean_iou)
            freq_iou = evaluator.Frequency_Weighted_Intersection_over_Union() * 100
            freq_iou_dict['val'].append(freq_iou)
            print("[Val][Epoch] {}/{}, [Loss] {:.6f}, [PA] {:.2f}%, [MeanPA] {:.2f}%, "
                  "[MeanIOU] {:.2f}%, ""[FreqIOU] {:.2f}%".format(epoch_idx,
                                                                  MAX_EPOCH,
                                                                  val_loss,
                                                                  px_acc,
                                                                  mean_px_acc,
                                                                  mean_iou,
                                                                  freq_iou))

        # save model checkpoints
        if epoch_idx % SAVE_INTERVAL == 0 or epoch_idx == MAX_EPOCH:
            os.makedirs(MODEL_CKPT_DIR, exist_ok=True)
            ckpt_save_path = os.path.join(MODEL_CKPT_DIR, 'epoch_{}.pth'.format(epoch_idx))
            torch.save(seg_model.state_dict(), ckpt_save_path)
            print("[Epoch] {}/{}, 模型权重保存至{}".format(epoch_idx, MAX_EPOCH, ckpt_save_path))

    # draw figures
    os.makedirs(FIGURE_DIR, exist_ok=True)
    draw_figure(loss_dict, title='Loss', ylabel='loss', filename='loss.png')
    draw_figure(px_acc_dict, title='Pixel Accuracy', ylabel='pa', filename='pixel_accuracy.png')
    draw_figure(mean_px_acc_dict, title='Mean Pixel Accuracy', ylabel='mean_pa', filename='mean_pixel_accuracy.png')
    draw_figure(mean_iou_dict, title='Mean IoU', ylabel='mean_iou', filename='mean_iou.png')
    draw_figure(freq_iou_dict, title='Freq Weighted IoU', ylabel='freq_weighted_iou', filename='freq_weighted_iou.png')