Пример #1
0
def train_a_epoch_spatial(training_data_loader, optimizer, model, criterion,
                          epoch, logfile):

    model.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    train_losses = AverageMeter()
    train_psnrs = AverageMeter()

    end = time.time()
    train_start_time = time.time()
    for iteration, batch in enumerate(training_data_loader):
        # (96) and (192)
        data, target = Variable(batch[0]).cuda(), Variable(
            batch[1], requires_grad=False).cuda()
        data_time.update(time.time() - end)
        _, _, output = model(data)  # train spatial stage
        # compute loss
        loss = criterion(output, target)

        mse = torch.sum(torch.sum(
            ((output[:, :, 8:-8, 8:-8] - target[:, :, 8:-8, 8:-8]) *
             (output[:, :, 8:-8, 8:-8] - target[:, :, 8:-8, 8:-8])).view(
                 -1, (target.shape[2] - 16), (target.shape[3] - 16)),
            dim=-1,
            keepdim=False),
                        dim=-1,
                        keepdim=False) / float(
                            (target.shape[2] - 16) * (target.shape[3] - 16))
        psnrs = 10 * torch.log10(1.0 / mse)
        train_psnrs.update((psnrs.mean()).data.item(),
                           target.shape[0] * target.shape[1])
        # update model
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), opt.clip)  # new
        optimizer.step()

        train_losses.update(loss.data.item(), target.shape[0])
        if iteration % 100 == 1:
            print_cz('  Batch {:d}, loss {:.1f}, PSNR_present: {:.3f} dB, PSNR_avg: {:.3f} dB'\
                .format(iteration, loss.data.item(), psnrs.mean(), train_psnrs.avg ), f=logfile)

        del batch, data, output, target, loss, mse, psnrs

        batch_time.update(time.time() - end)
        end = time.time()
    train_end_time = time.time()
    print_cz('  Train Loss: {:.3f}\t PSNR: {:.3f} dB\t Time: {:.1f}\t BatchT: {:.3f}\t DataT: {:.3f}\t D/B: {:.1f}%'\
        .format(train_losses.avg, train_psnrs.avg, (train_end_time-train_start_time), batch_time.avg, data_time.avg, 100.0*(data_time.avg/batch_time.avg)), f=logfile)
Пример #2
0
def test(testing_data_loader, model, criterion, epoch, logfile=None):

    model.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    test_psnrs = AverageMeter()
    test_ssims = AverageMeter()
    end = time.time()
    test_start_time = time.time()
    with torch.no_grad():
        for batch in testing_data_loader:
            data, target = Variable(batch[0]).cuda(), Variable(
                batch[1], requires_grad=False).cuda()
            # data, target = Variable(batch[0]), Variable(batch[1], requires_grad=False)
            data_time.update(time.time() - end)

            output, _, _ = model(data)

            ssim_value = pytorch_ssim.ssim(output, target)
            test_ssims.update(ssim_value, target.shape[0])

            mse = torch.sum(torch.sum(
                ((output[:, :, 8:-8, 8:-8] - target[:, :, 8:-8, 8:-8]) *
                 (output[:, :, 8:-8, 8:-8] - target[:, :, 8:-8, 8:-8])).view(
                     -1, (target.shape[2] - 16), (target.shape[3] - 16)),
                dim=-1,
                keepdim=False),
                            dim=-1,
                            keepdim=False) / float((target.shape[2] - 16) *
                                                   (target.shape[3] - 16))
            psnrs = 10 * torch.log10(1.0 / mse)
            test_psnrs.update((psnrs.mean()).data.item(),
                              target.shape[0] * target.shape[1])

            del batch, data, output, target, mse, ssim_value, psnrs

            batch_time.update(time.time() - end)
            end = time.time()
        test_end_time = time.time()
        print_cz('  Test PSNR: {:.3f} dB \tSSIM: {:.3f}\t Time: {:.1f}\t BatchT: {:.3f}\t DataT: {:.3f}\t D/B: {:.1f}%'\
                .format(
                    test_psnrs.avg,
                    test_ssims.avg,
                    (test_end_time-test_start_time),
                    batch_time.avg,
                    data_time.avg,
                    100.0*(data_time.avg/batch_time.avg)),
                f=logfile)
Пример #3
0
def main():

    global opt, model
    opt = config.get_args()
    print(opt)

    log_file = None

    starting_time = time.time()
    print_cz(str=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
             f=log_file)

    cudnn.benchmark = True
    print_cz("===> Building model", f=log_file)
    spatial_sr = spatial_stage.Spatial_Stage()
    wavelet_sr = wavelet_stage.Wavelet_Stage()
    model = swdnet.SWDNet(spatial_sr, wavelet_sr)
    if opt.data_degradation in ['bicubic', 'Bicubic']:
        pth_file = './weights/swdnet-bicubic-dict.pth'
    elif opt.data_degradation in ['nearest', 'Nearest']:
        pth_file = './weights/swdnet-nearest-dict.pth'
    model.load_state_dict(torch.load(pth_file))

    print_cz("===> Setting GPU", f=log_file)
    if opt.job_type == 'S' or opt.job_type == 's':
        model.cuda()
    else:
        if opt.job_type == 'Q' or opt.job_type == 'q':
            gpu_device_ids = [0, 1, 2, 3]
        elif opt.job_type == 'E' or opt.job_type == 'e':
            gpu_device_ids = [0, 1, 2, 3, 4, 5, 6, 7]
        elif opt.job_type == 'D' or opt.job_type == 'd':
            gpu_device_ids = [0, 1]
        model = nn.DataParallel(model.cuda(), device_ids=gpu_device_ids).cuda()
    criterion = nn.L1Loss(size_average=False)

    print_cz("===> Loading datasets", f=log_file)
    testing_data_loader = data_loader_lmdb.get_loader(
        os.path.join(config.dataset_dir, opt.data_degradation, 'test_lmdb'),
        batch_size=opt.batch_size,
        stage='test',
        num_workers=opt.num_workers)

    print_cz("===> Testing", f=log_file)
    test(testing_data_loader, model, criterion, epoch=0, logfile=None)

    print_cz(str(time.time() - starting_time), f=log_file)
Пример #4
0
def train(training_data_loader,
          testing_data_loader,
          swratio,
          swratio_tmp,
          optimizer,
          optimizer_spatial,
          model,
          criterion,
          save_dir=None,
          logfile=None,
          pth_prefix=''):
    print_cz("===> Training Spatial Stage", f=logfile)
    for epoch in range(opt.epochs):
        lr = adjust_learning_rate(epoch)
        for param_group in optimizer_spatial.param_groups:
            param_group["lr"] = lr
        print_cz("Epoch = {}, lr = {}".format(
            epoch, optimizer_spatial.param_groups[0]["lr"]),
                 f=logfile)
        train_a_epoch_spatial(training_data_loader=training_data_loader,
                              optimizer=optimizer_spatial,
                              model=model,
                              criterion=criterion,
                              epoch=epoch,
                              logfile=logfile)
        test_spatial(testing_data_loader=testing_data_loader,
                     model=model,
                     criterion=criterion,
                     epoch=epoch,
                     logfile=logfile)

    print_cz("===> Training SWD-Net", f=logfile)
    best_psnr = 0
    test_spatial(testing_data_loader=testing_data_loader,
                 model=model,
                 criterion=criterion,
                 epoch=epoch,
                 logfile=logfile)
    for epoch in range(opt.epochs):
        lr = adjust_learning_rate(epoch)
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        print_cz("Epoch = {}, lr = {}".format(epoch,
                                              optimizer.param_groups[0]["lr"]),
                 f=logfile)
        train_loss, train_psnr = train_a_epoch(training_data_loader, swratio,
                                               swratio_tmp, optimizer, model,
                                               criterion, epoch, logfile)
        test_loss, test_psnr, test_ssim = test(testing_data_loader, model,
                                               criterion, epoch, logfile)
        if test_psnr > best_psnr:
            best_psnr = test_psnr
            if save_dir is not None:  # save flag
                model_snapshot(
                    model,
                    new_file=
                    (pth_prefix +
                     'model-best-{}-TestL{:.1f}-TestPSNR-{:.3f}dB-TestSSIM{:.3f}-{}.pth'
                     .format(epoch, test_loss, best_psnr, test_ssim,
                             time_mark())),
                    old_file=pth_prefix + 'model-best-',
                    save_dir=save_dir + '/',
                    verbose=True)
                print_cz('*better model saved successfully*', f=logfile)
Пример #5
0
def main():
    global opt, model
    opt = config.get_args()
    print(opt)
    save_folder = prepare()
    log_file = open((save_folder + '/' + 'print_out_screen.txt'), 'w')

    with open(save_folder + '/args.json', 'w') as f:
        f.write(json.dumps(opt.__dict__, indent=4))

    starting_time = time.time()
    print_cz(str=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
             f=log_file)
    # opt.seed = random.randint(1, 10000)
    # opt.seed = 127
    cudnn.benchmark = True

    print_cz("===> Building model", f=log_file)
    # print_cz(opt.interpolate + ' \t ' + str(opt.corner), f=log_file)
    spatial_sr = spatial_stage.Spatial_Stage()
    wavelet_sr = wavelet_stage.Wavelet_Stage()
    model = swdnet.SWDNet(spatial_sr, wavelet_sr)

    print_cz("===> Setting GPU", f=log_file)
    if opt.job_type == 'S' or opt.job_type == 's':
        model.cuda()
    else:
        if opt.job_type == 'Q' or opt.job_type == 'q':
            gpu_device_ids = [0, 1, 2, 3]
        elif opt.job_type == 'E' or opt.job_type == 'e':
            gpu_device_ids = [0, 1, 2, 3, 4, 5, 6, 7]
        elif opt.job_type == 'D' or opt.job_type == 'd':
            gpu_device_ids = [0, 1]
        model = nn.DataParallel(model.cuda(), device_ids=gpu_device_ids).cuda()

    criterion = nn.L1Loss(size_average=False)

    print_cz("===> Loading datasets", f=log_file)
    training_data_loader = data_loader_lmdb.get_loader(
        os.path.join(config.dataset_dir, opt.data_degradation, 'train_lmdb'),
        batch_size=opt.batch_size,
        stage='train',
        num_workers=opt.num_workers)
    testing_data_loader = data_loader_lmdb.get_loader(
        os.path.join(config.dataset_dir, opt.data_degradation, 'test_lmdb'),
        batch_size=opt.batch_size,
        stage='test',
        num_workers=opt.num_workers)

    print_cz("===> Setting Optimizer", f=log_file)
    if opt.optim in ['SGD', 'sgd']:
        optimizer = optim.SGD(model.parameters(),
                              lr=opt.lr,
                              momentum=opt.momentum,
                              weight_decay=opt.weight_decay)
        optimizer_spatial = optim.SGD(model.spatial.parameters(),
                                      lr=opt.lr,
                                      momentum=opt.momentum,
                                      weight_decay=opt.weight_decay)
    elif opt.optim in ['Adam', 'adam']:
        optimizer = optim.Adam(model.parameters(),
                               lr=opt.lr,
                               weight_decay=opt.weight_decay)
        optimizer_spatial = optim.Adam(model.spatial_sr.parameters(),
                                       lr=opt.lr,
                                       weight_decay=opt.weight_decay)

    print_cz("===> Training", f=log_file)
    train(training_data_loader=training_data_loader,
          testing_data_loader=testing_data_loader,
          swratio=opt.SWratio,
          swratio_tmp=opt.SWratio_tmp,
          optimizer=optimizer,
          optimizer_spatial=optimizer_spatial,
          model=model,
          criterion=criterion,
          save_dir=save_folder,
          logfile=log_file)

    print_cz(str(time.time() - starting_time), f=log_file)
    log_file.close()