Ejemplo n.º 1
0
    def __init__(self, args, writer, device):
        """

        :param args:
        """
        super(Meta, self).__init__()
        self.device = device
        self.update_lr = args.update_lr
        self.meta_lr = args.meta_lr
        self.n_way = args.n_way
        self.k_spt = args.k_spt
        self.k_qry = args.k_qry
        self.task_num = args.task_num
        self.update_step = args.update_step
        self.update_step_test = args.update_step_test
        self.writer = writer
        self.in_channels = args.imgc
        self.out_channels = args.output_channel
        self.epsilon = 1e-10

        self.net = UNet(in_channels=args.imgc,
                        out_channels=args.output_channel)
        self.net_param = []
        for param in self.net.parameters():
            self.net_param.append(param.clone().data.to(device))

        logging.info(f'Network:\n'
                     f'\t{self.net.in_channels} input channels\n'
                     f'\t{self.net.out_channels} output channels (classes)')

        if args.load:
            net.load_state_dict(torch.load(args.load, map_location=device))
            logging.info(f'Model loaded from {args.load}')

        # define loss
        self.mse_loss_fn = torch.nn.MSELoss(reduction='none')
        self.ssim_loss_fc = pytorch_msssim.SSIM(window_size=7)
        self.contour_loss = Contour_loss(K=5)

        # define optimizer for meta learning
        if args.optimizer == "adam":
            self.meta_optim = optim.Adam(self.net.parameters(),
                                         lr=self.meta_lr)
        elif args.optimizer == "rmsprop":
            self.meta_optim = optim.RMSprop(self.net.parameters(),
                                            lr=self.meta_lr,
                                            weight_decay=self.weight_decay)
        else:
            raise ValueError("Wrong Optimzer !")
Ejemplo n.º 2
0
def main():
    
    # set logging and writer
    writer = SummaryWriter(log_dir=args.dir_checkpoint)
    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')
    
    synthesis_test = Synthesis_Image(args.dataset, mode='test', batchsz = args.batchsize, test_data=args.test_data)

    # define loss
    mse_loss_fn = torch.nn.MSELoss(reduction='none')
    ssim_loss_fc = pytorch_msssim.SSIM(window_size = 7)
    contour_loss = Contour_loss(K=5)

    # define net
    net = UNet(in_channels=args.imgc, out_channels=args.output_channel)
    net.to(device)

    if args.load:
        net.load_state_dict(torch.load(args.load))
        print('Model loaded from {}'.format(args.load)) 

    try:
        test(net,mse_loss_fn,ssim_loss_fc,contour_loss,synthesis_test,device,writer)
    except KeyboardInterrupt:
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
Ejemplo n.º 3
0
    def __init__(self, args, conv=common.default_conv):
        super(EDSR, self).__init__()

        n_resblocks = args.n_resblocks
        n_feats = args.n_feats
        kernel_size = 3
        scale = args.scale[0]
        act = nn.ReLU(True)
        url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale)
        if url_name in url:
            self.url = url[url_name]
        else:
            self.url = None
        self.sub_mean = common.MeanShift(args.rgb_range)
        self.add_mean = common.MeanShift(args.rgb_range, sign=1)

        # define head module
        m_head = [conv(args.n_colors, n_feats, kernel_size)]

        # define body module
        m_body = [
            common.ResBlock(conv,
                            n_feats,
                            kernel_size,
                            act=act,
                            res_scale=args.res_scale)
            for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        m_tail = [
            common.Upsampler(conv, scale, n_feats, act=False),
            conv(n_feats, args.n_colors, kernel_size)
        ]

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)
        self.unet = UNet(args.n_colors, args.n_colors)
Ejemplo n.º 4
0
    # 检查文件目录
    config.result_path = os.path.join(config.result_path, config.Task_name)
    print(config.result_path)
    config.model_path = os.path.join(config.result_path, 'models')
    config.log_dir = os.path.join(config.result_path, 'logs')
    if not os.path.exists(config.result_path):
        os.makedirs(config.result_path)
        os.makedirs(config.model_path)
        os.makedirs(config.log_dir)

    # 记录训练配置
    f = open(os.path.join(config.result_path, 'config.txt'), 'w')
    for key in config.__dict__:
        print('%s: %s' % (key, config.__getattribute__(key)), file=f)
    f.close()

    # 记录训练过程
    config.record_file = os.path.join(config.result_path, 'record.txt')
    f = open(config.record_file, 'a')
    f.close()

    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(device)
    # 加载网络,图片单通道1,分类为1。
    train_net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    train_net.to(device)
    train(train_net, device, config)
Ejemplo n.º 5
0
import glob
import numpy as np
import torch
import os
import cv2
from model.unet_model import UNet

if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = UNet(n_channels=1, n_classes=1)
    net.to(device=device)
    net.load_state_dict(torch.load('model.pth', map_location=device))
    net.eval()
    tests_path = glob.glob('D:/Research/Dataset/3DOH50K/testset/*.jpg')
    for test_path in tests_path:
        save_res_path = test_path.split('.')[0] + '_res.jpg'
        img = cv2.imread(test_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        img = img.reshape(1, 1, img.shape[0], img.shape[1])
        img_tensor = torch.from_numpy(img)
        img_tensor = img_tensor.to(device=device, dtype=torch.float32)
        pred = net(img_tensor)
        pred = np.array(pred.data.cpu()[0])[0]
        pred[pred >= 0.5] = 255
        pred[pred < 0.5] = 0
        cv2.imwrite(save_res_path, pred)
Ejemplo n.º 6
0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    criterion = nn.CrossEntropyLoss()
    patch_size = 64
    batch_size = 8
    num_class = 13
    save_dir = "./results"
    # 加载数据集
    data_dir = "/home/cym/Datasets/StData-12/F3_block/"
    dataset = F3DS(data_dir, ptsize=patch_size, train=False)
    test_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_size=batch_size, 
                                               shuffle=False)
    hw = dataset.hw
    origin_mask = np.zeros((hw[0], hw[1]))

    net = UNet(n_channels=1, n_classes=num_class, bilinear=False)
    net.load_state_dict(torch.load('models2/best_model.pth'))
    net.to(device=device)
    net.eval()
    print("net prepare done")
    if not os.path.exists(save_dir):
        os.makedirs(f"{save_dir}/img")
        os.makedirs(f"{save_dir}/label")
        os.makedirs(f"{save_dir}/predlabel")

    all_images_num = 0.
    all_acc = 0.

    img_idx = 0
    for batch_idx, (image, label, hys, wxs) in enumerate(test_loader):
Ejemplo n.º 7
0
    # 加载测试集
    # TestData_dataset = TestData_Loader('D:/Research/Dataset/3DOH50K/testset/')
    # tests_path = 'D:/Research/Dataset/3DOH50K/testset/masks/'

    TestData_dataset = TestData_Loader('D:/Research/Dataset/3DOH50K/notset/')
    tests_path = 'D:/Research/Dataset/3DOH50K/notset/masks_test/'
    print("数据个数: ", len(TestData_dataset))
    test_loader = torch.utils.data.DataLoader(dataset=TestData_dataset,
                                              batch_size=batch_size,
                                              shuffle=False)

    out_dir = "model_BCE_bs16.pkl"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # load model
    model = UNet(n_channels=3, n_classes=1).to(device=device)
    model.load_state_dict(torch.load(out_dir))
    model.eval()

    batchcnt = 0

    with torch.no_grad():
        for image in test_loader:
            image = image.to(device=device, dtype=torch.float32)
            pred = model(image)
            # 写入batch中的每个数据
            cnt = 0
            for mask in pred:

                # ii = (image[cnt] * 255.0).to(device=device, dtype=torch.uint8)
                # imShow = np.array(ii.data.cpu())
Ejemplo n.º 8
0
            predShow = predShow.reshape(512, 512, 1)

            cv2.imshow("image", imShow)
            cv2.imshow("mask", maskShow)
            cv2.imshow("pred", predShow)
            cv2.waitKey()

            # 计算loss
            loss = criterion(pred, label)
            print('Loss/train: ', loss.item())
            # 保存loss值最小的参数
            if loss < best_loss:
                best_loss = loss
                torch.save(net.state_dict(), 'model.pth')
            # 更新参数
            loss.backward()
            optimizer.step()


if __name__ == '__main__':
    # 选择设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络
    net = UNet(n_channels=3, n_classes=1)
    # 网络拷贝到device中
    net.to(device=device)
    #指定训练集
    data_path = 'data/train/'
    # 开始训练
    train_net(net, device, data_path)
Ejemplo n.º 9
0
                                               shuffle=True)
    # TrainData_dataset = TrainData_Loader('D:/Research/Dataset/3DOH50K/notset/')
    # print("数据个数: ", len(TrainData_dataset))
    # train_loader = torch.utils.data.DataLoader(
    #     dataset=TrainData_dataset,
    #     batch_size=3,
    #     shuffle=False
    # )

    print(len(train_loader))

    out_dir = "model_Dice_ep40_bs16.pkl"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # load model
    model = UNet(n_channels=3, n_classes=1).to(device=device)

    # 定义下降算法
    optimizer = optim.Adam(model.parameters(), lr=3e-4)
    # optimizer = optim.RMSprop(model.parameters(), lr=1e-3, weight_decay=1e-8, momentum=0.9)

    # 定义loss
    # criterion = nn.L1Loss(reduction='mean')
    # criterion = nn.BCEWithLogitsLoss()
    criterion = DiceLoss()
    best_loss = float('inf')

    # 训练
    epochs = 40
    for epoch in range(epochs):
        model.train()
def predict(in_channel, model_path, data_path, light=False):
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道,分类为1。
    if light:
        net = UNet_light(n_channels=in_channel, n_classes=1)
    else:
        net = UNet(n_channels=in_channel, n_classes=1)
    # net = Unet_v2(in_channels=in_channel, n_classes=1)

    # 将网络拷贝到deivce中
    net.to(device=device)
    # 加载模型参数
    net.load_state_dict(torch.load(model_path, map_location=device))
    # 测试模式
    net.eval()
    # 读取所有图片路径
    with open(os.path.join(data_path, 'valid.pkl'), "rb") as f:
        valid = pickle.load(f)

    tests_path = [os.path.join(data_path, path) for path in valid]
    # 遍历素有图片
    for test_path in tqdm(tests_path):
        # 保存结果地址
        save_res_path = test_path.replace("train", "valid_predict")
        # 读取图片
        img = cv2.imread(test_path)

        img_shape = img.shape

        if in_channel == 1:
            # 转为灰度图
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            # 转为batch为1,通道为1,大小为512*512的数组
            img = transforms.ToTensor()(img)
        else:
            # 转为batch为1,通道为3,大小为512*512的数组
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            PIL_image = Image.fromarray(img)
            transform = transforms.Compose([
                transforms.Resize((img_shape[0] // 2, img_shape[1] // 2)),
                transforms.ToTensor(),  #数据归一化到[0,1],输入通道转换在前
                transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                     std=(0.5, 0.5, 0.5)),  # 数据归一化到[-1,1]
            ])
            img = transform(PIL_image)
            img = img.unsqueeze(0)  # 加入batch维度

        # # 转为tensor
        # img_tensor = torch.from_numpy(img)
        # 将tensor拷贝到device中,只用cpu就是拷贝到cpu中,用cuda就是拷贝到cuda中。
        img = img.to(device=device, dtype=torch.float32)
        # 预测
        pred = net(img)
        # 提取结果
        pred = np.array(pred.data.cpu()[0])[0]
        # 处理结果
        pred[pred >= 0.5] = 255
        pred[pred < 0.5] = 0
        # 保存图片

        cv2.imwrite(save_res_path, pred)
Ejemplo n.º 11
0
import glob
import numpy as np
import torch
import os
import cv2
from model.unet_model import UNet

if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道,分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 加载模型参数
    net.load_state_dict(
        torch.load('/Users/manmi/Documents/GitHub/unet/best_model.pth',
                   map_location=device))
    # 测试模式
    net.eval()
    # 读取所有图片路径
    tests_path = glob.glob(
        '/Users/manmi/Documents/GitHub/unet/data/test/*.png')
    # 遍历素有图片
    for test_path in tests_path:
        # 保存结果地址
        save_res_path = test_path.split('.')[0] + '_res.png'
        # 读取图片
        img = cv2.imread(test_path)
        # 转为灰度图
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
def main(
    model_name,
    from_file=True,
    to_file=False,
    validation=True,
    test=["precision", "recall", "f1", "accuracy", "jaccard"],
    concat=False,
    plot=True,
):
    """
    Parameters
    ----------
    model_name : str
        Which model to do things with. This is assumed to be both the name of the directory in which parameters are stored, and the name of the parameters file.
    from_file : bool
        Whether data should be loaded from a file; otherwise it will be generated.
        The file should:
            - be in the same directory as the model parameters
            - be called "data.npz"
            - contain 4 arrays: "val_predictions", "val_labels", "test_predictions" and "test_labels".
        Labels are expected to be floats and will be thresholded.
        Predictions are expected to be raw (not probabilities).
    to_file : bool
        If data is generated (not loaded from a file), whether to save to a file in the model directory, according to the form described in from_file.
        Irrelevant when from_file is set to True.
    validation : bool
        Whether to go through validation steps (to find the best threshold).
    test : list of str
        Which metrics to use for testing (evaluate the model with a given threshold).
        The results are stored in a txt file called "test_results.txt" in the model directory.
        If empty then testing is skipped.
    concat : bool
        During testing, whether to compute each metric once, on the concatenation of the whole test set.
    plot : bool
        Whether to show plots during run.
    """
    model_dir = os.path.join(dir_models, model_name)
    params_file = os.path.join(model_dir, model_name)
    new_section = "=" * 50

    print("Importing model parameters from {}".format(params_file))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(n_channels=3, n_classes=1, bilinear=False)
    model = model.to(device)
    model.load_state_dict(
        torch.load(params_file, map_location=torch.device("cpu")))

    # File where data are stored, or will be if they aren't already
    data_file = os.path.join(model_dir, "data.npz")
    print(new_section)
    if from_file:
        print("Loading data")
        arrays = np.load(data_file)
        val_predictions, val_labels, test_predictions, test_labels = arrays.values(
        )
    else:
        print("Generating data")
        dir_data_validation = os.path.join(dir_data, "validation")
        dir_data_test = os.path.join(dir_data, "test")

        _, validation_dl, test_dl = load_data(
            dir_data_validation=dir_data_validation,
            dir_data_test=dir_data_test,
            prop_noPV_training=0,  # Has no impact
            min_rescale_images=0,  # Has no impact
            batch_size=100,  # All of them
        )

        model.eval()
        with torch.no_grad():
            # Get images and labels from both DataLoaders
            val_images, val_labels = next(iter(validation_dl))
            test_images, test_labels = next(iter(test_dl))
            val_images = val_images.to(device, dtype=torch.float32)
            test_images = test_images.to(device, dtype=torch.float32)
            # Make predictions (predictions are not probabilities at this stage)
            print("Running model on data")
            val_predictions = model(val_images)
            test_predictions = model(test_images)
            # Convert to numpy arrays for computing
            val_predictions = np.squeeze(val_predictions.cpu().numpy())
            val_labels = np.squeeze(val_labels.cpu().numpy())
            test_predictions = np.squeeze(test_predictions.cpu().numpy())
            test_labels = np.squeeze(test_labels.cpu().numpy())
            # Save to file as numpy arrays
            if to_file:
                print("Saving results to file")
                np.savez_compressed(
                    data_file,
                    val_predictions=val_predictions,
                    val_labels=val_labels,
                    test_predictions=test_predictions,
                    test_labels=test_labels,
                )

    threshold_true = 0.5
    val_labels = np.where(val_labels > threshold_true, 1, 0)
    test_labels = np.where(test_labels > threshold_true, 1, 0)

    if validation:
        n_thresholds = 101
        print(new_section)
        print("Validation starting")
        precision, recall, f1_scores, best_threshold = find_best_threshold(
            val_predictions,
            val_labels,
            n_thresholds,
            concat=concat,
            plot=plot)
        print(f"Found best threshold to be {best_threshold:.4f}")
        if to_file:
            precision_lower, precision_mid, precision_upper = (
                row for row in summary_stats(precision))
            f1_lower, f1_mid, f1_upper = (row
                                          for row in summary_stats(f1_scores))
            _, recall_mid, _ = (row for row in summary_stats(recall))
            results_summary = np.c_[np.linspace(0, 1, n_thresholds),
                                    precision_lower, precision_mid,
                                    precision_upper, recall_mid, f1_lower,
                                    f1_mid, f1_upper]
            results_file = os.path.join(model_dir, "prec_rec_f1.txt")
            print("Saving results to {}".format(results_file))
            np.savetxt(
                results_file,
                results_summary,
                delimiter=" ",
                header=
                f"Threshold: {best_threshold:.3f}\nthresholds precision_lower  precision_mid  precision_upper  recall_mid  f1_lower  f1_mid  f1_upper"
            )

    if test:
        print(new_section)
        print("Testing starting with metrics:")
        print(", ".join(test))
        results = test_model(test_predictions, test_labels, best_threshold,
                             concat, *test)
        print(results)
        summary_type = "median"
        results_file = os.path.join(
            model_dir,
            "test_{}results.txt".format("concat_" if concat else ""))
        if concat:
            results_summary = np.transpose(results)
            print("Results:")
        else:
            results_summary = np.transpose(
                summary_stats(results, type=summary_type))
            print(f"Summary statistics are based on the {summary_type}")
            print("Results (lower, mid-point, upper):")
        print(f"\tBest threshold = {best_threshold:.4f}")
        for i, measure in enumerate(test):
            print("\t{}: {}".format(measure, results_summary[i, :]))
        print(new_section)
        print("Saving results to {}".format(results_file))
        np.savetxt(
            results_file,
            results_summary,
            fmt="%.4f",
            delimiter=" ",
            header=f"Threshold: {best_threshold:.3f}\n{'  '.join(test)}",
        )

    print("\n")
Ejemplo n.º 13
0
def main():

    # set logging and writer
    writer = SummaryWriter(
        log_dir=args.dir_checkpoint + "run",
        comment=
        f'Learning Rate_{args.meta_lr}_Batch size_{args.batchsize}_Image Scale_{args.imgsz}'
    )
    logging.basicConfig(level=logging.INFO,
                        format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    # define net
    net = UNet(in_channels=args.imgc, out_channels=args.output_channel)
    net.to(device)

    if args.load:
        net.load_state_dict(torch.load(args.load))
        print('Model loaded from {}'.format(args.load))

    # define loss
    mse_loss_fn = torch.nn.MSELoss(reduction='none')
    ssim_loss_fc = pytorch_msssim.SSIM(window_size=7)
    contour_loss = Contour_loss(K=5)

    # define optimizer for meta learning
    if args.optimizer == "adam":
        meta_optim = optim.Adam(net.parameters(), lr=args.meta_lr)
    elif args.optimizer == "rmsprop":
        meta_optim = optim.RMSprop(net.parameters(),
                                   lr=args.meta_lr,
                                   weight_decay=args.weight_decay)
    else:
        raise ValueError("Wrong Optimzer !")

    # define step scheduler
    scheduler = optim.lr_scheduler.StepLR(meta_optim,
                                          step_size=args.step_size,
                                          gamma=args.step_adjust)

    # define global loss
    best_model_loss = 10000000

    # batch(batch set) of meta training set for each tasks and for meta testing
    synthesis_train = Synthesis_Image(args.dataset,
                                      mode='train_normal',
                                      batchsz=args.batchsize)
    synthesis_val = Synthesis_Image(args.dataset,
                                    mode='val_normal',
                                    batchsz=args.batchsize)
    try:
        train(net, synthesis_train, synthesis_val, device, best_model_loss,
              writer, scheduler, mse_loss_fn, ssim_loss_fc, contour_loss,
              meta_optim)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), args.dir_checkpoint + 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
Ejemplo n.º 14
0
        "gtFine",
        categories,
        transform=transforms.Compose([
            Resize(_IMAGE_SIZE_),
            Normalize(),
            ToTensor(),
            #TODO: Apply random color changes
            #TODO: Apply random spatial changes (rotation, flip etc)
        ]))
    trainloader = DataLoader(cityscapes_dataset,
                             batch_size=8,
                             shuffle=True,
                             num_workers=4)

    model = UNet(n_classes=len(categories),
                 in_channels=_NUM_CHANNELS_,
                 writer=writer)
    if torch.cuda.device_count() >= 1:
        print("Training model on ", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    # softmax = nn.Softmax2d()
    # criterion = nn.BCELoss()
    # criterion = nn.CrossEntropyLoss()
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

    # Network training
    epoch_data = {}
    float_type = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.FloatTensor
Ejemplo n.º 15
0
dir_train = args.dir_train
dir_test = args.dir_test
bs = args.batchsize
# Data Loader
# dataset_train = img_seg_ldr(data_dir=dir_train)
# train_loader = DataLoader(dataset_train, batch_size=bs, shuffle=True)
dataset_test = img_seg_ldr(data_dir=dir_test)
test_loader = DataLoader(dataset_test, batch_size=1, shuffle=True)
# Device identification
device = torch.device(
    'cuda' if torch.cuda.is_available() else 'cpu'
)  # Try to find out if the computer have a CUDA with Nivida GPU, else we will use CPU to work

# Model
if model == "unet":
    net = UNet(n_channels=3, n_classes=4).to(device)
if model == "unet3":
    net = UNet3(n_channels=3, n_classes=4).to(device)
if model == "unet2":
    net = UNet2(n_channels=3, n_classes=4).to(device)
if model == "resunet":
    net = Unet_Resnet(in_channels=3).to(device)
if model == "unetpp":
    net = UNetpp(in_ch=3, out_ch=4).to(device)
if model == "denseunet":
    net = FCDenseNet57(n_classes=4).to(device)

# Loss Function
criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(
    [1, 2, 2, 1])).type(torch.FloatTensor).to(device),
                                reduction='sum')
Ejemplo n.º 16
0
def main(
    num_epochs: int = 80,
    learning_rate: float = 1e-3,
    optimizer_type: str = "ADAM",
    loss: str = "BCE",
    use_scheduler: bool = True,
    milestones_scheduler: list = [50],
    gamma_scheduler: float = 0.1,
    batch_size: int = 32,
    dir_data_training: str = "../data/train",
    dir_data_validation: str = "../data/validation",
    prop_noPV_training: float = 0.5,
    min_rescale_images: float = 0.6,
    file_losses: str = "losses.txt",
    saving_frequency: int = 2,
    weight_for_positive_class: float = 1.0,
    save_model_parameters: bool = False,
    load_model_parameters: bool = False,
    dir_for_model_parameters: str = "../saved_models",
    filename_model_parameters_to_load: str = None,
    filename_model_parameters_to_save: str = None,
):
    """
    Main training function with tunable parameters.

    Parameters
    ----------
    num_epochs : int, optional
        Number of epochs to train. The default is 80.
    learning_rate : float, optional
        Learning rate of the Optimizer. The default is 1e-3.
    optimizer_type : str, optional
        Can be "ADAM" or "SGD". The default is "ADAM".
    loss : str, optional
        Cane be "BCE" of "L1". The default is "BCE".
    use_scheduler : bool
        If True, use a MultiStepLR. You should the next two parameters if used.
        The default is True.
    milestones_scheduler : list
        List of epochs at which to adapt the learning rate. The default is [50].
    gamma_scheduler : float
        Value by which to multiply the learning rate at each of the previously
        define milestone epochs. The default is 0.1.
    batch_size : int, optional
        Number of samples per batch in the Dataloaders. The default is 32.
    dir_data_training : str, optional
        Directory where the folders "images/", "labels/" and "noPV/" are for the training set.
    dir_data_validation : str, optional
        Directory where the folders "images/", "labels/" and "noPV/" are for the validation set.
    prop_noPV_training : float, optional
        Proportion noPV images to add compared to the total amount of PV images in the train set. The default is 0.5.
    min_rescale_images : float, optional
        Minimum proportion of the image to keep for the RandomResizedCrop transform.
        The default is 0.6.
    file_losses : str, optional
        Name of the files where to write the Train and test losses during training.
        The default is "losses.txt".
    saving_frequency : int, optional
        Frequency (in number of epochs) at which to write the train and
        test losses in the file.
        Small frequency is used if high risk that training might
        be interrupted to avoid too much lost data.
        The default is 2.
    weight_for_positive_class : float, optional
        Weight for the positive class in the Binary Cross entropy loss.
        The default is 1.0.
    save_model_parameters : bool, optional
        If True saves the model at the end of training. The default is False.
    load_model_parameters : bool, optional
        If True loads defined parameters in the model before training.
        The default is False.
    dir_for_model_parameters : str, optional
        Diretory where saved parameters are stored.
        The default is "../saved_models".
    filename_model_parameters_to_load : str, optional
        Filename of the parameters to load before training.
        Should be specified if load_model_parameters is True.
        The default is None.
    filename_model_parameters_to_save : str, optional
        Filename of the parameters to save after training.
        Should be defined is save_model_parameters is True.
        The default is None.

    Returns
    -------
    model : torch.nn.Module
        Model after training.
    avg_train_error : list of float
        List of Train errors or losses after each epoch.
    avg_validation_error : list of float
        List of Validation errors or losses after each epoch.

    """

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("GPU is {}available.".format(
        "" if torch.cuda.is_available() else "NOT "))

    # Instantiate the dataLoaders
    roof_dataloader_train, roof_dataloader_validation, roof_dataloader_test = load_data(
        prop_noPV_training,
        min_rescale_images,
        batch_size,
        dir_data_training,
        dir_data_validation,
    )

    if loss == "BCE":
        # Create Binary cross entropy loss weighted according to positive pixels.
        # pos_weight > 1 increases recall.
        # pos_weight < 1 increases precision.
        pos_weight = torch.tensor([weight_for_positive_class]).to(device)
        criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    elif loss == "L1":
        criterion = torch.nn.L1Loss()
    else:
        raise NotImplementedError(f"{loss} is not implemented.")

    model = UNet(n_channels=3, n_classes=1, bilinear=False)
    model = model.to(device)

    # If we're not starting from scratch
    if load_model_parameters:
        path_model_parameters_to_load = os.path.join(
            dir_for_model_parameters, filename_model_parameters_to_load)
        model.load_state_dict(torch.load(path_model_parameters_to_load))

    # If we're training or retraining a model
    if num_epochs > 0:
        if optimizer_type == "ADAM":
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        elif optimizer_type == "SGD":
            optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
        else:
            raise NotImplementedError(f"{optimizer} is not implemented.")
        scheduler = None
        if use_scheduler:
            scheduler = MultiStepLR(optimizer,
                                    milestones=milestones_scheduler,
                                    gamma=gamma_scheduler)

        avg_train_error, avg_validation_error = train(
            model,
            criterion,
            roof_dataloader_train,
            roof_dataloader_validation,
            optimizer,
            use_scheduler,
            scheduler,
            num_epochs,
            device,
            file_losses,
            saving_frequency,
        )

        if save_model_parameters:
            path_model_parameters_to_save = os.path.join(
                dir_for_model_parameters, filename_model_parameters_to_save)
            torch.save(model.state_dict(), path_model_parameters_to_save)

    print(avg_train_error, avg_validation_error)

    return model, avg_train_error, avg_validation_error
Ejemplo n.º 17
0

if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    patch_size = 44
    batch_size = 64
    num_class = 13
    epochs = 250
    # 加载数据集
    data_dir = "/home/cym/Datasets/StData-12/F3_block/"
    dataset = F3DS(data_dir, ptsize=patch_size, train=True)
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_size=batch_size, 
                                               shuffle=True)

    # 加载网络,图片单通道1,分类为13。
    net = UNet(n_channels=1, n_classes=num_class, bilinear=False)
    net.apply(weight_init)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 指定训练集地址,开始训练
    writer = SummaryWriter('./logs2')
    
    train_net(net, train_loader, device, writer, epochs, save_dir="./models2")



Ejemplo n.º 18
0
def set_model(sample,
              device,
              args,
              train=True,
              experiment_name=None,
              finetune=False,
              model_path=None):
    # Create a model and criterion
    num_classes = 2

    if 'unet' in args.model:
        if args.model == 'unet':
            model = UNet(n_channels=sample.shape[0],
                         n_classes=num_classes,
                         n_blocks=args.n_blocks,
                         start_channels=args.start_channels,
                         pos_loc=args.pos_loc,
                         pos_dim=args.pos_dim,
                         bilinear=args.bilinear,
                         batch_size=args.batch_size)

        elif args.model == 'attn_unet':
            model = AttentionUNet(n_channels=sample.shape[0],
                                  n_classes=num_classes,
                                  n_blocks=args.n_blocks,
                                  start_channels=args.start_channels,
                                  pos_loc=args.pos_loc,
                                  pos_dim=args.pos_dim,
                                  bilinear=args.bilinear,
                                  batch_size=args.batch_size)

        elif args.model == 'suc_unet':
            model = SuccessiveUNet(n_channels=sample.shape[0],
                                   n_classes=num_classes,
                                   n_blocks=args.n_blocks,
                                   start_channels=args.start_channels,
                                   pos_loc=args.pos_loc,
                                   pos_dim=args.pos_dim,
                                   bilinear=args.bilinear,
                                   batch_size=args.batch_size)

        criterion = NIMSCrossEntropyLoss(args=args,
                                         device=device,
                                         num_classes=num_classes,
                                         use_weights=args.cross_entropy_weight,
                                         experiment_name=experiment_name)

    elif args.model == 'convlstm':
        # assert args.window_size == args.target_num, \
        #        'window_size and target_num must be same for ConvLSTM'

        model = EncoderForecaster(input_channels=sample.shape[1],
                                  hidden_dim=args.hidden_dim,
                                  num_classes=num_classes)
        # criterion = MSELoss()
        criterion = NIMSCrossEntropyLoss(args=args,
                                         device=device,
                                         num_classes=num_classes,
                                         use_weights=args.cross_entropy_weight,
                                         experiment_name=experiment_name)

    if finetune:
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint, strict=False)

    # model = DataParallel(model)
    return model, criterion
Ejemplo n.º 19
0
import argparse
import os
import torch
from model.unet_model import UNet
from train import train

torch.cuda.set_device(0)
parser = argparse.ArgumentParser()
unet = UNet()
use_CUDA = True

if __name__ == "__main__":

    parser.add_argument("-v", "--visual", action="store_true")
    parser.add_argument("-l", "--lr", type=float, default=1e-5)
    parser.add_argument("-e", "--epochs", type=int, default=10)
    parser.add_argument("-b", "--batch", type=int, default=40)
    parser.add_argument("-r", "--retrain", type=bool, default=False)

    args = parser.parse_args()

    if args.visual:
        from visual import show_pred_mask
        from loader import train_loader

        trl = train_loader(1, shuffle=True)
        img, msk = next(trl)
        unet.load_state_dict(torch.load("unet.pkl"))
        show_pred_mask(unet, img, msk)

    else:
Ejemplo n.º 20
0
class Meta(nn.Module):
    """
    Meta Learner
    """
    def __init__(self, args, writer, device):
        """

        :param args:
        """
        super(Meta, self).__init__()
        self.device = device
        self.update_lr = args.update_lr
        self.meta_lr = args.meta_lr
        self.n_way = args.n_way
        self.k_spt = args.k_spt
        self.k_qry = args.k_qry
        self.task_num = args.task_num
        self.update_step = args.update_step
        self.update_step_test = args.update_step_test
        self.writer = writer
        self.in_channels = args.imgc
        self.out_channels = args.output_channel
        self.epsilon = 1e-10

        self.net = UNet(in_channels=args.imgc,
                        out_channels=args.output_channel)
        self.net_param = []
        for param in self.net.parameters():
            self.net_param.append(param.clone().data.to(device))

        logging.info(f'Network:\n'
                     f'\t{self.net.in_channels} input channels\n'
                     f'\t{self.net.out_channels} output channels (classes)')

        if args.load:
            net.load_state_dict(torch.load(args.load, map_location=device))
            logging.info(f'Model loaded from {args.load}')

        # define loss
        self.mse_loss_fn = torch.nn.MSELoss(reduction='none')
        self.ssim_loss_fc = pytorch_msssim.SSIM(window_size=7)
        self.contour_loss = Contour_loss(K=5)

        # define optimizer for meta learning
        if args.optimizer == "adam":
            self.meta_optim = optim.Adam(self.net.parameters(),
                                         lr=self.meta_lr)
        elif args.optimizer == "rmsprop":
            self.meta_optim = optim.RMSprop(self.net.parameters(),
                                            lr=self.meta_lr,
                                            weight_decay=self.weight_decay)
        else:
            raise ValueError("Wrong Optimzer !")

    def forward(self, sun_imgs_x_spt, sun_imgs_y_spt, sun_imgs_x_qry,
                sun_imgs_y_qry, step):
        """

        :param x_spt:   [b, setsz, c_, h, w]
        :param y_spt:   [b, setsz]
        :param x_qry:   [b, querysz, c_, h, w]
        :param y_qry:   [b, querysz]
        :return:
        """
        task_num, setsz, c_, h, w = sun_imgs_x_spt.size()
        querysz = sun_imgs_x_qry.size(1)

        loss_q = [0 for _ in range(self.update_step + 1)
                  ]  # losses_q[i] is the loss on step i

        # set epsilon
        epsilon = 1e-10
        alpha = 0.12
        self.net = self.net.to(self.device)
        self.net.train()
        for i in range(task_num):

            reflection_pred = self.net(sun_imgs_x_spt[i])
            reflection_pred = reflection_pred + epsilon
            reflection_pred = torch.clamp(
                reflection_pred, 0.1,
                5.0)  # Peter: may be we can try to remove this item
            restoration_imgs_pred = torch.zeros(
                *sun_imgs_x_spt[i][:, :, :, :].shape).to(self.device)
            restoration_imgs_pred[:, 0, :, :] = torch.div(
                sun_imgs_x_spt[i][:, 0, :, :], reflection_pred[:, 0, :, :])
            restoration_imgs_pred[:, 1, :, :] = torch.div(
                sun_imgs_x_spt[i][:, 1, :, :], reflection_pred[:, 0, :, :])
            restoration_imgs_pred[:, 2, :, :] = torch.div(
                sun_imgs_x_spt[i][:, 2, :, :], reflection_pred[:, 0, :, :])

            weight_map = self.contour_loss.forward(sun_imgs_y_spt[i])
            restoration_imgs_pred = torch.mul(restoration_imgs_pred,
                                              weight_map)
            gt_imgs = torch.mul(sun_imgs_y_spt[i], weight_map)

            mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs)
            ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs)
            loss = torch.mean(mse_loss) + alpha * (
                1 - torch.mean(ssim_loss))  #+ REGULARIZATION * reg_loss

            grad = torch.autograd.grad(loss, self.net.parameters())
            fast_weights = list(
                map(lambda p: p[1] - self.update_lr * p[0],
                    zip(grad, self.net.parameters())))

            # =====================================================

            reflection_pred = self.net(sun_imgs_x_qry[i])
            reflection_pred = reflection_pred + epsilon
            reflection_pred = torch.clamp(
                reflection_pred, 0.1,
                5.0)  # Peter: may be we can try to remove this item
            restoration_imgs_pred = torch.zeros(
                *sun_imgs_x_qry[i][:, :, :, :].shape).to(self.device)
            restoration_imgs_pred[:, 0, :, :] = torch.div(
                sun_imgs_x_qry[i][:, 0, :, :], reflection_pred[:, 0, :, :])
            restoration_imgs_pred[:, 1, :, :] = torch.div(
                sun_imgs_x_qry[i][:, 1, :, :], reflection_pred[:, 0, :, :])
            restoration_imgs_pred[:, 2, :, :] = torch.div(
                sun_imgs_x_qry[i][:, 2, :, :], reflection_pred[:, 0, :, :])

            weight_map = self.contour_loss.forward(sun_imgs_y_qry[i])
            restoration_imgs_pred = torch.mul(restoration_imgs_pred,
                                              weight_map)
            gt_imgs = torch.mul(sun_imgs_y_qry[i], weight_map)

            mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs)
            ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs)
            loss = torch.mean(mse_loss) + alpha * (
                1 - torch.mean(ssim_loss))  #+ REGULARIZATION * reg_loss
            loss_q[0] += loss

            for each_fastweight, (name,
                                  param) in zip(fast_weights,
                                                self.net.named_parameters()):
                param.data = each_fastweight.data

            # =====================================================

            reflection_pred = self.net(sun_imgs_x_qry[i])
            reflection_pred = reflection_pred + epsilon
            reflection_pred = torch.clamp(
                reflection_pred, 0.1,
                5.0)  # Peter: may be we can try to remove this item
            restoration_imgs_pred = torch.zeros(
                *sun_imgs_x_qry[i][:, :, :, :].shape).to(self.device)
            restoration_imgs_pred[:, 0, :, :] = torch.div(
                sun_imgs_x_qry[i][:, 0, :, :], reflection_pred[:, 0, :, :])
            restoration_imgs_pred[:, 1, :, :] = torch.div(
                sun_imgs_x_qry[i][:, 1, :, :], reflection_pred[:, 0, :, :])
            restoration_imgs_pred[:, 2, :, :] = torch.div(
                sun_imgs_x_qry[i][:, 2, :, :], reflection_pred[:, 0, :, :])

            weight_map = self.contour_loss.forward(sun_imgs_y_qry[i])
            restoration_imgs_pred = torch.mul(restoration_imgs_pred,
                                              weight_map)
            gt_imgs = torch.mul(sun_imgs_y_qry[i], weight_map)

            mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs)
            ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs)
            loss = torch.mean(mse_loss) + alpha * (
                1 - torch.mean(ssim_loss))  #+ REGULARIZATION * reg_loss
            loss_q[1] += loss

            del grad, fast_weights, loss, mse_loss, ssim_loss, weight_map

            # =====================================================

            for k in range(1, self.update_step):
                reflection_pred = self.net(sun_imgs_x_spt[i])
                reflection_pred = reflection_pred + epsilon
                reflection_pred = torch.clamp(
                    reflection_pred, 0.1,
                    5.0)  # Peter: may be we can try to remove this item
                restoration_imgs_pred = torch.zeros(
                    *sun_imgs_x_spt[i][:, :, :, :].shape).to(self.device)
                restoration_imgs_pred[:, 0, :, :] = torch.div(
                    sun_imgs_x_spt[i][:, 0, :, :], reflection_pred[:, 0, :, :])
                restoration_imgs_pred[:, 1, :, :] = torch.div(
                    sun_imgs_x_spt[i][:, 1, :, :], reflection_pred[:, 0, :, :])
                restoration_imgs_pred[:, 2, :, :] = torch.div(
                    sun_imgs_x_spt[i][:, 2, :, :], reflection_pred[:, 0, :, :])

                weight_map = self.contour_loss.forward(sun_imgs_y_spt[i])
                restoration_imgs_pred = torch.mul(restoration_imgs_pred,
                                                  weight_map)
                gt_imgs = torch.mul(sun_imgs_y_spt[i], weight_map)

                mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs)
                ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs)
                loss = torch.mean(mse_loss) + alpha * (
                    1 - torch.mean(ssim_loss))  #+ REGULARIZATION * reg_loss

                grad = torch.autograd.grad(loss, self.net.parameters())
                fast_weights = list(
                    map(lambda p: p[1] - self.update_lr * p[0],
                        zip(grad, self.net.parameters())))

                for each_fastweight, (name, param) in zip(
                        fast_weights, self.net.named_parameters()):
                    param.data = each_fastweight.data

                reflection_pred = self.net(sun_imgs_x_qry[i])
                reflection_pred = reflection_pred + epsilon
                reflection_pred = torch.clamp(
                    reflection_pred, 0.1,
                    5.0)  # Peter: may be we can try to remove this item
                restoration_imgs_pred = torch.zeros(
                    *sun_imgs_x_qry[i][:, :, :, :].shape).to(self.device)
                restoration_imgs_pred[:, 0, :, :] = torch.div(
                    sun_imgs_x_qry[i][:, 0, :, :], reflection_pred[:, 0, :, :])
                restoration_imgs_pred[:, 1, :, :] = torch.div(
                    sun_imgs_x_qry[i][:, 1, :, :], reflection_pred[:, 0, :, :])
                restoration_imgs_pred[:, 2, :, :] = torch.div(
                    sun_imgs_x_qry[i][:, 2, :, :], reflection_pred[:, 0, :, :])

                weight_map = self.contour_loss.forward(sun_imgs_y_qry[i])
                restoration_imgs_pred = torch.mul(restoration_imgs_pred,
                                                  weight_map)
                gt_imgs = torch.mul(sun_imgs_y_qry[i], weight_map)

                mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs)
                ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs)
                loss = torch.mean(mse_loss) + alpha * (
                    1 - torch.mean(ssim_loss))  #+ REGULARIZATION * reg_loss
                loss_q[k + 1] += loss

                del grad, fast_weights, loss, mse_loss, ssim_loss, weight_map

        for net_param, (name, param) in zip(self.net_param,
                                            self.net.named_parameters()):
            param.data = net_param

        # end of all tasks
        # sum over all losses on query set across all tasks
        loss_q = torch.sum(torch.stack(loss_q)) / task_num

        # optimize theta parameters
        self.meta_optim.zero_grad()
        loss_q.backward()

        # optimize
        self.meta_optim.step()

        self.net_param = []
        for param in self.net.parameters():
            self.net_param.append(param.clone().data)

        torch.cuda.empty_cache()

        return loss_q

    def finetunning(self, sun_imgs_x_spt, sun_imgs_y_spt, sun_imgs_x_qry,
                    sun_imgs_y_qry, step):
        """

        :param x_spt:   [setsz, c_, h, w]
        :param y_spt:   [setsz]
        :param x_qry:   [setsz, c_, h, w]
        :param y_qry:   [querysz]
        :return:
        """
        setsz, c_, h, w = sun_imgs_x_spt.size()
        querysz = sun_imgs_x_qry.size(1)

        loss_q = [0 for _ in range(self.update_step + 1)
                  ]  # losses_q[i] is the loss on step i

        # set epsilon
        epsilon = 1e-10
        alpha = 0.12
        self.net = self.net.to(self.device)
        self.net.train()

        # 1. run the i-th task and compute loss for k=0
        reflection_pred = self.net(sun_imgs_x_spt)
        reflection_pred = reflection_pred + epsilon
        reflection_pred = torch.clamp(
            reflection_pred, 0.1,
            5.0)  # Peter: may be we can try to remove this item
        restoration_imgs_pred = torch.zeros(
            *sun_imgs_x_spt[:, :, :, :].shape).to(self.device)
        restoration_imgs_pred[:,
                              0, :, :] = torch.div(sun_imgs_x_spt[:, 0, :, :],
                                                   reflection_pred[:, 0, :, :])
        restoration_imgs_pred[:,
                              1, :, :] = torch.div(sun_imgs_x_spt[:, 1, :, :],
                                                   reflection_pred[:, 0, :, :])
        restoration_imgs_pred[:,
                              2, :, :] = torch.div(sun_imgs_x_spt[:, 2, :, :],
                                                   reflection_pred[:, 0, :, :])

        weight_map = self.contour_loss.forward(sun_imgs_y_spt)
        restoration_imgs_pred = torch.mul(restoration_imgs_pred, weight_map)
        gt_imgs = torch.mul(sun_imgs_y_spt, weight_map)

        mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs)
        ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs)
        loss = torch.mean(mse_loss) + alpha * (1 - torch.mean(ssim_loss)
                                               )  #+ REGULARIZATION * reg_loss

        grad = torch.autograd.grad(loss, self.net.parameters())
        fast_weights = list(
            map(lambda p: p[1] - self.update_lr * p[0],
                zip(grad, self.net.parameters())))

        with torch.no_grad():

            # 1. run the i-th task and compute loss for k=0
            reflection_pred = self.net(sun_imgs_x_qry)
            reflection_pred = reflection_pred + epsilon
            reflection_pred = torch.clamp(
                reflection_pred, 0.1,
                5.0)  # Peter: may be we can try to remove this item
            restoration_imgs_pred = torch.zeros(
                *sun_imgs_x_qry[:, :, :, :].shape).to(self.device)
            restoration_imgs_pred[:, 0, :, :] = torch.div(
                sun_imgs_x_qry[:, 0, :, :], reflection_pred[:, 0, :, :])
            restoration_imgs_pred[:, 1, :, :] = torch.div(
                sun_imgs_x_qry[:, 1, :, :], reflection_pred[:, 0, :, :])
            restoration_imgs_pred[:, 2, :, :] = torch.div(
                sun_imgs_x_qry[:, 2, :, :], reflection_pred[:, 0, :, :])

            weight_map = self.contour_loss.forward(sun_imgs_y_qry)
            restoration_imgs_pred = torch.mul(restoration_imgs_pred,
                                              weight_map)
            gt_imgs = torch.mul(sun_imgs_y_qry, weight_map)

            mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs)
            ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs)
            loss = torch.mean(mse_loss) + alpha * (
                1 - torch.mean(ssim_loss))  #+ REGULARIZATION * reg_loss
            loss_q[0] += loss

        for each_fastweight, (name, param) in zip(fast_weights,
                                                  self.net.named_parameters()):
            param = each_fastweight

        with torch.no_grad():

            # 1. run the i-th task and compute loss for k=0
            # not use original net, use copy one
            reflection_pred = self.net(sun_imgs_x_qry)
            reflection_pred = reflection_pred + epsilon
            reflection_pred = torch.clamp(
                reflection_pred, 0.1,
                5.0)  # Peter: may be we can try to remove this item
            restoration_imgs_pred = torch.zeros(
                *sun_imgs_x_qry[:, :, :, :].shape).to(self.device)
            restoration_imgs_pred[:, 0, :, :] = torch.div(
                sun_imgs_x_qry[:, 0, :, :], reflection_pred[:, 0, :, :])
            restoration_imgs_pred[:, 1, :, :] = torch.div(
                sun_imgs_x_qry[:, 1, :, :], reflection_pred[:, 0, :, :])
            restoration_imgs_pred[:, 2, :, :] = torch.div(
                sun_imgs_x_qry[:, 2, :, :], reflection_pred[:, 0, :, :])

            weight_map = self.contour_loss.forward(sun_imgs_y_qry)
            restoration_imgs_pred = torch.mul(restoration_imgs_pred,
                                              weight_map)
            gt_imgs = torch.mul(sun_imgs_y_qry, weight_map)

            mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs)
            ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs)
            loss = torch.mean(mse_loss) + alpha * (
                1 - torch.mean(ssim_loss))  #+ REGULARIZATION * reg_loss
            loss_q[1] += loss

        del grad, fast_weights, loss, mse_loss, ssim_loss, weight_map

        for k in range(1, self.update_step_test):

            reflection_pred = self.net(sun_imgs_x_spt)
            reflection_pred = reflection_pred + epsilon
            reflection_pred = torch.clamp(
                reflection_pred, 0.1,
                5.0)  # Peter: may be we can try to remove this item
            restoration_imgs_pred = torch.zeros(
                *sun_imgs_x_spt[:, :, :, :].shape).to(self.device)
            restoration_imgs_pred[:, 0, :, :] = torch.div(
                sun_imgs_x_spt[:, 0, :, :], reflection_pred[:, 0, :, :])
            restoration_imgs_pred[:, 1, :, :] = torch.div(
                sun_imgs_x_spt[:, 1, :, :], reflection_pred[:, 0, :, :])
            restoration_imgs_pred[:, 2, :, :] = torch.div(
                sun_imgs_x_spt[:, 2, :, :], reflection_pred[:, 0, :, :])
            restoration_imgs_pred = torch.clamp(reflection_pred, 0, 1)

            weight_map = self.contour_loss.forward(sun_imgs_y_spt)
            restoration_imgs_pred = torch.mul(restoration_imgs_pred,
                                              weight_map)
            gt_imgs = torch.mul(sun_imgs_y_spt, weight_map)

            mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs)
            ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs)
            loss = torch.mean(mse_loss) + alpha * (
                1 - torch.mean(ssim_loss))  #+ REGULARIZATION * reg_loss

            grad = torch.autograd.grad(loss, self.parameters())
            fast_weights = list(
                map(lambda p: p[1] - self.update_lr * p[0],
                    zip(grad, self.net.parameters())))

            # this is the loss and accuracy before first update
            for each_fastweight, (name,
                                  param) in zip(fast_weights,
                                                self.net.named_parameters()):
                param = each_fastweight

            with torch.no_grad():
                reflection_pred = self.net(sun_imgs_x_qry)
                reflection_pred = reflection_pred + epsilon
                reflection_pred = torch.clamp(
                    reflection_pred, 0.1,
                    5.0)  # Peter: may be we can try to remove this item
                restoration_imgs_pred = torch.zeros(
                    *sun_imgs_x_qry[:, :, :, :].shape).to(self.device)
                restoration_imgs_pred[:, 0, :, :] = torch.div(
                    sun_imgs_x_qry[:, 0, :, :], reflection_pred[:, 0, :, :])
                restoration_imgs_pred[:, 1, :, :] = torch.div(
                    sun_imgs_x_qry[:, 1, :, :], reflection_pred[:, 0, :, :])
                restoration_imgs_pred[:, 2, :, :] = torch.div(
                    sun_imgs_x_qry[:, 2, :, :], reflection_pred[:, 0, :, :])
                restoration_imgs_pred = torch.clamp(reflection_pred, 0, 1)

                weight_map = self.contour_loss.forward(sun_imgs_y_qry)
                restoration_imgs_pred = torch.mul(restoration_imgs_pred,
                                                  weight_map)
                gt_imgs = torch.mul(sun_imgs_y_qry, weight_map)

                mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs)
                ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs)
                loss = torch.mean(mse_loss) + alpha * (
                    1 - torch.mean(ssim_loss))  #+ REGULARIZATION * reg_loss
                loss_q[k + 1] += loss

            del grad, fast_weights, loss, mse_loss, ssim_loss, weight_map

        # this is the loss and accuracy before first update
        for net_param, (name, param) in zip(self.net_param,
                                            self.net.named_parameters()):
            param = net_param

        # end of all tasks
        # sum over all losses on query set across all tasks
        loss_q = torch.sum(torch.stack(loss_q))

        torch.cuda.empty_cache()

        return loss_q