def main():
    os.makedirs(OUT_PATH, exist_ok=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    seg_model = SegNet(3, 2)
    if device == 'cuda':
        seg_model.to(device)
    seg_model.load_state_dict(torch.load(CKPT_PATH))

    seg_model.eval()

    test_set = LaneTestDataset(list_path='./test.tsv',
                               dir_path='./data_road',
                               img_shape=(IMG_W, IMG_H))
    test_loader = DataLoader(test_set,
                             batch_size=1,
                             shuffle=False,
                             num_workers=1)

    with torch.no_grad():
        for image, image_path in tqdm(test_loader):
            image = image.to(device)
            output = seg_model(image)
            output = torch.sigmoid(output)
            mask = torch.argmax(output, dim=1).cpu().numpy().transpose(
                (1, 2, 0))
            mask = mask.reshape(IMG_H, IMG_W)
            image = image.cpu().numpy().reshape(3, IMG_H, IMG_W).transpose(
                (1, 2, 0)) * 255
            image[..., 2] = np.where(mask == 0, 255, image[..., 2])

            cv2.imwrite(
                os.path.join(OUT_PATH, os.path.basename(image_path[0])), image)
def build_model(model_name, num_classes):
    if model_name == 'SQNet':
        return SQNet(classes=num_classes)
    elif model_name == 'LinkNet':
        return LinkNet(classes=num_classes)
    elif model_name == 'SegNet':
        return SegNet(classes=num_classes)
    elif model_name == 'UNet':
        return UNet(classes=num_classes)
    elif model_name == 'ENet':
        return ENet(classes=num_classes)
    elif model_name == 'ERFNet':
        return ERFNet(classes=num_classes)
    elif model_name == 'CGNet':
        return CGNet(classes=num_classes)
    elif model_name == 'EDANet':
        return EDANet(classes=num_classes)
    elif model_name == 'ESNet':
        return ESNet(classes=num_classes)
    elif model_name == 'ESPNet':
        return ESPNet(classes=num_classes)
    elif model_name == 'LEDNet':
        return LEDNet(classes=num_classes)
    elif model_name == 'ESPNet_v2':
        return EESPNet_Seg(classes=num_classes)
    elif model_name == 'ContextNet':
        return ContextNet(classes=num_classes)
    elif model_name == 'FastSCNN':
        return FastSCNN(classes=num_classes)
    elif model_name == 'DABNet':
        return DABNet(classes=num_classes)
    elif model_name == 'FSSNet':
        return FSSNet(classes=num_classes)
    elif model_name == 'FPENet':
        return FPENet(classes=num_classes)
예제 #3
0
def build_model(model_name, num_classes):
    # for deeplabv3
    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }

    if model_name == 'SQNet':
        return SQNet(classes=num_classes)
    elif model_name == 'LinkNet':
        return LinkNet(classes=num_classes)
    elif model_name == 'SegNet':
        return SegNet(classes=num_classes)
    elif model_name == 'UNet':
        return UNet(classes=num_classes)
    elif model_name == 'ENet':
        return ENet(classes=num_classes)
    elif model_name == 'ERFNet':
        return ERFNet(classes=num_classes)
    elif model_name == 'CGNet':
        return CGNet(classes=num_classes)
    elif model_name == 'EDANet':
        return EDANet(classes=num_classes)
    elif model_name == 'ESNet':
        return ESNet(classes=num_classes)
    elif model_name == 'ESPNet':
        return ESPNet(classes=num_classes)
    elif model_name == 'LEDNet':
        return LEDNet(classes=num_classes)
    elif model_name == 'ESPNet_v2':
        return EESPNet_Seg(classes=num_classes)
    elif model_name == 'ContextNet':
        return ContextNet(classes=num_classes)
    elif model_name == 'FastSCNN':
        return FastSCNN(classes=num_classes)
    elif model_name == 'DABNet':
        return DABNet(classes=num_classes)
    elif model_name == 'FSSNet':
        return FSSNet(classes=num_classes)
    elif model_name == 'FPENet':
        return FPENet(classes=num_classes)
    elif model_name == 'FCN':
        return FCN32VGG(classes=num_classes)
    elif model_name in model_map.keys():
        return model_map[model_name](num_classes, output_stride=8)
예제 #4
0
    def __init__(self, encoder_dim, grid_dims, Generate1_dims, Generate2_dims):
        super(GeneratorVanilla, self).__init__()

        self.encoder = SegNet(input_channels=encoder_dim[0],
                              output_channels=encoder_dim[1])
        init_weights(self.encoder, init_type="kaiming")
        self.N = grid_dims[0] * grid_dims[1]
        self.G1 = PointGeneration(Generate1_dims)

        init_weights(self.G1, init_type="xavier")
        self.G2 = PointGeneration(Generate2_dims)
        init_weights(self.G2, init_type="xavier")
        # self.reconstruct = nn.Tanh()

        self.P0 = PointProjection()
        self.P1 = PointProjection()
예제 #5
0
def build_model(model_name,
                num_classes,
                backbone='resnet18',
                pretrained=False,
                out_stride=32,
                mult_grid=False):
    if model_name == 'FCN':
        model = FCN(num_classes=num_classes)
    elif model_name == 'FCN_ResNet':
        model = FCN_ResNet(num_classes=num_classes,
                           backbone=backbone,
                           out_stride=out_stride,
                           mult_grid=mult_grid)
    elif model_name == 'SegNet':
        model = SegNet(classes=num_classes)
    elif model_name == 'UNet':
        model = UNet(num_classes=num_classes)
    elif model_name == 'BiSeNet':
        model = BiSeNet(num_classes=num_classes, backbone=backbone)
    elif model_name == 'BiSeNetV2':
        model = BiSeNetV2(num_classes=num_classes)
    elif model_name == 'HRNet':
        model = HighResolutionNet(num_classes=num_classes)
    elif model_name == 'Deeplabv3plus_res101':
        model = DeepLabv3_plus(nInputChannels=3,
                               n_classes=num_classes,
                               os=out_stride,
                               pretrained=True)
    elif model_name == "DDRNet":
        model = DDRNet(pretrained=True, num_classes=num_classes)
    elif model_name == 'PSPNet_res50':
        model = PSPNet(layers=50,
                       bins=(1, 2, 3, 6),
                       dropout=0.1,
                       num_classes=num_classes,
                       zoom_factor=8,
                       use_ppm=True,
                       pretrained=True)
    elif model_name == 'PSPNet_res101':
        model = PSPNet(layers=101,
                       bins=(1, 2, 3, 6),
                       dropout=0.1,
                       num_classes=num_classes,
                       zoom_factor=8,
                       use_ppm=True,
                       pretrained=True)
    # elif model_name == 'PSANet50':
    #     return PSANet(layers=50, dropout=0.1, classes=num_classes, zoom_factor=8, use_psa=True, psa_type=2, compact=compact,
    #                shrink_factor=shrink_factor, mask_h=mask_h, mask_w=mask_w, psa_softmax=True, pretrained=True)

    if pretrained:
        checkpoint = model_zoo.load_url(model_urls[backbone])
        model_dict = model.state_dict()
        # print(model_dict)
        # 筛除不加载的层结构
        pretrained_dict = {
            'backbone.' + k: v
            for k, v in checkpoint.items() if 'backbone.' + k in model_dict
        }
        # 更新当前网络的结构字典
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    return model
import matplotlib.pyplot as plt
from model.Unet import Unet
from model.SegNet import SegNet
from model.PSPNet import PSPNet
from model.Deeplab import Deeplab
from model.DeconvNet import DeconvNet
from model.ENet import ENet
from dataloader import DataLoader
from utils import COLOR_DICT, show_label, show_predict_image
from matplotlib.pyplot import savefig

unet_model = Unet((512, 512, 3), 12)
weight_path = 'Log/Unet/weight.h5'
unet_model.load_weights(weight_path)

segnet_model = SegNet((512, 512, 3), 12)
weight_path = 'Log/SegNet/weight.h5'
segnet_model.load_weights(weight_path)

pspnet_model = PSPNet((480, 480, 3), 12)
weight_path = 'Log/PSPNet/weight.h5'
pspnet_model.load_weights(weight_path)

deeplab_model = Deeplab((512, 512, 3), 12)
weight_path = 'Log/Deeplab/weight.h5'
deeplab_model.load_weights(weight_path)

deconvnet_model = DeconvNet((512, 512, 3), 12)
weight_path = 'Log/DeconvNet/weight.h5'
deconvnet_model.load_weights(weight_path)
예제 #7
0
def seg_model(args):
    if args.network == "Unet":
        model = Unet(args.in_channel,
                     args.n_class,
                     channel_reduction=args.Ulikenet_channel_reduction,
                     aux=args.aux)
    elif args.network == "AttUnet":
        model = AttUnet(args.in_channel,
                        args.n_class,
                        channel_reduction=args.Ulikenet_channel_reduction,
                        aux=args.aux)
    elif args.network == "SegNet":
        model = SegNet(args.in_channel, args.n_class)
    elif args.network == "PSPNet":
        model = PSPNet(args.n_class,
                       args.backbone,
                       aux=args.aux,
                       pretrained_base=args.pretrained,
                       dilated=args.dilated,
                       deep_stem=args.deep_stem)
    elif args.network == "DeepLabV3":
        model = DeepLabV3(args.n_class,
                          args.backbone,
                          aux=args.aux,
                          pretrained_base=args.pretrained,
                          dilated=args.dilated,
                          deep_stem=args.deep_stem)
    elif args.network == "DANet":
        model = DANet(args.n_class,
                      args.backbone,
                      aux=args.aux,
                      pretrained_base=args.pretrained,
                      dilated=args.dilated,
                      deep_stem=args.deep_stem)
    elif args.network == "CPFNet":
        model = CPFNet(args.n_class,
                       args.backbone,
                       aux=args.aux,
                       pretrained_base=args.pretrained,
                       dilated=args.dilated,
                       deep_stem=args.deep_stem)
    elif args.network == "AG_Net":
        model = AG_Net(args.n_class)
    elif args.network == "CENet":
        model = CE_Net_(args.n_class)
    elif args.network == "ResUnet":
        model = ResUnet(args.n_class,
                        args.backbone,
                        aux=args.aux,
                        pretrained_base=args.pretrained,
                        dilated=args.dilated,
                        deep_stem=args.deep_stem)
    elif args.network == "EMANet":  # 增强 PPM + EMAU + sematic flow
        model = EMANet(args.n_class,
                       args.backbone,
                       aux=args.aux,
                       pretrained_base=args.pretrained,
                       dilated=args.dilated,
                       deep_stem=args.deep_stem,
                       crop_size=args.crop_size)
    elif args.network == "DeepLabV3Plus":
        model = DeepLabV3Plus(args.n_class)
    elif args.network == "EfficientFCN":
        model = EfficientFCN(args.n_class,
                             args.backbone,
                             aux=args.aux,
                             pretrained_base=args.pretrained,
                             dilated=False,
                             deep_stem=args.deep_stem)
    elif args.network == "EMUPNet":
        model = EMUPNet(args.n_class,
                        args.crop_size,
                        args.backbone,
                        pretrained_base=args.pretrained,
                        deep_stem=args.deep_stem)
    elif args.network == "CaCNet":
        model = CaCNet(args.n_class,
                       args.backbone,
                       aux=args.aux,
                       pretrained_base=args.pretrained,
                       dilated=args.dilated,
                       deep_stem=args.deep_stem)
    elif args.network == "Border_ResUnet":
        model = Border_ResUnet(args.n_class,
                               args.backbone,
                               aux=args.aux,
                               pretrained_base=args.pretrained,
                               dilated=args.dilated,
                               deep_stem=args.deep_stem)
    elif args.network == "TANet":
        model = TANet(args.n_class,
                      args.crop_size,
                      args.backbone,
                      aux=args.aux,
                      pretrained_base=args.pretrained,
                      dilated=args.dilated,
                      deep_stem=args.deep_stem)
    elif args.network == "CMSINet":
        model = CMSINet(args.n_class,
                        args.backbone,
                        aux=args.aux,
                        pretrained_base=args.pretrained,
                        dilated=args.dilated,
                        deep_stem=args.deep_stem)
    elif args.network == "class_gcn_Net":
        model = class_gcn_Net(args.n_class,
                              args.backbone,
                              aux=args.aux,
                              pretrained_base=args.pretrained,
                              dilated=args.dilated,
                              deep_stem=args.deep_stem)
    elif args.network == "EfficientEMUPNet":
        model = EfficientEMUPNet(args.n_class,
                                 args.backbone,
                                 aux=args.aux,
                                 pretrained_base=args.pretrained,
                                 dilated=False,
                                 deep_stem=args.deep_stem)
    elif args.network == "CG_EMUPNet":
        model = CG_EMUPNet(args.n_class,
                           args.crop_size,
                           args.backbone,
                           pretrained_base=args.pretrained,
                           deep_stem=args.deep_stem)
    elif args.network == "DF_ResUnet":
        model = DF_ResUnet(args.n_class,
                           args.backbone,
                           aux=args.aux,
                           pretrained_base=args.pretrained,
                           dilated=args.dilated,
                           deep_stem=args.deep_stem)
    elif args.network == "GloRe_Net":
        model = GloRe_Net(args.n_class,
                          args.backbone,
                          aux=args.aux,
                          pretrained_base=args.pretrained,
                          dilated=args.dilated,
                          deep_stem=args.deep_stem)
    elif args.network == "BiNet":
        model = BiNet(args.n_class,
                      args.backbone,
                      aux=args.aux,
                      pretrained_base=args.pretrained,
                      dilated=args.dilated,
                      deep_stem=args.deep_stem)
    elif args.network == "BiNet_baseline":
        model = BiNet_baseline(args.n_class,
                               args.backbone,
                               aux=args.aux,
                               pretrained_base=args.pretrained,
                               dilated=args.dilated,
                               deep_stem=args.deep_stem)
    elif args.network == "channel_gcn_Net":
        model = channel_gcn_Net(args.n_class,
                                args.backbone,
                                aux=args.aux,
                                pretrained_base=args.pretrained,
                                dilated=args.dilated,
                                deep_stem=args.deep_stem)
    elif args.network == "shuffle_Unet":
        model = shuffle_Unet(args.in_channel,
                             args.n_class,
                             channel_reduction=args.Ulikenet_channel_reduction,
                             aux=args.aux)
    elif args.network == "CSNet":
        model = CSNet(args.in_channel, args.n_class)
    elif args.network == "Flaw_Unet":
        model = Flaw_Unet(args.in_channel, args.n_class)
    else:
        NotImplementedError("not implemented {args.network} model")

    return model
예제 #8
0
def main():
    # get model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # fcn_model = FCNs(pretrained_net=VGGNet(pretrained=True, requires_grad=True))
    seg_model = SegNet(3, 2)
    # seg_model.load_weights('./vgg16-397923af.pth')
    # criterion = nn.BCELoss()
    criterion = BCEFocalLoss()
    optimizer = optim.Adam(seg_model.parameters(), lr=LR, weight_decay=0.0001)
    evaluator = Evaluator(num_class=2)

    if device == 'cuda':
        seg_model.to(device)
        criterion.to(device)

    # get dataloader
    train_set = LaneClsDataset(list_path='train.tsv',
                               dir_path='data_road',
                               img_shape=(IMG_W, IMG_H))
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
                              shuffle=True, num_workers=8)

    val_set = LaneClsDataset(list_path='val.tsv',
                             dir_path='data_road',
                             img_shape=(IMG_W, IMG_H))
    val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1)

    # info records
    loss_dict = defaultdict(list)
    px_acc_dict = defaultdict(list)
    mean_px_acc_dict = defaultdict(list)
    mean_iou_dict = defaultdict(list)
    freq_iou_dict = defaultdict(list)

    for epoch_idx in range(1, MAX_EPOCH + 1):
        # train stage
        seg_model.train()
        evaluator.reset()
        train_loss = 0.0
        for batch_idx, (image, label) in enumerate(train_loader):

            lr = LR
            lr = lr_func((epoch_idx-1) * 88 + batch_idx, lr)
            for param in optimizer.param_groups:
                param['lr']=lr

            image = image.to(device)
            # print(label.shape)
            # label = label.reshape(BATCH_SIZE, 288, 800)
            label = label.to(device)
            optimizer.zero_grad()
            output = seg_model(image)
            output = torch.sigmoid(output)

            loss = criterion(output, label.long())
            loss.backward()

            evaluator.add_batch(torch.argmax(output, dim=1).cpu().numpy(),
                                torch.argmax(label, dim=1).cpu().numpy())
            train_loss += loss.item()
            print("[Train][Epoch] {}/{}, [Batch] {}/{}, [lr] {:.6f},[Loss] {:.6f}".format(epoch_idx,
                                                                              MAX_EPOCH,
                                                                              batch_idx+1,
                                                                              len(train_loader),
                                                                              lr,
                                                                              loss.item()))
            optimizer.step()
        loss_dict['train'].append(train_loss/len(train_loader))
        px_acc = evaluator.Pixel_Accuracy() * 100
        px_acc_dict['train'].append(px_acc)
        mean_px_acc = evaluator.Pixel_Accuracy_Class() * 100
        mean_px_acc_dict['train'].append(mean_px_acc)
        mean_iou = evaluator.Mean_Intersection_over_Union() * 100
        mean_iou_dict['train'].append(mean_iou)
        freq_iou = evaluator.Frequency_Weighted_Intersection_over_Union() * 100
        freq_iou_dict['train'].append(freq_iou)
        print("[Train][Epoch] {}/{}, [PA] {:.2f}%, [MeanPA] {:.2f}%, [MeanIOU] {:.2f}%, ""[FreqIOU] {:.2f}%".format(
            epoch_idx,
            MAX_EPOCH,
            px_acc,
            mean_px_acc,
            mean_iou,
            freq_iou))

        evaluator.reset()
        # validate stage
        seg_model.eval()
        with torch.no_grad():
            val_loss = 0.0
            for image, label in val_loader:
                image, label = image.to(device), label.to(device)
                output = seg_model(image)
                output = torch.sigmoid(output)
                loss = criterion(output, label.long())
                val_loss += loss.item()
                evaluator.add_batch(torch.argmax(output, dim=1).cpu().numpy(),
                                    torch.argmax(label, dim=1).cpu().numpy())
            val_loss /= len(val_loader)
            loss_dict['val'].append(val_loss)
            px_acc = evaluator.Pixel_Accuracy() * 100
            px_acc_dict['val'].append(px_acc)
            mean_px_acc = evaluator.Pixel_Accuracy_Class() * 100
            mean_px_acc_dict['val'].append(mean_px_acc)
            mean_iou = evaluator.Mean_Intersection_over_Union() * 100
            mean_iou_dict['val'].append(mean_iou)
            freq_iou = evaluator.Frequency_Weighted_Intersection_over_Union() * 100
            freq_iou_dict['val'].append(freq_iou)
            print("[Val][Epoch] {}/{}, [Loss] {:.6f}, [PA] {:.2f}%, [MeanPA] {:.2f}%, "
                  "[MeanIOU] {:.2f}%, ""[FreqIOU] {:.2f}%".format(epoch_idx,
                                                                  MAX_EPOCH,
                                                                  val_loss,
                                                                  px_acc,
                                                                  mean_px_acc,
                                                                  mean_iou,
                                                                  freq_iou))

        # save model checkpoints
        if epoch_idx % SAVE_INTERVAL == 0 or epoch_idx == MAX_EPOCH:
            os.makedirs(MODEL_CKPT_DIR, exist_ok=True)
            ckpt_save_path = os.path.join(MODEL_CKPT_DIR, 'epoch_{}.pth'.format(epoch_idx))
            torch.save(seg_model.state_dict(), ckpt_save_path)
            print("[Epoch] {}/{}, 模型权重保存至{}".format(epoch_idx, MAX_EPOCH, ckpt_save_path))

    # draw figures
    os.makedirs(FIGURE_DIR, exist_ok=True)
    draw_figure(loss_dict, title='Loss', ylabel='loss', filename='loss.png')
    draw_figure(px_acc_dict, title='Pixel Accuracy', ylabel='pa', filename='pixel_accuracy.png')
    draw_figure(mean_px_acc_dict, title='Mean Pixel Accuracy', ylabel='mean_pa', filename='mean_pixel_accuracy.png')
    draw_figure(mean_iou_dict, title='Mean IoU', ylabel='mean_iou', filename='mean_iou.png')
    draw_figure(freq_iou_dict, title='Freq Weighted IoU', ylabel='freq_weighted_iou', filename='freq_weighted_iou.png')
예제 #9
0
def main(args):
    # load data
    starter_time = time.time()
    kwargs = {'num_workers': 4, 'pin_memory': True}

    print("loading train data ...")
    trainset = PointCloudDataset(args.train_json)
    train_loader = torch.utils.data.DataLoader(
        trainset,
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs,
    )
    print("loading test data ...")
    testset = PointCloudDataset(args.test_json)
    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)
    print("Initialize cache={}".format(time.time() - starter_time))

    im_encoder = SegNet(input_channels=3, output_channels=3)
    pointVAE = PointVAE(args=args)
    #net = GeneratorVAE(
    #	encoder_dim=(3, 3),
    #	grid_dims=(32, 32, 1),
    #	Generate1_dims=259,
    #	Generate2_dims=1091,
    #	Generate3_dims=1219,
    #	args=args,
    #)
    net = GeneratorVAE(
        im_encoder=im_encoder,
        pointVAE=pointVAE,
        encoder_dim=(3, 3),
        grid_dims=(32, 32, 1),
        Generate1_dims=259,
        Generate2_dims=1091,
        Generate3_dims=1219,
        args=args,
    )
    #init_weights(net, init_type="xavier")

    logger = logging.getLogger()
    file_log_handler = logging.FileHandler(args.log_dir + args.log_filename)
    logger.addHandler(file_log_handler)

    stderr_log_handler = logging.StreamHandler(sys.stdout)
    logger.addHandler(stderr_log_handler)

    logger.setLevel('INFO')
    formatter = logging.Formatter()
    file_log_handler.setFormatter(formatter)
    stderr_log_handler.setFormatter(formatter)
    logger.info(args)

    criterion_I = MaskedL1().to(args.device)
    criterion_PTC = ChamfersDistance().to(args.device)

    optimizer_image = torch.optim.Adam(
        im_encoder.parameters(),
        lr=args.lr_image,
        betas=(args.adam_beta1, 0.999),
        weight_decay=args.weight_decay,
    )
    optimizer_VAE = torch.optim.Adam(
        pointVAE.parameters(),
        lr=args.lr_vae,
        betas=(args.adam_beta1, 0.999),
        weight_decay=args.weight_decay,
    )
    optimizer = torch.optim.Adam(
        net.parameters(),
        lr=args.lr,
        betas=(args.adam_beta1, 0.999),
        weight_decay=args.weight_decay,
    )

    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=args.lr_decay_step,
        gamma=args.lr_decay_rate,
    )

    # train and test

    runner = TrainTester(
        net=net,
        criterion_I=criterion_I,
        criterion_PTC=criterion_PTC,
        optimizer=optimizer,
        optimizer_image=optimizer_image,
        optimizer_VAE=optimizer_VAE,
        lr_scheduler=lr_scheduler,
        logger=logger,
        args=args,
    )

    if args.train:
        runner.run(
            train_loader=train_loader,
            test_loader=test_loader,
        )
        logger.info('Training Done!')

    if args.test:
        runner.test(
            epoch=args.total_epochs + 1,
            loader=test_loader,
        )
        logger.info('Testing Done!')