Exemple #1
0
def testwithdehaze():
    model_dehaze = aodnet.AODnet()
    model_dehaze.load_state_dict(
        torch.load("./weight/dehaze.pth", map_location='cpu'))

    if args.choose_net == "Unet":
        model_segmentation = my_unet.UNet(3, 1)
    elif args.choose_net == "Enet":
        model_segmentation = enet.ENet(num_classes=1)
    elif args.choose_net == "Segnet":
        model_segmentation = segnet.SegNet(3, 1)
    model_segmentation.load_state_dict(
        torch.load(args.weight, map_location='cpu'))
    liver_dataset = LiverDataset_three("data/val_dehaze",
                                       transform=x_transform,
                                       target_transform=y_transform)
    dataloaders = DataLoader(liver_dataset)  # batch_size默认为1
    model_dehaze.eval()
    model_segmentation.eval()

    metric = SegmentationMetric(2)

    # import matplotlib.pyplot as plt
    # plt.ion()
    mean_acc, mean_miou = [], []
    with torch.no_grad():
        for src, rain, mask in dataloaders:
            y = model_dehaze(src)
            y1 = model_segmentation(y)
            y1 = torch.squeeze(y1).numpy()
            y_label = torch.squeeze(mask).numpy()
            y_label = y_label * 255
            y1 = y1 * 127.5
            # print(y_label.shape,y.shape)
            image = np.concatenate((y_label, y1))

            if args.choose_net == "Unet":
                img_y = (y1 > 0.5)
            elif args.choose_net == "Enet":
                img_y = (y1 > 0.5)
            elif args.choose_net == "Segnet":
                img_y = (y1 > 0.5)
            elif args.choose_net == "Scnn":
                img_y = (y1 > 0.5)
            img_y = img_y.astype(int)

            y_label = y_label.astype(int)
            metric.addBatch(img_y, y_label)
            acc = metric.pixelAccuracy()
            mIoU = metric.meanIntersectionOverUnion()
            # confusionMatrix=metric.genConfusionMatrix(img_y, y_label)
            mean_acc.append(acc)
            mean_miou.append(mIoU)
            # print(acc, mIoU,confusionMatrix)
            print(acc, mIoU)
            # plt.imshow(image)
            # plt.pause(0.01)
            # plt.show()
    print("average acc:%0.6f  average miou:%0.6f" %
          (np.mean(mean_acc), np.mean(mean_miou)))
    def test_save_loss_graphs_no_class_weight(self):
        x = np.random.uniform(-1, 1, self.x_shape)
        x = Variable(x.astype(np.float32))
        t = np.random.randint(
            0, 12, (self.x_shape[0], self.x_shape[2], self.x_shape[3]))
        t = Variable(t.astype(np.int32))

        for depth in six.moves.range(1, self.n_encdec + 1):
            model = segnet.SegNet(n_encdec=self.n_encdec,
                                  n_classes=12,
                                  in_channel=self.x_shape[1])
            model = segnet.SegNetLoss(model,
                                      class_weight=None,
                                      train_depth=depth)
            y = model(x, t)
            cg = build_computational_graph([y],
                                           variable_style=_var_style,
                                           function_style=_func_style).dump()
            for e in range(1, self.n_encdec + 1):
                self.assertTrue(
                    'encdec{}'.format(e) in model.predictor._children)

            fn = 'tests/SegNet_xt_depth-{}_{}.dot'.format(self.n_encdec, depth)
            if os.path.exists(fn):
                continue
            with open(fn, 'w') as f:
                f.write(cg)
            subprocess.call('dot -Tpng {} -o {}'.format(
                fn, fn.replace('.dot', '.png')),
                            shell=True)
def main():
    nClasses = args.nClasses
    train_batch_size = 16
    val_batch_size = 16
    epochs = 50
    img_height = 256
    img_width = 256
    root_path = '../../datasets/segmentation/'
    mode = 'seg' if nClasses == 2 else 'parse'
    train_file = './data/{}_train.txt'.format(mode)
    val_file = './data/{}_test.txt'.format(mode)
    if args.model == 'unet':
        model = unet.Unet(nClasses,
                          input_height=img_height,
                          input_width=img_width)
    elif args.model == 'segnet':
        model = segnet.SegNet(nClasses,
                              input_height=img_height,
                              input_width=img_width)
    else:
        raise ValueError(
            'Does not support {}, only supports unet and segnet now'.format(
                args.model))

    model.compile(loss='categorical_crossentropy',
                  optimizer=Adam(lr=1e-4),
                  metrics=['accuracy'])
    model.summary()

    train = segdata_generator.generator(root_path, train_file,
                                        train_batch_size, nClasses, img_height,
                                        img_width)

    val = segdata_generator.generator(root_path,
                                      val_file,
                                      val_batch_size,
                                      nClasses,
                                      img_height,
                                      img_width,
                                      train=False)

    if not os.path.exists('./results/'):
        os.mkdir('./results')
    save_file = './weights/{}_seg_weights.h5'.format(args.model) if nClasses == 2 \
        else './weights/{}_parse_weights.h5'.format(args.model)
    checkpoint = ModelCheckpoint(save_file,
                                 monitor='val_acc',
                                 save_best_only=True,
                                 save_weights_only=True,
                                 verbose=1)
    history = model.fit_generator(
        train,
        steps_per_epoch=12706 // train_batch_size,
        validation_data=val,
        validation_steps=5000 // val_batch_size,
        epochs=epochs,
        callbacks=[checkpoint],
    )
    plot_history(history, './results/', args.model)
    save_history(history, './results/', args.model)
Exemple #4
0
def trainwithdehaze():
    model_dehaze = aodnet.AODnet().to(device)
    dsize = (3, 1, 256, 256)
    # inputs1 = torch.randn(dsize).to(device)
    # total_ops, total_params = profile(model_dehaze, (inputs1,), verbose=False)
    # print(" %.2f | %.2f" % (total_params / (1000 ** 2), total_ops / (1000 ** 3)))
    if args.choose_net == "Unet":
        model_segmentation = my_unet.UNet(3, 1).to(device)
    elif args.choose_net == "Enet":
        model_segmentation = enet.ENet(num_classes=1).to(device)
    elif args.choose_net == "Segnet":
        model_segmentation = segnet.SegNet(3, 1).to(device)

    # inputs2 = torch.randn(dsize).to(device)
    # total_ops, total_params = profile(model_segmentation, (inputs2,), verbose=False)
    # print(" %.2f | %.2f" % (total_params / (1000 ** 2), total_ops / (1000 ** 3)))
    batch_size = args.batch_size
    # dehaze的损失函数
    criterion_dehaze = torch.nn.MSELoss()
    # dehaze的优化函数
    optimizer_dehaze = optim.Adam(model_dehaze.parameters(
    ))  # model.parameters():Returns an iterator over module parameters

    # 语义分割的损失函数
    criterion_segmentation = torch.nn.BCELoss()
    # 语义分割的优化函数
    optimizer_segmentation = optim.Adam(model_segmentation.parameters(
    ))  # model.parameters():Returns an iterator over module parameters
    # 加载数据集
    dataset_dehaze = LiverDataset_three("data/train_dehaze/",
                                        transform=x_transform,
                                        target_transform=y_transform)
    dataloader_dehaze = DataLoader(dataset_dehaze,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=4)
    # DataLoader:该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor
    # batch_size:how many samples per minibatch to load,这里为4,数据集大小400,所以一共有100个minibatch
    # shuffle:每个epoch将数据打乱,这里epoch=10。一般在训练数据中会采用
    # num_workers:表示通过多个进程来导入数据,可以加快数据导入速度
    train_dehaze_model(model_dehaze,
                       model_segmentation,
                       criterion_dehaze,
                       criterion_segmentation,
                       optimizer_dehaze,
                       optimizer_segmentation,
                       dataloader_dehaze,
                       num_epochs=6)
Exemple #5
0
def load_model(model_name, noc):
    if model_name == 'fcn':
        model = fcn.FCN8s(noc)
    if model_name == 'segnet':
        model = segnet.SegNet(3, noc)
    if model_name == 'pspnet':
        model = pspnet.PSPNet(noc)
    if model_name == 'unet':
        model = unet.UNet(noc)
    if model_name == 'segfast':
        model = segfast.SegFast(64, noc)
    if model_name == 'segfast_basic':
        model = segfast_basic.SegFast_Basic(64, noc)
    if model_name == 'segfast_mobile':
        model = segfast_mobile.SegFast_Mobile(noc)
    if model_name == 'segfast_v2_3':
        model = segfast_v2.SegFast_V2(64, noc, 3)
    if model_name == 'segfast_v2_5':
        model = segfast_v2.SegFast_V2(64, noc, 5)
    return model
def predict_segmentation():
    n_classes = 2
    images_path = '/home/deep/datasets/'
    val_file = './data/seg_test.txt'
    input_height = 256
    input_width = 256

    if args.model == 'unet':
        m = unet.Unet(n_classes,
                      input_height=input_height,
                      input_width=input_width)
    elif args.model == 'segnet':
        m = segnet.SegNet(n_classes,
                          input_height=input_height,
                          input_width=input_width)
    else:
        raise ValueError('Do not support {}'.format(args.model))

    m.load_weights("./results/{}_weights.h5".format(args.model))
    m.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

    colors = np.array([[0, 0, 0], [255, 255, 255]])
    i = 0
    for x, y in generator(images_path, val_file, 1, n_classes, input_height,
                          input_width):
        pr = m.predict(x)[0]
        pr = pr.reshape((input_height, input_width, n_classes)).argmax(axis=2)
        seg_img = np.zeros((input_height, input_width, 3))
        for c in range(n_classes):
            seg_img[:, :,
                    0] += ((pr[:, :] == c) * (colors[c][0])).astype('uint8')
            seg_img[:, :,
                    1] += ((pr[:, :] == c) * (colors[c][1])).astype('uint8')
            seg_img[:, :,
                    2] += ((pr[:, :] == c) * (colors[c][2])).astype('uint8')
        cv2.imshow('test', seg_img)
        cv2.imwrite('./output/{}.jpg'.format(i), seg_img)
        i += 1
        cv2.waitKey(30)
    def test_save_normal_graphs(self):
        x = np.random.uniform(-1, 1, self.x_shape)
        x = Variable(x.astype(np.float32))

        for depth in six.moves.range(1, self.n_encdec + 1):
            model = segnet.SegNet(n_encdec=self.n_encdec,
                                  in_channel=self.x_shape[1])
            y = model(x, depth)
            cg = build_computational_graph([y],
                                           variable_style=_var_style,
                                           function_style=_func_style).dump()
            for e in range(1, self.n_encdec + 1):
                self.assertTrue('encdec{}'.format(e) in model._children)

            fn = 'tests/SegNet_x_depth-{}_{}.dot'.format(self.n_encdec, depth)
            if os.path.exists(fn):
                continue
            with open(fn, 'w') as f:
                f.write(cg)
            subprocess.call('dot -Tpng {} -o {}'.format(
                fn, fn.replace('.dot', '.png')),
                            shell=True)
    def test_remove_link(self):
        opt = optimizers.MomentumSGD(lr=0.01)
        # Update each depth
        for depth in six.moves.range(1, self.n_encdec + 1):
            model = segnet.SegNet(self.n_encdec, self.n_classes,
                                  self.x_shape[1], self.n_mid)
            model = segnet.SegNetLoss(model,
                                      class_weight=None,
                                      train_depth=depth)
            opt.setup(model)

            # Deregister non-target links from opt
            if depth > 1:
                model.predictor.remove_link('conv_cls')
            for d in range(1, self.n_encdec + 1):
                if d != depth:
                    model.predictor.remove_link('encdec{}'.format(d))

            for name, link in model.namedparams():
                if depth > 1:
                    self.assertTrue('encdec{}'.format(depth) in name)
                else:
                    self.assertTrue('encdec{}'.format(depth) in name
                                    or 'conv_cls' in name)
Exemple #9
0
def test():
    if args.choose_net == "Unet":
        model = my_unet.UNet(3, 1).to(device)
    if args.choose_net == "My_Unet":
        model = my_unet.My_Unet2(3, 1).to(device)
    elif args.choose_net == "Enet":
        model = enet.ENet(num_classes=13).to(device)
    elif args.choose_net == "Segnet":
        model = segnet.SegNet(3, 1).to(device)
    elif args.choose_net == "CascadNet":
        model = my_cascadenet.CascadeNet(3, 1).to(device)

    elif args.choose_net == "my_drsnet_A":
        model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_B":
        model = my_drsnet.MultiscaleSENetB(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_C":
        model = my_drsnet.MultiscaleSENetC(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_A_direct_skip":
        model = my_drsnet.MultiscaleSENetA_direct_skip(in_ch=3,
                                                       out_ch=1).to(device)
    elif args.choose_net == "SEResNet":
        model = my_drsnet.SEResNet18(in_ch=3, out_ch=1).to(device)

    elif args.choose_net == "resnext_unet":
        model = resnext_unet.resnext50(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "resnet50_unet":
        model = resnet50_unet.UNetWithResnet50Encoder(in_ch=3,
                                                      out_ch=1).to(device)
    elif args.choose_net == "unet_res34":
        model = unet_res34.Resnet_Unet(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "dfanet":
        ch_cfg = [[8, 48, 96], [240, 144, 288], [240, 144, 288]]
        model = dfanet.DFANet(ch_cfg, 3, 1).to(device)
    elif args.choose_net == "cgnet":
        model = cgnet.Context_Guided_Network(1).to(device)
    elif args.choose_net == "lednet":
        model = lednet.Net(num_classes=1).to(device)
    elif args.choose_net == "bisenet":
        model = bisenet.BiSeNet(1, 'resnet18').to(device)
    elif args.choose_net == "espnet":
        model = espnet.ESPNet(classes=1).to(device)
    elif args.choose_net == "pspnet":
        model = pspnet.PSPNet(1).to(device)
    elif args.choose_net == "fddwnet":
        model = fddwnet.Net(classes=1).to(device)
    elif args.choose_net == "contextnet":
        model = contextnet.ContextNet(classes=1).to(device)
    elif args.choose_net == "linknet":
        model = linknet.LinkNet(classes=1).to(device)
    elif args.choose_net == "edanet":
        model = edanet.EDANet(classes=1).to(device)
    elif args.choose_net == "erfnet":
        model = erfnet.ERFNet(classes=1).to(device)
    dsize = (1, 3, 128, 192)
    inputs = torch.randn(dsize).to(device)
    total_ops, total_params = profile(model, (inputs, ), verbose=False)
    print(" %.2f | %.2f" % (total_params / (1000**2), total_ops / (1000**3)))

    model.load_state_dict(torch.load(args.weight))
    liver_dataset = LiverDataset("data/val_camvid",
                                 transform=x_transform,
                                 target_transform=y_transform)
    dataloaders = DataLoader(liver_dataset)  # batch_size默认为1
    model.eval()

    metric = SegmentationMetric(13)
    # import matplotlib.pyplot as plt
    # plt.ion()
    multiclass = 1
    mean_acc, mean_miou = [], []

    alltime = 0.0
    with torch.no_grad():
        for x, y_label in dataloaders:
            x = x.to(device)
            start = time.time()
            y = model(x)
            usingtime = time.time() - start
            alltime = alltime + usingtime

            if multiclass == 1:
                # predict输出处理:
                # https://www.cnblogs.com/ljwgis/p/12313047.html
                y = F.sigmoid(y)
                y = y.cpu()
                # y = torch.squeeze(y).numpy()
                y = torch.argmax(y.squeeze(0), dim=0).data.numpy()
                print(y.max(), y.min())
                # y_label = y_label[0]
                y_label = torch.squeeze(y_label).numpy()
            else:
                y = y.cpu()
                y = torch.squeeze(y).numpy()
                y_label = torch.squeeze(y_label).numpy()

                # img_y = y*127.5

                if args.choose_net == "Unet":
                    y = (y > 0.5)
                elif args.choose_net == "My_Unet":
                    y = (y > 0.5)
                elif args.choose_net == "Enet":
                    y = (y > 0.5)
                elif args.choose_net == "Segnet":
                    y = (y > 0.5)
                elif args.choose_net == "Scnn":
                    y = (y > 0.5)
                elif args.choose_net == "CascadNet":
                    y = (y > 0.8)

                elif args.choose_net == "my_drsnet_A":
                    y = (y > 0.5)
                elif args.choose_net == "my_drsnet_B":
                    y = (y > 0.5)
                elif args.choose_net == "my_drsnet_C":
                    y = (y > 0.5)
                elif args.choose_net == "my_drsnet_A_direct_skip":
                    y = (y > 0.5)
                elif args.choose_net == "SEResNet":
                    y = (y > 0.5)

                elif args.choose_net == "resnext_unet":
                    y = (y > 0.5)
                elif args.choose_net == "resnet50_unet":
                    y = (y > 0.5)
                elif args.choose_net == "unet_res34":
                    y = (y > 0.5)
                elif args.choose_net == "dfanet":
                    y = (y > 0.5)
                elif args.choose_net == "cgnet":
                    y = (y > 0.5)
                elif args.choose_net == "lednet":
                    y = (y > 0.5)
                elif args.choose_net == "bisenet":
                    y = (y > 0.5)
                elif args.choose_net == "pspnet":
                    y = (y > 0.5)
                elif args.choose_net == "fddwnet":
                    y = (y > 0.5)
                elif args.choose_net == "contextnet":
                    y = (y > 0.5)
                elif args.choose_net == "linknet":
                    y = (y > 0.5)
                elif args.choose_net == "edanet":
                    y = (y > 0.5)
                elif args.choose_net == "erfnet":
                    y = (y > 0.5)

            img_y = y.astype(int).squeeze()
            print(y_label.shape, img_y.shape)
            image = np.concatenate((img_y, y_label))

            y_label = y_label.astype(int)
            metric.addBatch(img_y, y_label)
            acc = metric.classPixelAccuracy()
            mIoU = metric.meanIntersectionOverUnion()
            # confusionMatrix=metric.genConfusionMatrix(img_y, y_label)
            mean_acc.append(acc[1])
            mean_miou.append(mIoU)
            # print(acc, mIoU,confusionMatrix)
            print(acc, mIoU)
            plt.imshow(image * 5)
            plt.pause(0.1)
            plt.show()
    # 计算时需封印acc和miou计算部分

    print("Took ", alltime, "seconds")
    print("Took", alltime / 638.0, "s/perimage")
    print("FPS", 1 / (alltime / 638.0))
    print("average acc:%0.6f  average miou:%0.6f" %
          (np.mean(mean_acc), np.mean(mean_miou)))
Exemple #10
0
def test_img(src_path, label_path):
    model_enet = enet.ENet(num_classes=1).to(device)
    model_segnet = segnet.SegNet(3, 1).to(device)
    model_my_mulSE_A = my_drsnet.MultiscaleSENetA(3, 1).to(device)
    model_my_mulSE_B = my_drsnetmy_drsnet.MultiscaleSENetB(3, 1).to(device)
    model_my_mulSE_C = my_drsnet.MultiscaleSENetC(3, 1).to(device)
    model_my_mulSE_A_direct_skip = my_drsnet.MultiscaleSENetA_direct_skip(
        3, 1).to(device)
    model_SEResNet18 = my_drsnet.SEResNet18(in_ch=3, out_ch=1).to(device)

    ch_cfg = [[8, 48, 96], [240, 144, 288], [240, 144, 288]]
    model_dfanet = dfanet.DFANet(ch_cfg, 3, 1).to(device)
    model_cgnet = cgnet.Context_Guided_Network(1).to(device)
    model_lednet = lednet.Net(num_classes=1).to(device)
    model_bisenet = bisenet.BiSeNet(1, 'resnet18').to(device)

    model_fddwnet = fddwnet.Net(classes=1).to(device)
    model_contextnet = contextnet.ContextNet(classes=1).to(device)
    model_linknet = linknet.LinkNet(classes=1).to(device)
    model_edanet = edanet.EDANet(classes=1).to(device)
    model_erfnet = erfnet.ERFNet(classes=1).to(device)

    model_enet.load_state_dict(torch.load("./weight/enet_weight.pth"))
    model_enet.eval()
    model_segnet.load_state_dict(torch.load("./weight/segnet_weight.pth"))
    model_segnet.eval()

    model_my_mulSE_A.load_state_dict(
        torch.load("./weight/my_drsnet_A_weight.pth"))
    model_my_mulSE_A.eval()
    model_my_mulSE_B.load_state_dict(
        torch.load("./weight/my_drsnet_B_weight.pth"))
    model_my_mulSE_B.eval()
    model_my_mulSE_C.load_state_dict(
        torch.load("./weight/my_drsnet_C_weight.pth"))
    model_my_mulSE_C.eval()
    model_my_mulSE_A_direct_skip.load_state_dict(
        torch.load("./weight/my_drsnet_A_direct_skip_weight.pth"))
    model_my_mulSE_A_direct_skip.eval()
    model_SEResNet18.load_state_dict(
        torch.load("./weight/SEResNet18_weight.pth"))
    model_SEResNet18.eval()

    model_dfanet.load_state_dict(torch.load("./weight/dfanet.pth"))
    model_dfanet.eval()
    model_cgnet.load_state_dict(torch.load("./weight/cgnet.pth"))
    model_cgnet.eval()
    model_lednet.load_state_dict(torch.load("./weight/lednet.pth"))
    model_lednet.eval()
    model_bisenet.load_state_dict(torch.load("./weight/bisenet.pth"))
    model_bisenet.eval()

    model_fddwnet.load_state_dict(torch.load("./weight/fddwnet.pth"))
    model_fddwnet.eval()
    model_contextnet.load_state_dict(torch.load("./weight/contextnet.pth"))
    model_contextnet.eval()
    model_linknet.load_state_dict(torch.load("./weight/linknet.pth"))
    model_linknet.eval()
    model_edanet.load_state_dict(torch.load("./weight/edanet.pth"))
    model_edanet.eval()
    model_erfnet.load_state_dict(torch.load("./weight/erfnet.pth"))
    model_erfnet.eval()

    src = Image.open(src_path)
    src = src.resize((128, 192))
    src = x_transform(src)
    src = src.to(device)
    src = torch.unsqueeze(src, 0)

    y_enet = model_enet(src)
    # label = label.to(device)
    y_enet = y_enet.cpu()
    y_enet = y_enet.detach().numpy().reshape(192, 128)

    y_segnet = model_segnet(src)
    # label = label.to(device)
    y_segnet = y_segnet.cpu()
    y_segnet = y_segnet.detach().numpy().reshape(192, 128)

    y_my_mulSE_A = model_my_mulSE_A(src)
    # label = label.to(device)
    y_my_mulSE_A = y_my_mulSE_A.cpu()
    y_my_mulSE_A = y_my_mulSE_A.detach().numpy().reshape(192, 128)

    y_my_mulSE_B = model_my_mulSE_B(src)
    # label = label.to(device)
    y_my_mulSE_B = y_my_mulSE_B.cpu()
    y_my_mulSE_B = y_my_mulSE_B.detach().numpy().reshape(192, 128)

    y_my_mulSE_C = model_my_mulSE_C(src)
    # label = label.to(device)
    y_my_mulSE_C = y_my_mulSE_C.cpu()
    y_my_mulSE_C = y_my_mulSE_C.detach().numpy().reshape(192, 128)

    y_my_mulSE_A_direct_skip = model_my_mulSE_A_direct_skip(src)
    # label = label.to(device)
    y_my_mulSE_A_direct_skip = y_my_mulSE_A_direct_skip.cpu()
    y_my_mulSE_A_direct_skip = y_my_mulSE_A_direct_skip.detach().numpy(
    ).reshape(192, 128)

    y_SEResNet18 = model_SEResNet18(src)
    # label = label.to(device)
    y_SEResNet18 = y_SEResNet18.cpu()
    y_SEResNet18 = y_SEResNet18.detach().numpy().reshape(192, 128)

    y_dfanet = model_dfanet(src)
    # label = label.to(device)
    y_dfanet = y_dfanet.cpu()
    y_dfanet = y_dfanet.detach().numpy().reshape(192, 128)

    y_cgnet = model_cgnet(src)
    # label = label.to(device)
    y_cgnet = y_cgnet.cpu()
    y_cgnet = y_cgnet.detach().numpy().reshape(192, 128)

    y_lednet = model_lednet(src)
    # label = label.to(device)
    y_lednet = y_lednet.cpu()
    y_lednet = y_lednet.detach().numpy().reshape(192, 128)

    y_bisenet = model_bisenet(src)
    # label = label.to(device)
    y_bisenet = y_bisenet.cpu()
    y_bisenet = y_bisenet.detach().numpy().reshape(192, 128)

    y_fddwnet = model_fddwnet(src)
    # label = label.to(device)
    y_fddwnet = y_fddwnet.cpu()
    y_fddwnet = y_fddwnet.detach().numpy().reshape(192, 128)

    y_contextnet = model_contextnet(src)
    # label = label.to(device)
    y_contextnet = y_contextnet.cpu()
    y_contextnet = y_contextnet.detach().numpy().reshape(192, 128)

    y_linknet = model_linknet(src)
    # label = label.to(device)
    y_linknet = y_linknet.cpu()
    y_linknet = y_linknet.detach().numpy().reshape(192, 128)

    y_edanet = model_edanet(src)
    # label = label.to(device)
    y_edanet = y_edanet.cpu()
    y_edanet = y_edanet.detach().numpy().reshape(192, 128)

    y_erfnet = model_erfnet(src)
    # label = label.to(device)
    y_erfnet = y_erfnet.cpu()
    y_erfnet = y_erfnet.detach().numpy().reshape(192, 128)

    y_enet = (y_enet > 0.5).astype(int) * 255
    y_segnet = (y_segnet > 0.5).astype(int) * 255
    y_my_mulSE_A = (y_my_mulSE_A > 0.5).astype(int) * 255
    y_my_mulSE_B = (y_my_mulSE_B > 0.5).astype(int) * 255
    y_my_mulSE_C = (y_my_mulSE_C > 0.5).astype(int) * 255
    y_my_mulSE_A_direct_skip = (y_my_mulSE_A_direct_skip >
                                0.5).astype(int) * 255
    y_SEResNet18 = (y_SEResNet18 > 0.5).astype(int) * 255

    y_dfanet = (y_dfanet > 0.5).astype(int) * 255
    y_cgnet = (y_cgnet > 0.5).astype(int) * 255
    y_lednet = (y_lednet > 0.5).astype(int) * 255
    y_bisenet = (y_bisenet > 0.5).astype(int) * 255

    y_fddwnet = (y_fddwnet > 0.5).astype(int) * 255
    y_contextnet = (y_contextnet > 0.5).astype(int) * 255
    y_linknet = (y_linknet > 0.5).astype(int) * 255
    y_edanet = (y_edanet > 0.5).astype(int) * 255
    y_erfnet = (y_erfnet > 0.5).astype(int) * 255

    src1 = Image.open(src_path)
    src1 = src1.resize((128, 192))
    label = Image.open(label_path)
    label = label.resize((128, 192))
    label = np.array(label) * 255
    src1.save("./data/result/" + "_src.png")
    cv2.imwrite("./data/result/" + "_label.png", label)
    cv2.imwrite("./data/result/" + "enet_predict.png", y_enet)
    cv2.imwrite("./data/result/" + "segnet_predict.png", y_segnet)
    cv2.imwrite("./data/result/" + "my_drsnet_A_predict.png", y_my_mulSE_A)
    cv2.imwrite("./data/result/" + "my_drsnet_B_predict.png", y_my_mulSE_B)
    cv2.imwrite("./data/result/" + "my_drsnet_C_predict.png", y_my_mulSE_C)
    cv2.imwrite("./data/result/" + "my_drsnet_A_direct_skip_predict.png",
                y_my_mulSE_A_direct_skip)
    cv2.imwrite("./data/result/" + "y_SEResNet18_predict.png", y_SEResNet18)

    cv2.imwrite("./data/result/" + "dfanet_predict.png", y_dfanet)
    cv2.imwrite("./data/result/" + "cgnet_predict.png", y_cgnet)
    cv2.imwrite("./data/result/" + "lednet_predict.png", y_lednet)
    cv2.imwrite("./data/result/" + "bisenet_predict.png", y_bisenet)

    cv2.imwrite("./data/result/" + "fddwnet_predict.png", y_fddwnet)
    cv2.imwrite("./data/result/" + "contextnet_predict.png", y_contextnet)
    cv2.imwrite("./data/result/" + "linknet_predict.png", y_linknet)
    cv2.imwrite("./data/result/" + "edanet_predict.png", y_edanet)
    cv2.imwrite("./data/result/" + "erfnet_predict.png", y_erfnet)

    return 0
Exemple #11
0
def train():
    if args.choose_net == "Unet":
        model = my_unet.UNet(3, 1).to(device)
    if args.choose_net == "My_Unet":
        model = my_unet.My_Unet2(3, 1).to(device)
    elif args.choose_net == "Enet":
        model = enet.ENet(num_classes=13).to(device)
    elif args.choose_net == "Segnet":
        model = segnet.SegNet(3, 13).to(device)
    elif args.choose_net == "CascadNet":
        model = my_cascadenet.CascadeNet(3, 1).to(device)

    elif args.choose_net == "my_drsnet_A":
        model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_B":
        model = my_drsnet.MultiscaleSENetB(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_C":
        model = my_drsnet.MultiscaleSENetC(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_A_direct_skip":
        model = my_drsnet.MultiscaleSENetA_direct_skip(in_ch=3,
                                                       out_ch=1).to(device)
    elif args.choose_net == "SEResNet":
        model = my_drsnet.SEResNet18(in_ch=3, out_ch=1).to(device)

    elif args.choose_net == "resnext_unet":
        model = resnext_unet.resnext50(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "resnet50_unet":
        model = resnet50_unet.UNetWithResnet50Encoder(in_ch=3,
                                                      out_ch=1).to(device)
    elif args.choose_net == "unet_nest":
        model = unet_nest.UNet_Nested(3, 2).to(device)
    elif args.choose_net == "unet_res34":
        model = unet_res34.Resnet_Unet(3, 1).to(device)
    elif args.choose_net == "trangle_net":
        model = mytrangle_net.trangle_net(3, 1).to(device)
    elif args.choose_net == "dfanet":
        ch_cfg = [[8, 48, 96], [240, 144, 288], [240, 144, 288]]
        model = dfanet.DFANet(ch_cfg, 3, 1).to(device)
    elif args.choose_net == "lednet":
        model = lednet.Net(num_classes=1).to(device)
    elif args.choose_net == "cgnet":
        model = cgnet.Context_Guided_Network(classes=1).to(device)
    elif args.choose_net == "pspnet":
        model = pspnet.PSPNet(1).to(device)
    elif args.choose_net == "bisenet":
        model = bisenet.BiSeNet(1, 'resnet18').to(device)
    elif args.choose_net == "espnet":
        model = espnet.ESPNet(classes=1).to(device)
    elif args.choose_net == "fddwnet":
        model = fddwnet.Net(classes=1).to(device)
    elif args.choose_net == "contextnet":
        model = contextnet.ContextNet(classes=1).to(device)
    elif args.choose_net == "linknet":
        model = linknet.LinkNet(classes=1).to(device)
    elif args.choose_net == "edanet":
        model = edanet.EDANet(classes=1).to(device)
    elif args.choose_net == "erfnet":
        model = erfnet.ERFNet(classes=1).to(device)

    from collections import OrderedDict

    loadpretrained = 0
    # 0:no loadpretrained model
    # 1:loadpretrained model to original network
    # 2:loadpretrained model to new network
    if loadpretrained == 1:
        model.load_state_dict(torch.load(args.weight))

    elif loadpretrained == 2:
        model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device)
        model_dict = model.state_dict()
        pretrained_dict = torch.load(args.weight)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

        # model.load_state_dict(torch.load(args.weight))
        # pretrained_dict = {k: v for k, v in model.items() if k in model}  # filter out unnecessary keys
        # model.update(pretrained_dict)
        # model.load_state_dict(model)

    # 计算模型参数量和计算量FLOPs
    dsize = (1, 3, 128, 192)
    inputs = torch.randn(dsize).to(device)
    total_ops, total_params = profile(model, (inputs, ), verbose=False)
    print(" %.2f | %.2f" % (total_params / (1000**2), total_ops / (1000**3)))
    batch_size = args.batch_size

    # 加载数据集
    liver_dataset = LiverDataset("data/train_camvid/",
                                 transform=x_transform,
                                 target_transform=y_transform)
    len_img = liver_dataset.__len__()
    dataloader = DataLoader(liver_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=24)

    # DataLoader:该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor
    # batch_size:how many samples per minibatch to load,这里为4,数据集大小400,所以一共有100个minibatch
    # shuffle:每个epoch将数据打乱,这里epoch=10。一般在训练数据中会采用
    # num_workers:表示通过多个进程来导入数据,可以加快数据导入速度

    # 梯度下降
    # optimizer = optim.Adam(model.parameters())  # model.parameters():Returns an iterator over module parameters
    # # Observe that all parameters are being optimized

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=0.0001)

    # 每n个epoches来一次余弦退火
    cosine_lr_scheduler = lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=10 * int(len_img / batch_size), eta_min=0.00001)

    multiclass = 1
    if multiclass == 1:
        # 损失函数
        class_weights = np.array([
            0., 6.3005947, 4.31063664, 34.09234699, 50.49834979, 3.88280945,
            50.49834979, 8.91626081, 47.58477105, 29.41289083, 18.95706775,
            37.84558871, 39.3477858
        ])  #camvid
        # class_weights = weighing(dataloader, 13, c=1.02)
        class_weights = torch.from_numpy(class_weights).float().to(device)
        criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
        # criterion = LovaszLossSoftmax()
        # criterion = torch.nn.MSELoss()
        train_modelmulticlasses(model, criterion, optimizer, dataloader,
                                cosine_lr_scheduler)
    else:
        # 损失函数
        # criterion = LovaszLossHinge()
        # weights=[0.2]
        # weights=torch.Tensor(weights).to(device)
        # # criterion = torch.nn.CrossEntropyLoss(weight=weights)
        criterion = torch.nn.BCELoss()
        # criterion =focal_loss.FocalLoss(1)
        train_model(model, criterion, optimizer, dataloader,
                    cosine_lr_scheduler)
    n_classes = args.nClasses
    images_path = '../../datasets/segmentation/'
    val_file = './data/seg_test.txt' if n_classes == 2 else './data/parse_test.txt'
    weights_file = './weights/{}_seg_weights.h5'.format(args.model) if n_classes == 2 \
        else './weights/{}_parse_weights.h5'.format(args.model)
    input_height = 256
    input_width = 256

    if args.model == 'unet':
        m = unet.Unet(n_classes,
                      input_height=input_height,
                      input_width=input_width)
    elif args.model == 'segnet':
        m = segnet.SegNet(n_classes,
                          input_height=input_height,
                          input_width=input_width)
    else:
        raise ValueError('Do not support {}'.format(args.model))

    m.load_weights(weights_file.format(args.model))
    m.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

    print('Start evaluating..')
    pbdr = tqdm(total=5000)
    iou = [0. for _ in range(1, n_classes)]
    count = [0. for _ in range(1, n_classes)]
    for x, y in generator(images_path,
                          val_file,
Exemple #13
0
 def __init__(self, args):
     self.model = segnet.SegNet(
         args.in_ch, args.out_ch, args.base_kernel)
    def test_backward(self):
        opt = optimizers.MomentumSGD(lr=0.01)
        # Update each depth
        for depth in six.moves.range(1, self.n_encdec + 1):
            model = segnet.SegNet(self.n_encdec, self.n_classes,
                                  self.x_shape[1], self.n_mid)
            model = segnet.SegNetLoss(model,
                                      class_weight=None,
                                      train_depth=depth)
            opt.setup(model)

            # Deregister non-target links from opt
            if depth > 1:
                model.predictor.remove_link('conv_cls')
            for d in range(1, self.n_encdec + 1):
                if d != depth:
                    model.predictor.remove_link('encdec{}'.format(d))

            # Keep the initial values
            prev_params = {
                'conv_cls': copy.deepcopy(model.predictor.conv_cls.W.data)
            }
            for d in range(1, self.n_encdec + 1):
                name = '/encdec{}/enc/W'.format(d)
                encdec = getattr(model.predictor, 'encdec{}'.format(d))
                prev_params[name] = copy.deepcopy(encdec.enc.W.data)
                self.assertTrue(prev_params[name] is not encdec.enc.W.data)

            # Update the params
            x, t = self.get_xt()
            loss = model(x, t)
            loss.data *= 1E20
            model.cleargrads()
            loss.backward()
            opt.update()

            for d in range(1, self.n_encdec + 1):
                # The weight only in the target layer should be updated
                c = self.assertFalse if d == depth else self.assertTrue
                encdec = getattr(opt.target.predictor, 'encdec{}'.format(d))
                self.assertTrue(hasattr(encdec, 'enc'))
                self.assertTrue(hasattr(encdec.enc, 'W'))
                self.assertTrue('/encdec{}/enc/W'.format(d) in prev_params)
                c(np.array_equal(encdec.enc.W.data,
                                 prev_params['/encdec{}/enc/W'.format(d)]),
                  msg='depth:{} d:{} diff:{}'.format(
                      depth, d,
                      np.sum(encdec.enc.W.data -
                             prev_params['/encdec{}/enc/W'.format(d)])))
            if depth == 1:
                # The weight in the last layer should be updated
                self.assertFalse(
                    np.allclose(model.predictor.conv_cls.W.data,
                                prev_params['conv_cls']))

            cg = build_computational_graph([loss],
                                           variable_style=_var_style,
                                           function_style=_func_style).dump()

            fn = 'tests/SegNet_bw_depth-{}_{}.dot'.format(self.n_encdec, depth)
            if os.path.exists(fn):
                continue
            with open(fn, 'w') as f:
                f.write(cg)
            subprocess.call('dot -Tpng {} -o {}'.format(
                fn, fn.replace('.dot', '.png')),
                            shell=True)

            for name, param in model.namedparams():
                encdec_depth = re.search('encdec([0-9]+)', name)
                if encdec_depth:
                    ed = int(encdec_depth.groups()[0])
                    self.assertEqual(ed, depth)