Example #1
0
def get_loss():
    if args.loss == LossType.L1:
        return nn.L1Loss()
    if args.loss == LossType.SmoothL1:
        return nn.SmoothL1Loss(beta=0.01)
    if args.loss == LossType.L2:
        return nn.MSELoss()
    if args.loss == LossType.SSIM:
        return PIQLoss(piq.SSIMLoss())
    if args.loss == LossType.VIF:
        return PIQLoss(piq.VIFLoss())
    if args.loss == LossType.LPIPS:
        return PIQLoss(piq.LPIPS())
    if args.loss == LossType.DISTS:
        return PIQLoss(piq.DISTS())
    raise ValueError("Unknown loss")
Example #2
0
     lambda x, y: sk.structural_similarity(
         x,
         y,
         win_size=11,
         multichannel=True,
         gaussian_weights=True,
     ),
     'piq.ssim':
     piq.ssim,
     'kornia.SSIM-halfloss':
     kornia.SSIM(
         window_size=11,
         reduction='mean',
     ),
     'piq.SSIM-loss':
     piq.SSIMLoss(),
     'IQA.SSIM-loss':
     IQA.SSIM(),
     'vainf.SSIM':
     vainf.SSIM(data_range=1.),
     'piqa.SSIM':
     piqa.SSIM(),
 }),
 'MS-SSIM': (2, {
     'piq.ms_ssim': piq.multi_scale_ssim,
     'piq.MS_SSIM-loss': piq.MultiScaleSSIMLoss(),
     'IQA.MS_SSIM-loss': IQA.MS_SSIM(),
     'vainf.MS_SSIM': vainf.MS_SSIM(data_range=1.),
     'piqa.MS_SSIM': piqa.MS_SSIM(),
 }),
 'LPIPS': (
Example #3
0
def train(config):
    use_gpu = config.use_gpu
    bk_width = config.block_width
    bk_height = config.block_height
    resize = config.resize
    bTest = config.bTest

    if use_gpu:
        dehaze_net = net.dehaze_net().cuda()
    else:
        dehaze_net = net.dehaze_net()

    if config.snap_train_data:
        dehaze_net.load_state_dict(
            torch.load(config.snapshots_train_folder + config.snap_train_data))
    else:
        dehaze_net.apply(weights_init)
    print(dehaze_net)

    train_dataset = dataloader.dehazing_loader(config.orig_images_path,
                                               'train', resize, bk_width,
                                               bk_height, bTest)
    val_dataset = dataloader.dehazing_loader(config.orig_images_path, "val",
                                             resize, bk_width, bk_height,
                                             bTest)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train_batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=config.val_batch_size,
                                             shuffle=True,
                                             num_workers=config.num_workers,
                                             pin_memory=True)

    if use_gpu:
        criterion = nn.MSELoss().cuda()
    else:
        criterion = nn.MSELoss()

    optimizer = torch.optim.Adam(dehaze_net.parameters(),
                                 lr=config.lr,
                                 weight_decay=config.weight_decay)
    dehaze_net.train()

    # 同一組訓練資料跑 epoch 次
    save_counter = 0
    for epoch in range(config.num_epochs):
        # 有 iteration 張一起訓練.
        # img_orig , img_haze 是包含 iteration 個圖片的 tensor 資料集 , 訓練時會一口氣訓練 iteration 個圖片.
        # 有點像將圖片橫向拼起來 實際上是不同維度.
        if config.do_valid == 0:
            for iteration, (img_orig, img_haze, rgb, bl_num_width,
                            bl_num_height,
                            data_path) in enumerate(train_loader):
                if save_counter == 0:
                    print("img_orig.size:")
                    print(len(img_orig))
                    print("bl_num_width.type:")
                    print(bl_num_width.type)
                    print("shape:")
                    print(bl_num_width.shape)

                # train stage
                num_width = int(bl_num_width[0].item())
                num_height = int(bl_num_height[0].item())
                full_bk_num = num_width * num_height
                display_block_iter = full_bk_num / config.display_block_iter
                for index in range(len(img_orig)):
                    unit_img_orig = img_orig[index]
                    unit_img_haze = img_haze[index]
                    if save_counter == 0:
                        print("unit_img_orig type:")
                        print(unit_img_orig.type())
                        print("size:")
                        print(unit_img_orig.size())
                        print("shape:")
                        print(unit_img_orig.shape)
                    '''
                    if bTest == 1:
                        if save_counter ==0:
                            numpy_ori = unit_img_orig.numpy().copy()
                            print("data path:")
                            print(data_path)
                            print("index:"+str(index))

                            for i in range(3):
                                for j in range(32):
                                    print("before:")
                                    print(numpy_ori[index][i][j])
                                    print("after:")
                                    print(numpy_ori[index][i][j]*255)
                    '''

                    if use_gpu:
                        unit_img_orig = unit_img_orig.cuda()
                        unit_img_haze = unit_img_haze.cuda()

                    clean_image = dehaze_net(unit_img_haze)

                    loss = criterion(clean_image, unit_img_orig)

                    if torch.isnan(unit_img_haze).any() or torch.isinf(
                            clean_image).any():
                        print("unit_img_haze:")
                        print(unit_img_haze.shape)
                        print(unit_img_haze)

                        print("clean_image:")
                        print(clean_image.shape)
                        print(clean_image)

                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(dehaze_net.parameters(),
                                                   config.grad_clip_norm)
                    optimizer.step()

                    # show loss every config.display_block_iter
                    if ((index + 1) % display_block_iter) == 0:
                        print("Loss at Epoch:" + str(epoch) + "_index:" +
                              str(index + 1) + "/" + str(len(img_orig)) +
                              "_iter:" + str(iteration + 1) + "_Loss value:" +
                              str(loss.item()))
                    # save snapshot every save_counter times
                    if ((save_counter + 1) % config.snapshot_iter) == 0:
                        save_name = "Epoch:" + str(
                            epoch) + "_TrainTimes:" + str(save_counter +
                                                          1) + ".pth"
                        torch.save(dehaze_net.state_dict(),
                                   config.snapshots_folder + save_name)
                        # torch.save(dehaze_net.state_dict(),
                        #           config.snapshots_folder , "Epoch:", str(epoch), "
                        #           _TrainTimes:", str(save_counter+1), ".pth")

                    save_counter = save_counter + 1

        # Validation Stage
        # img_orig -> yuv444
        # img_haze -> yuv420
        for iter_val, (img_orig, img_haze, rgb, bl_num_width, bl_num_height,
                       data_path) in enumerate(val_loader):
            sub_image_list = []  # after deep_learning image (yuv420)
            sub_image_list_no_deep = []  # yuv420
            ori_sub_image_list = []  # yuv444 image

            rgb_image_list = []  # block ori image (rgb)
            rgb_list_from_sub = []  # rgb from clean image (yuv420)
            rgb_list_from_ori = []  # rgb from haze image  (yuv420)

            for index in range(len(img_orig)):
                unit_img_orig = img_orig[index]
                unit_img_haze = img_haze[index]
                unit_img_rgb = rgb[index]

                # TODO: yuv444 ??? color is strange ...
                '''
                if bTest == 1 and index == 0:
                    numpy_ori = unit_img_orig.numpy().copy()
                    print("data path:")
                    print(data_path)
                    print("index:" + str(index))

                    for i in range(3):
                        for j in range(32):
                            print(numpy_ori[index][i][j])
                    bTest = 0
                '''
                if use_gpu:
                    unit_img_orig = unit_img_orig.cuda()
                    unit_img_haze = unit_img_haze.cuda()
                    unit_img_rgb = unit_img_rgb.cuda()

                clean_image = dehaze_net(unit_img_haze)

                sub_image_list.append(clean_image)
                sub_image_list_no_deep.append(unit_img_haze)
                ori_sub_image_list.append(unit_img_orig)
                rgb_image_list.append(unit_img_rgb)

                rgb_list_from_sub.append(yuv2rgb(clean_image))
                rgb_list_from_ori.append(yuv2rgb(unit_img_haze))

            print(data_path)
            temp_data_path = data_path[0]
            print('temp_data_path:')
            print(temp_data_path)
            orimage_name = temp_data_path.split("/")[-1]
            print(orimage_name)
            orimage_name = orimage_name.split(".")[0]
            print(orimage_name)

            num_width = int(bl_num_width[0].item())
            num_height = int(bl_num_height[0].item())
            full_bk_num = num_width * num_height

            # YUV420 & after deep learning
            # ------------------------------------------------------------------#
            image_all = torch.cat((sub_image_list[:num_width]), 3)

            for i in range(num_width, full_bk_num, num_width):
                image_row = torch.cat(sub_image_list[i:i + num_width], 3)
                image_all = torch.cat([image_all, image_row], 2)

            image_name = config.sample_output_folder + str(
                iter_val + 1) + "_yuv420_deep_learning.bmp"
            print(image_name)

            torchvision.utils.save_image(
                image_all, config.sample_output_folder + "Epoch:" +
                str(epoch) + "_Index:" + str(iter_val + 1) + "_" +
                orimage_name + "_yuv420_deep.bmp")
            # ------------------------------------------------------------------#

            # YUV420 & without deep learning
            # ------------------------------------------------------------------#
            image_all_ori_no_deep = torch.cat(
                (sub_image_list_no_deep[:num_width]), 3)

            for i in range(num_width, full_bk_num, num_width):
                image_row = torch.cat(sub_image_list_no_deep[i:i + num_width],
                                      3)
                image_all_ori_no_deep = torch.cat(
                    [image_all_ori_no_deep, image_row], 2)

            image_name = config.sample_output_folder + str(
                iter_val + 1) + "_yuv420_ori.bmp"
            print(image_name)

            torchvision.utils.save_image(
                image_all_ori_no_deep, config.sample_output_folder + "Epoch:" +
                str(epoch) + "_Index:" + str(iter_val + 1) + "_" +
                orimage_name + "_yuv420_ori.bmp")
            # ------------------------------------------------------------------#

            # YUV444
            # ------------------------------------------------------------------#
            image_all_ori = torch.cat(ori_sub_image_list[:num_width], 3)

            for i in range(num_width, full_bk_num, num_width):
                image_row = torch.cat(ori_sub_image_list[i:i + num_width], 3)
                image_all_ori = torch.cat([image_all_ori, image_row], 2)

            image_name = config.sample_output_folder + str(iter_val +
                                                           1) + "_yuv444.bmp"
            print(image_name)
            # torchvision.utils.save_image(image_all_ori, image_name)
            torchvision.utils.save_image(
                image_all_ori, config.sample_output_folder + "Epoch:" +
                str(epoch) + "_Index:" + str(iter_val + 1) + "_" +
                orimage_name + "_yuv444.bmp")
            # ------------------------------------------------------------------#

            # block rgb (test)
            # ------------------------------------------------------------------#
            rgb_image_all = torch.cat(rgb_image_list[:num_width], 3)
            for i in range(num_width, full_bk_num, num_width):
                image_row = torch.cat(rgb_image_list[i:i + num_width], 3)
                '''
                image_row = torch.cat((ori_sub_image_list[i],ori_sub_image_list[i +1]), 1)
                for j in range(i+2, num_width):
                    image_row = torch.cat((image_row, ori_sub_image_list[j]), 1)
                '''
                rgb_image_all = torch.cat([rgb_image_all, image_row], 2)
            image_name = config.sample_output_folder + str(iter_val +
                                                           1) + "_rgb.bmp"
            print(image_name)
            torchvision.utils.save_image(
                rgb_image_all, config.sample_output_folder + "Epoch:" +
                str(epoch) + "_Index:" + str(iter_val + 1) + "_" +
                orimage_name + "_rgb.bmp")
            # ------------------------------------------------------------------#

            # ------------------------------------------------------------------#
            rgb_from_420_image_all_clear = torch.cat(
                rgb_list_from_sub[:num_width], 3)
            for i in range(num_width, full_bk_num, num_width):
                image_row = torch.cat(rgb_list_from_sub[i:i + num_width], 3)
                rgb_from_420_image_all_clear = torch.cat(
                    [rgb_from_420_image_all_clear, image_row], 2)

            image_name = config.sample_output_folder + str(
                iter_val + 1) + "_rgb_from_clean_420.bmp"
            print(image_name)
            torchvision.utils.save_image(
                rgb_from_420_image_all_clear, config.sample_output_folder +
                "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" +
                orimage_name + "_rgb_from_clean_420.bmp")
            # ------------------------------------------------------------------#

            # ------------------------------------------------------------------#
            rgb_from_420_image_all_haze = torch.cat(
                rgb_list_from_ori[:num_width], 3)
            for i in range(num_width, full_bk_num, num_width):
                image_row = torch.cat(rgb_list_from_ori[i:i + num_width], 3)
                rgb_from_420_image_all_haze = torch.cat(
                    [rgb_from_420_image_all_haze, image_row], 2)
            image_name = config.sample_output_folder + str(
                iter_val + 1) + "_rgb_from_haze_420.bmp"
            print(image_name)
            torchvision.utils.save_image(
                rgb_from_420_image_all_haze, config.sample_output_folder +
                "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" +
                orimage_name + "__rgb_from_haze_420.bmp")
            # ------------------------------------------------------------------#

            # To compute PSNR as a measure, use lower case function from the library.
            # ------------------------------------------------------------------#
            # rgb_from_420_image_all_haze rgb_image_all
            # rgb_from_420_image_all_clear rgb_image_all
            psnr_index = piq.psnr(rgb_from_420_image_all_haze,
                                  rgb_image_all,
                                  data_range=1.,
                                  reduction='none')
            print(f"PSNR haze: {psnr_index.item():0.4f}")

            psnr_index = piq.psnr(rgb_from_420_image_all_clear,
                                  rgb_image_all,
                                  data_range=1.,
                                  reduction='none')
            print(f"PSNR clear: {psnr_index.item():0.4f}")
            # ------------------------------------------------------------------#

            # To compute SSIM as a measure, use lower case function from the library.
            # ------------------------------------------------------------------#

            ssim_index = piq.ssim(rgb_from_420_image_all_haze,
                                  rgb_image_all,
                                  data_range=1.)
            ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)(
                rgb_from_420_image_all_haze, rgb_image_all)
            print(
                f"SSIM haze index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}"
            )

            ssim_index = piq.ssim(rgb_from_420_image_all_clear,
                                  rgb_image_all,
                                  data_range=1.)
            ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)(
                rgb_from_420_image_all_clear, rgb_image_all)
            print(
                f"SSIM clear index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}"
            )
            # ------------------------------------------------------------------#

        torch.save(dehaze_net.state_dict(),
                   config.snapshots_folder + "dehazer.pth")
Example #4
0
def main():
    # Read RGB image and it's noisy version
    x = torch.tensor(imread('tests/assets/i01_01_5.bmp')).permute(2, 0,
                                                                  1) / 255.
    y = torch.tensor(imread('tests/assets/I01.BMP')).permute(2, 0, 1) / 255.

    if torch.cuda.is_available():
        # Move to GPU to make computaions faster
        x = x.cuda()
        y = y.cuda()

    # To compute BRISQUE score as a measure, use lower case function from the library
    brisque_index: torch.Tensor = piq.brisque(x,
                                              data_range=1.,
                                              reduction='none')
    # In order to use BRISQUE as a loss function, use corresponding PyTorch module.
    # Note: the back propagation is not available using torch==1.5.0.
    # Update the environment with latest torch and torchvision.
    brisque_loss: torch.Tensor = piq.BRISQUELoss(data_range=1.,
                                                 reduction='none')(x)
    print(
        f"BRISQUE index: {brisque_index.item():0.4f}, loss: {brisque_loss.item():0.4f}"
    )

    # To compute Content score as a loss function, use corresponding PyTorch module
    # By default VGG16 model is used, but any feature extractor model is supported.
    # Don't forget to adjust layers names accordingly. Features from different layers can be weighted differently.
    # Use weights parameter. See other options in class docstring.
    content_loss = piq.ContentLoss(feature_extractor="vgg16",
                                   layers=("relu3_3", ),
                                   reduction='none')(x, y)
    print(f"ContentLoss: {content_loss.item():0.4f}")

    # To compute DISTS as a loss function, use corresponding PyTorch module
    # By default input images are normalized with ImageNet statistics before forwarding through VGG16 model.
    # If there is no need to normalize the data, use mean=[0.0, 0.0, 0.0] and std=[1.0, 1.0, 1.0].
    dists_loss = piq.DISTS(reduction='none')(x, y)
    print(f"DISTS: {dists_loss.item():0.4f}")

    # To compute FSIM as a measure, use lower case function from the library
    fsim_index: torch.Tensor = piq.fsim(x, y, data_range=1., reduction='none')
    # In order to use FSIM as a loss function, use corresponding PyTorch module
    fsim_loss = piq.FSIMLoss(data_range=1., reduction='none')(x, y)
    print(
        f"FSIM index: {fsim_index.item():0.4f}, loss: {fsim_loss.item():0.4f}")

    # To compute GMSD as a measure, use lower case function from the library
    # This is port of MATLAB version from the authors of original paper.
    # In any case it should me minimized. Usually values of GMSD lie in [0, 0.35] interval.
    gmsd_index: torch.Tensor = piq.gmsd(x, y, data_range=1., reduction='none')
    # In order to use GMSD as a loss function, use corresponding PyTorch module:
    gmsd_loss: torch.Tensor = piq.GMSDLoss(data_range=1., reduction='none')(x,
                                                                            y)
    print(
        f"GMSD index: {gmsd_index.item():0.4f}, loss: {gmsd_loss.item():0.4f}")

    # To compute HaarPSI as a measure, use lower case function from the library
    # This is port of MATLAB version from the authors of original paper.
    haarpsi_index: torch.Tensor = piq.haarpsi(x,
                                              y,
                                              data_range=1.,
                                              reduction='none')
    # In order to use HaarPSI as a loss function, use corresponding PyTorch module
    haarpsi_loss: torch.Tensor = piq.HaarPSILoss(data_range=1.,
                                                 reduction='none')(x, y)
    print(
        f"HaarPSI index: {haarpsi_index.item():0.4f}, loss: {haarpsi_loss.item():0.4f}"
    )

    # To compute LPIPS as a loss function, use corresponding PyTorch module
    lpips_loss: torch.Tensor = piq.LPIPS(reduction='none')(x, y)
    print(f"LPIPS: {lpips_loss.item():0.4f}")

    # To compute MDSI as a measure, use lower case function from the library
    mdsi_index: torch.Tensor = piq.mdsi(x, y, data_range=1., reduction='none')
    # In order to use MDSI as a loss function, use corresponding PyTorch module
    mdsi_loss: torch.Tensor = piq.MDSILoss(data_range=1., reduction='none')(x,
                                                                            y)
    print(
        f"MDSI index: {mdsi_index.item():0.4f}, loss: {mdsi_loss.item():0.4f}")

    # To compute MS-SSIM index as a measure, use lower case function from the library:
    ms_ssim_index: torch.Tensor = piq.multi_scale_ssim(x, y, data_range=1.)
    # In order to use MS-SSIM as a loss function, use corresponding PyTorch module:
    ms_ssim_loss = piq.MultiScaleSSIMLoss(data_range=1., reduction='none')(x,
                                                                           y)
    print(
        f"MS-SSIM index: {ms_ssim_index.item():0.4f}, loss: {ms_ssim_loss.item():0.4f}"
    )

    # To compute Multi-Scale GMSD as a measure, use lower case function from the library
    # It can be used both as a measure and as a loss function. In any case it should me minimized.
    # By defualt scale weights are initialized with values from the paper.
    # You can change them by passing a list of 4 variables to scale_weights argument during initialization
    # Note that input tensors should contain images with height and width equal 2 ** number_of_scales + 1 at least.
    ms_gmsd_index: torch.Tensor = piq.multi_scale_gmsd(x,
                                                       y,
                                                       data_range=1.,
                                                       chromatic=True,
                                                       reduction='none')
    # In order to use Multi-Scale GMSD as a loss function, use corresponding PyTorch module
    ms_gmsd_loss: torch.Tensor = piq.MultiScaleGMSDLoss(chromatic=True,
                                                        data_range=1.,
                                                        reduction='none')(x, y)
    print(
        f"MS-GMSDc index: {ms_gmsd_index.item():0.4f}, loss: {ms_gmsd_loss.item():0.4f}"
    )

    # To compute PSNR as a measure, use lower case function from the library.
    psnr_index = piq.psnr(x, y, data_range=1., reduction='none')
    print(f"PSNR index: {psnr_index.item():0.4f}")

    # To compute PieAPP as a loss function, use corresponding PyTorch module:
    pieapp_loss: torch.Tensor = piq.PieAPP(reduction='none', stride=32)(x, y)
    print(f"PieAPP loss: {pieapp_loss.item():0.4f}")

    # To compute SSIM index as a measure, use lower case function from the library:
    ssim_index = piq.ssim(x, y, data_range=1.)
    # In order to use SSIM as a loss function, use corresponding PyTorch module:
    ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)(x, y)
    print(
        f"SSIM index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}")

    # To compute Style score as a loss function, use corresponding PyTorch module:
    # By default VGG16 model is used, but any feature extractor model is supported.
    # Don't forget to adjust layers names accordingly. Features from different layers can be weighted differently.
    # Use weights parameter. See other options in class docstring.
    style_loss = piq.StyleLoss(feature_extractor="vgg16",
                               layers=("relu3_3", ))(x, y)
    print(f"Style: {style_loss.item():0.4f}")

    # To compute TV as a measure, use lower case function from the library:
    tv_index: torch.Tensor = piq.total_variation(x)
    # In order to use TV as a loss function, use corresponding PyTorch module:
    tv_loss: torch.Tensor = piq.TVLoss(reduction='none')(x)
    print(f"TV index: {tv_index.item():0.4f}, loss: {tv_loss.item():0.4f}")

    # To compute VIF as a measure, use lower case function from the library:
    vif_index: torch.Tensor = piq.vif_p(x, y, data_range=1.)
    # In order to use VIF as a loss function, use corresponding PyTorch class:
    vif_loss: torch.Tensor = piq.VIFLoss(sigma_n_sq=2.0, data_range=1.)(x, y)
    print(f"VIFp index: {vif_index.item():0.4f}, loss: {vif_loss.item():0.4f}")

    # To compute VSI score as a measure, use lower case function from the library:
    vsi_index: torch.Tensor = piq.vsi(x, y, data_range=1.)
    # In order to use VSI as a loss function, use corresponding PyTorch module:
    vsi_loss: torch.Tensor = piq.VSILoss(data_range=1.)(x, y)
    print(f"VSI index: {vsi_index.item():0.4f}, loss: {vsi_loss.item():0.4f}")
Example #5
0
     'kornia.PSNR': kornia.PSNRLoss(max_val=1.),
     'piqa.PSNR': piqa.PSNR(),
 }),
 'SSIM': (2, {
     'sk.ssim': lambda x, y: sk.structural_similarity(
         x, y,
         win_size=11,
         multichannel=True,
         gaussian_weights=True,
     ),
     'piq.ssim': piq.ssim,
     'kornia.SSIM-halfloss': kornia.SSIM(
         window_size=11,
         reduction='mean',
     ),
     'piq.SSIM-loss': piq.SSIMLoss(),
     'IQA.SSIM-loss': IQA.SSIM(),
     'vainf.SSIM': vainf.SSIM(data_range=1.),
     'piqa.SSIM': piqa.SSIM(),
 }),
 'MS-SSIM': (2, {
     'piq.ms_ssim': piq.multi_scale_ssim,
     'piq.MS_SSIM-loss': piq.MultiScaleSSIMLoss(),
     'IQA.MS_SSIM-loss': IQA.MS_SSIM(),
     'vainf.MS_SSIM': vainf.MS_SSIM(data_range=1.),
     'piqa.MS_SSIM': piqa.MS_SSIM(),
 }),
 'LPIPS': (2, {
     'piq.LPIPS': piq.LPIPS(),
     'IQA.LPIPS': IQA.LPIPSvgg(),
     'piqa.LPIPS': piqa.LPIPS(network='vgg')