Exemplo n.º 1
0
        adjust_learning_rate(optimizer,epoch)

        for batch_id,train_data in enumerate(train_dataloader):
            cloud,gt=train_data
            optimizer.zero_grad()
            cloud=cloud.to(device)
            gt=gt.to(device)
            gt_quarter_1=F.interpolate(gt,scale_factor=0.25,recompute_scale_factor=True)
            gt_quarter_2=F.interpolate(gt,scale_factor=0.25,recompute_scale_factor=True)

            if train_phrase==1:
                decloud_1,feat_extra_1=net(cloud)
                rec_loss1=loss_rec1(decloud_1,gt)
                perceptual_loss=loss_network(decloud_1,gt)
                lap_loss=loss_lap(decloud_1,gt)
                psnr=to_psnr(decloud_1,gt)
                psnr_list.extend(psnr)

            if train_phrase==2:
                decloud_1,feat_extra_1=net(cloud)
                decloud_2,feat_extra_2=G2(decloud_1)
                rec_loss1=(loss_rec1(decloud_2,gt)+loss_rec1(decloud_1,gt))/2.0
                rec_loss2=loss_rec2(decloud_2,gt)
                perceptual_loss=loss_network(decloud_2,gt)
                lap_loss=loss_lap(decloud_2,gt)
                psnr=to_psnr(decloud_2,gt)
                psnr_list.extend(psnr)

            if train_phrase==3:
                decloud_1,feat_extra_1=net(F.interpolate(cloud,scale_factor=0.25,recompute_scale_factor=True))
                decloud_2,feat,feat_extra_2=G2(decloud_1)
Exemplo n.º 2
0
        # --- Zero the parameter gradients --- #
        optimizer.zero_grad()

        # --- Forward + Backward + Optimize --- #
        net.train()
        dehaze = net(haze)

        smooth_loss = F.smooth_l1_loss(dehaze, gt)
        perceptual_loss = loss_network(dehaze, gt)
        loss = smooth_loss + lambda_loss*perceptual_loss

        loss.backward()
        optimizer.step()

        # --- To calculate average PSNR --- #
        psnr_list.extend(to_psnr(dehaze, gt))

        if not (batch_id % 100):
            print('Epoch: {0}, Iteration: {1}'.format(epoch, batch_id))

    # --- Calculate the average training PSNR in one epoch --- #
    train_psnr = sum(psnr_list) / len(psnr_list)

    # --- Save the network parameters --- #
    torch.save(net.state_dict(), '{}_haze_{}_{}'.format(category, network_height, network_width))

    # --- Use the evaluation model in testing --- #
    net.eval()

    val_psnr, val_ssim = validation(net, val_data_loader, device, category)
    one_epoch_time = time.time() - start_time
Exemplo n.º 3
0
        # --- Forward + Backward + Optimize --- #
        net.train()
        pred_image, zy_in = net(input_image)

        smooth_loss = F.smooth_l1_loss(pred_image, gt)
        perceptual_loss = loss_network(pred_image, gt)
        gp_loss = 0
        if lambgp != 0 and use_GP_inlblphase == True:
            gp_loss = gp_struct.compute_gploss(zy_in, imgid, batch_id, 1)
        loss = smooth_loss + lambda_loss * perceptual_loss + lambgp * gp_loss

        loss.backward()
        optimizer.step()

        # --- To calculate average PSNR --- #
        psnr_list.extend(to_psnr(pred_image, gt))

        if not (batch_id % 100):
            print('Epoch: {0}, Iteration: {1}'.format(epoch, batch_id))

    # --- Calculate the average training PSNR in one epoch --- #
    train_psnr = sum(psnr_list) / len(psnr_list)

    # --- Save the network parameters --- #
    torch.save(net.state_dict(), './{}/{}'.format(exp_name, category))

    # --- Use the evaluation model in testing --- #
    net.eval()

    val_psnr, val_ssim = validation(net, val_data_loader, device, category,
                                    exp_name)
Exemplo n.º 4
0
                total_loss = smooth_loss + lambda_loss * perceptual_loss
            else:
                for i in range(args.levels):
                    _, _, hi, wi = dehaze[i].size()
                    gt_img = F.interpolate(gt, size=[hi, wi])
                    smooth_loss = F.smooth_l1_loss(dehaze[i], gt_img)
                    perceptual_loss = loss_network(dehaze[i], gt_img)
                    loss.append(smooth_loss + lambda_loss * perceptual_loss)
                    total_loss += loss[i]

            total_loss.backward()
            optimizer.step()

            if not (batch_id % 400):
                print('Epoch: {0}, Iteration: {1}'.format(epoch, batch_id))
                print('total_loss = {0}'.format(total_loss))
                print("PSNR: ", end=" ")
                print('coarse: ', to_psnr(coarse_out, gt))
                print('fine : ', to_psnr(dehaze, gt))

    # --- Calculate the average training PSNR in one epoch --- #
    epoch += 1
    if epoch % 10 == 0:
        state = {'net': net.state_dict(), 'epoch': epoch}
        if fine_share:
            torch.save(state,
                       './checkpoint/checkpoint_K12_epoch_{}'.format(epoch))
        else:
            torch.save(state,
                       './checkpoint/checkpoint_K12_epoch_{}'.format(epoch))
        # --- Forward + Backward + Optimize --- #
        net.train()
        _, J, T, A, I = net(haze)
        #s, v = get_SV_from_HSV(J)
        #CAP_loss = F.smooth_l1_loss(s, v)
        Rec_Loss1 = F.smooth_l1_loss(J, gt)
        Rec_Loss2 = F.smooth_l1_loss(I, haze)

        #perceptual_loss = loss_network(dehaze, gt)
        loss = Rec_Loss1 + Rec_Loss2

        loss.backward()
        optimizer.step()

        # --- To calculate average PSNR --- #
        psnr_list.extend(to_psnr(J, gt))

        #if not (batch_id % 100):
        print(
            'Epoch: {}, Iteration: {}, Loss: {}, Rec_Loss1: {}, Rec_loss2: {}'.
            format(epoch, batch_id, loss, Rec_Loss1, Rec_Loss2))

    # --- Calculate the average training PSNR in one epoch --- #
    train_psnr = sum(psnr_list) / len(psnr_list)

    # --- Save the network parameters --- #
    torch.save(net.state_dict(), '/output/haze_current{}'.format(epoch))

    # --- Use the evaluation model in testing --- #
    net.eval()
Exemplo n.º 6
0
def train():
    device_ids = [0]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("CUDA visible devices: " + str(torch.cuda.device_count()))
    print("CUDA Device Name: " + str(torch.cuda.get_device_name(device)))

    # Initialize loss and model
    loss = ms_Loss().to(device)
    net = AWNet(4, 3, block=[3, 3, 3, 4, 4]).to(device)
    net = nn.DataParallel(net, device_ids=device_ids)
    new_lr = trainConfig.learning_rate[0]

    # Reload
    if trainConfig.pretrain == True:
        net.load_state_dict(
            torch.load(
                '{}/best_4channel.pkl'.format(trainConfig.save_best),
                map_location=device)["model_state"])
        print('weight loaded.')
    else:
        print('no weight loaded.')
    pytorch_total_params = sum(
        p.numel() for p in net.parameters() if p.requires_grad)
    print("Total_params: {}".format(pytorch_total_params))

    # optimizer and scheduler
    optimizer = torch.optim.Adam(
        net.parameters(), lr=new_lr, betas=(0.9, 0.999))

    # Dataloaders
    train_dataset = LoadData(
        trainConfig.data_dir, TRAIN_SIZE, dslr_scale=1, test=False)
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=trainConfig.batch_size,
        shuffle=True,
        num_workers=32,
        pin_memory=True,
        drop_last=True)

    test_dataset = LoadData(
        trainConfig.data_dir, TEST_SIZE, dslr_scale=1, test=True)
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=8,
        shuffle=False,
        num_workers=18,
        pin_memory=True,
        drop_last=False)

    print('Train loader length: {}'.format(len(train_loader)))

    pre_psnr, pre_ssim = validation(net, test_loader, device, save_tag=True)
    print('previous PSNR: {:.4f}, previous ssim: {:.4f}'.format(
        pre_psnr, pre_ssim))
    iteration = 0
    for epoch in range(trainConfig.epoch):
        psnr_list = []
        start_time = time.time()
        if epoch > 0:
            new_lr = adjust_learning_rate_step(
                optimizer, epoch, trainConfig.epoch, trainConfig.learning_rate)
        for batch_id, data in enumerate(train_loader):
            x, target, _ = data
            x = x.to(device)
            target = target.to(device)
            pred, _ = net(x)

            optimizer.zero_grad()

            total_loss, losses = loss(pred, target)
            total_loss.backward()
            optimizer.step()

            iteration += 1
            if trainConfig.print_loss:
                print("epoch:{}/{} | Loss: {:.4f} ".format(
                    epoch, trainConfig.epoch, total_loss.item()))
            if not (batch_id % 1000):
                print('Epoch:{0}, Iteration:{1}'.format(epoch, batch_id))

            psnr_list.extend(to_psnr(pred[0], target))

        train_psnr = sum(psnr_list) / len(psnr_list)
        state = {
            "model_state": net.state_dict(),
            "lr": new_lr,
        }
        print('saved checkpoint')
        torch.save(state, '{}/four_channel_epoch_{}.pkl'.format(
            trainConfig.checkpoints, epoch))

        one_epoch_time = time.time() - start_time
        print('time: {}, train psnr: {}'.format(one_epoch_time, train_psnr))
        val_psnr, val_ssim = validation(
            net, test_loader, device, save_tag=True)
        print_log(epoch + 1, trainConfig.epoch, one_epoch_time, train_psnr,
                  val_psnr, val_ssim, 'multi_loss')

        if val_psnr >= pre_psnr:
            state = {
                "model_state": net.state_dict(),
                "lr": new_lr,
            }

            print('saved best weight')
            torch.save(state, '{}/best_4channel.pkl'.format(
                trainConfig.save_best))
            pre_psnr = val_psnr
Exemplo n.º 7
0
        haze_gt = haze_gt.to(device)
        gt_quarter_1 = F.interpolate(gt,
                                     scale_factor=0.25,
                                     recompute_scale_factor=True)
        gt_quarter_2 = F.interpolate(gt,
                                     scale_factor=0.25,
                                     recompute_scale_factor=True)

        # --- Forward + Backward + Optimize --- #

        if train_phrase == 1:
            dehaze_1, feat_extra_1 = net(haze)
            rec_loss1 = loss_rec1(dehaze_1, gt)
            perceptual_loss = loss_network(dehaze_1, gt)
            lap_loss = loss_lap(dehaze_1, gt)
            psnr = to_psnr(dehaze_1, gt)
            psnr_list.extend(to_psnr(dehaze_1, gt))
            train_info = to_psnr(dehaze_1, gt)
        if train_phrase == 2:
            dehaze_1, feat_extra_1 = net(haze)
            dehaze_2, feat, feat_extra_2 = G2(dehaze_1)
            rec_loss1 = (loss_rec1(dehaze_2, gt) +
                         loss_rec1(dehaze_1, gt)) / 2.0
            rec_loss2 = loss_rec2(dehaze_2, gt)
            perceptual_loss = loss_network(dehaze_2, gt)
            lap_loss = loss_lap(dehaze_2, gt)
            psnr = to_psnr(dehaze_2, gt)
            psnr_list.extend(to_psnr(dehaze_2, gt))
            train_info = to_psnr(dehaze_2, gt)
        if train_phrase == 3:
            dehaze_1, feat_extra_1 = net(
Exemplo n.º 8
0
def main(test_img, test_gt, test_phrase, test_epoch):
    device_ids = [Id for Id in range(torch.cuda.device_count())]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    test_data_dir = test_img
    test_data_gt = test_gt
    test_batch_size = 1
    network_height = 3
    network_width = 6
    num_dense_layer = 4
    growth_rate = 16
    test_phrase = test_phrase
    crop_size = [1600, 1200]
    test_data = data(test_data_dir, test_data_gt)
    test_data_loader = DataLoader(test_data, batch_size=test_batch_size)

    def save_image(dehaze, image_name, category):
        #dehaze_images = torch.split(dehaze, 1, dim=0)
        batch_num = len(dehaze)

        for ind in range(batch_num):
            utils.save_image(
                dehaze[ind],
                '{}_results/{}'.format(category, image_name[ind][:-3] + 'png'))

    if test_phrase == 1:
        G1 = Generate_quarter(height=network_height,
                              width=network_width,
                              num_dense_layer=num_dense_layer,
                              growth_rate=growth_rate)
        G1 = G1.to(device)
        G1 = nn.DataParallel(G1, device_ids=device_ids)
        G1.load_state_dict(
            torch.load('./checkpoint/1_' + str(test_epoch) + '.tar'))
        G1.eval()
        psnr = []
        net_time = 0.
        net_count = 0.
        for batch_id, test_data in enumerate(test_data_loader):
            with torch.no_grad():
                haze, gt, image_name = test_data
                #haze = F.interpolate(haze, scale_factor = 0.25)
                haze = haze.to(device)
                gt = gt.to(device)
                start_time = time.time()
                dehaze, _ = G1(haze)
                end_time = time.time() - start_time
                net_time += end_time
                net_count += 1
                test_info = to_psnr_test(dehaze, gt)
                psnr.append(sum(test_info) / len(test_info))
                print(sum(test_info) / len(test_info))
        # --- Save image --- #
            save_image(dehaze, image_name, 'NH')
        test_psnr = sum(psnr) / len(psnr)
        print('Test PSNR:' + str(test_psnr))
        print('net time is {0:.4f}'.format(net_time / net_count))

    if test_phrase == 2:
        G1 = Generate_quarter(height=network_height,
                              width=network_width,
                              num_dense_layer=num_dense_layer,
                              growth_rate=growth_rate)
        G1 = G1.to(device)
        G1 = nn.DataParallel(G1, device_ids=device_ids)
        #G1.load_state_dict(torch.load('./checkpoint/1.tar'))
        G2 = Generate_quarter_refine(height=network_height,
                                     width=network_width,
                                     num_dense_layer=num_dense_layer,
                                     growth_rate=growth_rate)
        G2 = G2.to(device)
        G2 = nn.DataParallel(G2, device_ids=device_ids)
        G1.load_state_dict(
            torch.load('./checkpoint/2-' + str(test_epoch) + '_G1.tar'))

        G2.load_state_dict(
            torch.load('./checkpoint/2_' + str(test_epoch) + '_G2.tar'))
        G1.eval()
        G2.eval()
        psnr = []
        net_time = 0.
        net_count = 0.
        for batch_id, test_data in enumerate(test_data_loader):
            with torch.no_grad():
                haze, gt, image_name = test_data
                #haze = F.interpolate(haze, scale_factor = 0.25,recompute_scale_factor=True)
                haze = haze.to(device)
                gt = gt.to(device)
                start_time = time.time()
                dehaze_1, feat1 = G1(haze)
                dehaze, _, _ = G2(dehaze_1)
                gt = gt
                end_time = time.time() - start_time
                net_time += end_time
                net_count += 1
                test_info = to_psnr_test(dehaze, gt)
                psnr.append(sum(test_info) / len(test_info))
                print(sum(test_info) / len(test_info))
            # --- Save image --- #
            save_image(dehaze, image_name, 'NH')
        test_psnr = sum(psnr) / len(psnr)
        print('Test PSNR:' + str(test_psnr))
        print('net time is {0:.4f}'.format(net_time / net_count))

    if test_phrase == 3:
        G1 = Generate_quarter(height=network_height,
                              width=network_width,
                              num_dense_layer=num_dense_layer,
                              growth_rate=growth_rate)
        G1 = G1.to(device)
        G1 = nn.DataParallel(G1, device_ids=device_ids)
        G1.load_state_dict(
            torch.load('./checkpoint/3-' + str(test_epoch) + '_G1.tar'))
        G2 = Generate_quarter_refine(height=network_height,
                                     width=network_width,
                                     num_dense_layer=num_dense_layer,
                                     growth_rate=growth_rate)
        G2 = G2.to(device)
        G2 = nn.DataParallel(G2, device_ids=device_ids)
        G2.load_state_dict(
            torch.load('./checkpoint/3_' + str(test_epoch) + '_G2.tar'))
        G3 = Generate(height=network_height,
                      width=network_width,
                      num_dense_layer=num_dense_layer,
                      growth_rate=growth_rate)
        G3 = G3.to(device)
        G3 = nn.DataParallel(G3, device_ids=device_ids)
        G3.load_state_dict(
            torch.load('./checkpoint/33_' + str(test_epoch) + '_G3.tar'))
        G1.eval()
        G2.eval()
        G3.eval()
        psnr = []
        net_time = 0.
        net_count = 0.
        for batch_id, test_data in enumerate(test_data_loader):
            with torch.no_grad():
                haze, gt, image_name = test_data
                haze = haze.to(device)
                gt = gt.to(device)
                start_time = time.time()
                dehaze_1, feat1 = G1(
                    F.interpolate(haze,
                                  scale_factor=0.25,
                                  recompute_scale_factor=True))
                dehaze_2, feat, feat2 = G2(dehaze_1)
                dehaze = G3(
                    haze,
                    F.interpolate(dehaze_2,
                                  scale_factor=4,
                                  recompute_scale_factor=True), feat)
                end_time = time.time() - start_time
                net_time += end_time
                net_count += 1
                test_info = to_psnr(dehaze, gt)
                psnr.append(sum(test_info) / len(test_info))
                print(sum(test_info) / len(test_info))
            # --- Save image --- #
            save_image(dehaze, image_name, 'NH')
        test_psnr = sum(psnr) / len(psnr)
        print('Test PSNR:' + str(test_psnr))
        print('net time is {0:.4f}'.format(net_time / net_count))
    return test_psnr