Example #1
0
    def train(self):
        self.scheduler.step()
        # self.loss.step()    # edit for new
        epoch = self.scheduler.last_epoch + 1
        lr = self.scheduler.get_lr()[0]

        self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(
            epoch, Decimal(lr)))
        # self.loss.start_log() # edit for new
        self.model.train()

        timer_data, timer_model = utility.timer(), utility.timer()
        for batch, (lr, hr, _) in enumerate(self.loader_train):
            lr, hr = self.prepare([lr, hr])

            timer_data.hold()
            timer_model.tic()

            self.optimizer.zero_grad()
            noise = torch.FloatTensor(hr.size()).normal_(mean=0,
                                                         std=self.args.noiseL /
                                                         255.).cuda()
            lr = hr + noise
            sr = self.model(lr, 1)  #self.args.scale)
            loss = self.loss(sr, hr) / (hr.size()[0] * 2)
            # print(loss,(hr.size()[0]*2))
            psnr_train = batch_PSNR(sr, hr, 1.)
            psnr_org = batch_PSNR(lr, hr, 1.)
            if loss.item() < self.args.skip_threshold * self.error_last:
                loss.backward()
                self.optimizer.step()
            else:
                print('Skip this batch {}! (Loss: {})'.format(
                    batch + 1,
                    loss.item()  #data[0]
                ))

            timer_model.hold()

            if (batch + 1) % self.args.print_every == 0:
                self.ckp.write_log(
                    '[{}/{}]\t{}\t{:.4f}\t{:.4f}\t{:.1f}+{:.1f}s'.format(
                        (batch + 1) * self.args.batch_size,
                        len(self.loader_train.dataset),
                        loss.item(),  #display_loss(batch),   # edit for new
                        psnr_train,
                        psnr_org,
                        timer_model.release(),
                        timer_data.release()))

            timer_data.tic()
Example #2
0
def evaluate(model, dataset_val):
    """

    :param model: network model
    :param dataset_val:
    :return: psnr_val
    """
    with torch.no_grad():
        model.eval()
        # validate
        psnr_val = 0
        for k in range(len(dataset_val)):
            img_val = torch.unsqueeze(dataset_val[k], 0)
            torch.manual_seed(0)
            noise = torch.FloatTensor(img_val.size()).normal_(mean=0,
                                                              std=opt.noiseL /
                                                              255.)

            imgn_val = img_val + noise
            img_val, imgn_val = img_val.cuda(), imgn_val.cuda()
            out_val = model(imgn_val)
            out_val = torch.clamp(out_val, 0., 1.)
            psnr_val += batch_PSNR(out_val, img_val, 1.)
        psnr_val /= len(dataset_val)
    return psnr_val
Example #3
0
    def predict_deraining(self):
        # Deraining
        self.DNet.eval()

        current_data_list = [self.test_data, self.test_data_semi
                             ] if self.train_path_semi else [
                                 self.test_data,
                             ]
        for kk, currrent_data in enumerate(current_data_list):
            num_frame = currrent_data.shape[1]
            test_data_derain = torch.zeros(
                currrent_data.shape)  # c x n x p x p
            for ii in range(ceil(num_frame / self.truncate_test)):
                start_ind = ii * self.truncate_test
                end_ind = min((ii + 1) * self.truncate_test, num_frame)
                inputs = currrent_data[:, start_ind:end_ind, ].cuda(
                )  # c x truncate x p x p
                with torch.set_grad_enabled(False):
                    out = self.DNet(inputs.unsqueeze(0)).clamp_(0.0,
                                                                1.0).squeeze(0)
                test_data_derain[:, start_ind:end_ind, ] = out.cpu()

                if len(current_data_list) == 2 and kk == 1:
                    x1 = vutils.make_grid(inputs.permute([1, 0, 2, 3]),
                                          normalize=True,
                                          scale_each=True)
                    self.writer.add_image('Test Rainy Image', x1,
                                          self.log_im_step['test'])
                    x2 = vutils.make_grid(out.permute([1, 0, 2, 3]),
                                          normalize=True,
                                          scale_each=True)
                    self.writer.add_image('Test Deained Image', x2,
                                          self.log_im_step['test'])
                    self.log_im_step['test'] += 1
                else:
                    if random.randint(1, 10) == 1:
                        x1 = vutils.make_grid(inputs.permute([1, 0, 2, 3]),
                                              normalize=True,
                                              scale_each=True)
                        self.writer.add_image('Test Rainy Image', x1,
                                              self.log_im_step['test'])
                        x2 = vutils.make_grid(out.permute([1, 0, 2, 3]),
                                              normalize=True,
                                              scale_each=True)
                        self.writer.add_image('Test Deained Image', x2,
                                              self.log_im_step['test'])
                        self.log_im_step['test'] += 1

            if kk == 0:
                self.psnrm = batch_PSNR(
                    test_data_derain[:, 2:-2, ].permute([1, 0, 2, 3]),
                    self.test_gt[:, 2:-2, ].permute([1, 0, 2, 3]),
                    ycbcr=False)
                self.ssimm = batch_SSIM(
                    test_data_derain[:, 2:-2, ].permute([1, 0, 2, 3]),
                    self.test_gt[:, 2:-2, ].permute([1, 0, 2, 3]),
                    ycbcr=False)
Example #4
0
    def valid_step(image, label):
        pred = model(image)
        pred = image - pred
        mae_iter = F.nn.l1_loss(pred, label)
        psnr_it = batch_PSNR(pred, label)
        #print(psnr_it.item())
        if world_size > 1:
            mae_iter = F.distributed.all_reduce_sum(mae_iter) / world_size
            psnr_it = F.distributed.all_reduce_sum(psnr_it) / world_size

        return mae_iter, psnr_it
Example #5
0
        if torch.sum(mask) < 256:
            continue

        # -----------------
        #  Train model
        # -----------------

        optimizer.zero_grad()

        # Generate a batch of images
        pred1, pred2, pred3, w_imgL_o, result, _ = model(
            left, right, left_g, right_g, left_o, right_o)

        # calculate PSNR
        PSNR = batch_PSNR(torch.clamp(result, 0., 1.), right_gt, 1.)

        # calculate disp ME
        ME = batch_me(pred3[mask], disp[mask])

        # loss aggregation
        g_loss = L1loss(result[:, :, :, :CROP_WIDTH // 2],
                        right_gt[:, :, :, :CROP_WIDTH // 2])
        g_loss += L1loss(pred3[mask], disp[mask])
        # g_loss += L1loss(w_imgL_o[:,:,:,:CROP_WIDTH//2], warped_gt[:,:,:,:CROP_WIDTH//2])

        if PSNR < 100:
            e_loss += g_loss.item()
            e_PSNR += PSNR
            e_ME += ME
            count += 1
Example #6
0
def main():
    print('Loading dataset ...\n')
    dataset_train = Dataset(data_path=opt.data_path)
    loader_train = DataLoader(dataset=dataset_train,
                              num_workers=4,
                              batch_size=opt.batch_size,
                              shuffle=True)
    print("# of training samples: %d\n" % int(len(loader_train)))
    # Build model
    model = NET(input_channel=32)
    # print_network(model)

    # loss function
    criterion = SSIM()
    criterion1 = nn.L1Loss()
    # Move to GPU
    if opt.use_GPU:
        model = model.cuda()
        criterion.cuda()
        criterion1.cuda()

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=opt.lr)
    scheduler = MultiStepLR(optimizer, milestones=opt.milestone, gamma=0.2)
    # record training
    writer = SummaryWriter(opt.save_path)

    # load the lastest model
    initial_epoch = findLastCheckpoint(save_dir=opt.save_path)
    if initial_epoch > 0:
        print('resuming by loading epoch %d' % initial_epoch)
        model.load_state_dict(
            torch.load(
                os.path.join(opt.save_path,
                             'net_epoch%d.pth' % initial_epoch)))
        # start training
    step = 0
    for epoch in range(initial_epoch, opt.epochs):
        scheduler.step(epoch)
        for param_group in optimizer.param_groups:
            print('learning rate %f' % param_group["lr"])

        # epoch training start
        for i, (input_train, target_train) in enumerate(loader_train, 0):
            model.train()
            model.zero_grad()
            optimizer.zero_grad()

            input_train, target_train = Variable(input_train), Variable(
                target_train)

            if opt.use_GPU:
                input_train, target_train = input_train.cuda(
                ), target_train.cuda()

            out_train = model(input_train)
            pixel_metric = criterion(target_train, out_train)
            loss2 = criterion1(target_train, out_train)
            loss = 1 - pixel_metric + loss2

            loss.backward()
            optimizer.step()

            model.eval()
            out_train = model(input_train)
            out_train = torch.clamp(out_train, 0., 1.)
            psnr_train = batch_PSNR(out_train, target_train, 1.)
            print(
                "[epoch %d][%d/%d] loss: %.4f, pixel_metric: %.4f,PSNR: %.4f" %
                (epoch + 1, i + 1, len(loader_train), loss.item(),
                 pixel_metric.item(), psnr_train))

            if step % 10 == 0:
                # Log the scalar values
                writer.add_scalar('loss', loss.item(), step)
                writer.add_scalar('PSNR on training data', psnr_train, step)
            step += 1
        # epoch training end

        # log the images
        model.eval()
        out_train = model(input_train)
        out_train = torch.clamp(out_train, 0., 1.)
        im_target = utils.make_grid(target_train.data,
                                    nrow=8,
                                    normalize=True,
                                    scale_each=True)
        im_input = utils.make_grid(input_train.data,
                                   nrow=8,
                                   normalize=True,
                                   scale_each=True)
        im_derain = utils.make_grid(out_train.data,
                                    nrow=8,
                                    normalize=True,
                                    scale_each=True)
        writer.add_image('clean image', im_target, epoch + 1)
        writer.add_image('rainy image', im_input, epoch + 1)
        writer.add_image('deraining image', im_derain, epoch + 1)

        torch.save(model.state_dict(),
                   os.path.join(opt.save_path, 'net_latest.pth'))

        if epoch % opt.save_freq == 0:
            torch.save(
                model.state_dict(),
                os.path.join(opt.save_path, 'net_epoch%d.pth' % (epoch + 1)))
Example #7
0
model_restoration = nn.DataParallel(model_restoration)

model_restoration.eval()

with torch.no_grad():
    psnr_val_raw = []
    for ii, data_val in enumerate(tqdm(test_loader), 0):
        raw_gt = data_val[0].cuda()
        raw_noisy = data_val[1].cuda()
        variance = data_val[2].cuda(
        )  ##variance = shot_noise * raw_noisy + read_noise  (Shot and Read noise comes from images' metadata)
        filenames = data_val[3]
        raw_restored = model_restoration(raw_noisy, variance)
        raw_restored = torch.clamp(raw_restored, 0, 1)
        psnr_val_raw.append(utils.batch_PSNR(raw_restored, raw_gt, 1.))

        if args.save_images:
            for batch in range(len(raw_gt)):
                denoised_img = utils.unpack_raw(
                    raw_restored[batch, :, :, :].unsqueeze(0))
                denoised_img = denoised_img.permute(
                    0, 2, 3, 1).cpu().detach().numpy()[0] * 255
                denoised_img = np.squeeze(np.stack(
                    (denoised_img, ) * 3, -1)) * 255
                lycon.save(args.result_dir + filenames[batch][:-4] + '.png',
                           denoised_img.astype(np.uint8))

psnr_val_raw = sum(psnr_val_raw) / len(psnr_val_raw)
print("PSNR: %.2f " % (psnr_val_raw))
Example #8
0
def main(args):
    # Load dataset
    print('> Loading dataset ...')
    Dataset = MemoryFriendlyLoader(origin_img_dir=args.gt_dir, edited_img_dir=args.train_dir,
                                   pathlistfile=args.filelist)
    loader_train = torch.utils.data.DataLoader(dataset=Dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=8, pin_memory=True)
    print('\t# of training samples: %d\n' % int(len(Dataset)))

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    model = MAPVDNet(cuda_flag=True, alignment_model=args.pretrained_model, T=args.stages).cuda()
    model = torch.nn.DataParallel(model)

    criterion = nn.L1Loss().cuda()

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    plotx = []
    ploty = []

    checkpoint_path = args.ckp_dir + 'checkpoint_%d_%depoch.ckpt' % (args.sigma, args.start_epoch)
    if args.use_checkpoint:
        model, optimizer, start_epoch, ploty = load_checkpoint(model, optimizer, checkpoint_path)
        model = torch.nn.DataParallel(model)
        print('cps loaded!')
        plotx = list(range(len(ploty)))


    # Training
    for epoch in range(args.start_epoch, args.epochs):
        losses = 0

        # train over all data in the epoch
        for step, (x, y, path_code) in enumerate(loader_train):

            # Pre-training step
            model.train()
            model.zero_grad()
            optimizer.zero_grad()

            frames_input = x
            frame_clean = y

            frame_clean = Variable(frame_clean.cuda())
            frames_input = Variable(frames_input.cuda())

            # Evaluate model and optimize it
            x_list = model(frames_input)

            loss = criterion(x_list[-1], frame_clean)
            for i in range(1, len(x_list)-1): # 1, 2, 3, 4
                loss += 0.0001 * criterion(x_list[i], frame_clean)

            losses += loss.item()
            loss.backward()
            optimizer.step()

            if step % 100 == 0:
                # Results
                model.eval()
                psnr_train = batch_PSNR(x_list[-1], frame_clean, 1.)
                print('%s  [epoch %d][%d/%d]  loss: %f  PSNR_train: %.4fdB' % \
                    (show_time(datetime.datetime.now()), epoch + 1, step + 1, len(loader_train), losses / (step+1), psnr_train))

        # save loss pic
        plotx.append(epoch + 1)
        ploty.append(losses / (step + 1))
        if epoch // 1 == epoch / 1:
            plt.plot(plotx, ploty)
            plt.savefig(args.loss_pic)
        # save loss.txt
        file = open(args.savetxt, 'a')
        file.write('epoch %d loss: %f, val_psnr: %f\n' % ((epoch + 1), losses / (step+1), psnr_train))
        file.close()
        # save checkpoint
        if not os.path.exists(args.ckp_dir):
            os.mkdir(args.ckp_dir)
        save_checkpoint(model, optimizer, epoch + 1, ploty, args.ckp_dir + 'checkpoint_%d_%depoch.ckpt' %
                        (args.sigma, epoch + 1))
        # save denoise.pkl
        torch.save(model, os.path.join(args.save_dir + '/denoising_%d_%d.pkl' % (args.sigma, epoch + 1)))