示例#1
0
def main():
    # ===========================================================
    # Set train dataset & test dataset
    # ===========================================================
    print('===> Loading datasets')
    train_set = get_training_set(args.upscale_factor)
    test_set = get_test_set(args.upscale_factor)
    training_data_loader = DataLoader(dataset=train_set,
                                      batch_size=args.batchSize,
                                      shuffle=True)
    testing_data_loader = DataLoader(dataset=test_set,
                                     batch_size=args.testBatchSize,
                                     shuffle=False)

    if args.model == 'sub':
        model = SubPixelTrainer(args, training_data_loader,
                                testing_data_loader)
    elif args.model == 'srcnn':
        model = SRCNNTrainer(args, training_data_loader, testing_data_loader)
    elif args.model == 'vdsr':
        model = VDSRTrainer(args, training_data_loader, testing_data_loader)
    elif args.model == 'edsr':
        model = EDSRTrainer(args, training_data_loader, testing_data_loader)
    elif args.model == 'fsrcnn':
        model = FSRCNNTrainer(args, training_data_loader, testing_data_loader)
    elif args.model == 'drcn':
        model = DRCNTrainer(args, training_data_loader, testing_data_loader)
    elif args.model == 'srgan':
        model = SRGANTrainer(args, training_data_loader, testing_data_loader)
    elif args.model == 'dbpn':
        model = DBPNTrainer(args, training_data_loader, testing_data_loader)
    else:
        raise Exception("the model does not exist")

    model.run()
示例#2
0
def main():
    # ===========================================================
    # Set train dataset & valid dataset
    # ===========================================================
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('===> Loading datasets')
    if args.dataset == 'test':
        test_set = get_test_set(args.upscale_factor)
    elif args.dataset == 'valid':
        test_set = get_valid_set(args.upscale_factor)
    else:
        raise NotImplementedError
    test_data_loader = DataLoader(dataset=test_set,
                                  batch_size=args.batchSize,
                                  shuffle=False)

    file_name = args.model + "_generator.pth" if "gan" in args.model else "model.pth"
    model_name = args.model + ("_diff" if args.diff else "")
    model_path = "/home/teven/canvas/python/super-resolution/results/models/{}/{}".format(
        model_name, file_name)
    model = torch.load(model_path, map_location=lambda storage, loc: storage)
    model = model.to(device)
    model.eval()

    avg_psnr = 0
    avg_baseline_psnr = 0
    criterion = nn.MSELoss()

    with torch.no_grad():
        for batch_num, (data, target) in enumerate(test_data_loader):
            data, target = data.to(device), target.to(device)
            prediction = model(data)
            mse = criterion(prediction, target)
            psnr = 10 * log10(1 / mse.item())
            avg_psnr += psnr
            progress_bar(batch_num, len(test_data_loader),
                         'PSNR: %.3f' % (avg_psnr / (batch_num + 1)))

            baseline = F.interpolate(data,
                                     scale_factor=args.upscale_factor,
                                     mode='bilinear',
                                     align_corners=False)
            baseline_mse = criterion(baseline, target)
            baseline_psnr = 10 * log10(1 / baseline_mse.item())
            avg_baseline_psnr += baseline_psnr
            progress_bar(batch_num, len(test_data_loader),
                         'PSNR: %.3f' % (avg_baseline_psnr / (batch_num + 1)))

    print("    Average PSNR: {:.3f} dB".format(avg_psnr /
                                               len(test_data_loader)))
    print("    Average Baseline PSNR: {:.3f} dB".format(avg_baseline_psnr /
                                                        len(test_data_loader)))
示例#3
0
文件: main.py 项目: tangjiafu/SRMSD
def main():
    # ===========================================================
    # 设置train dataset & test dataset
    # ===========================================================
    print('===> Loading datasets')
    train_set = get_training_set(args.upscale_factor)
    test_set = get_test_set(args.upscale_factor)
    training_data_loader = DataLoader(dataset=train_set, batch_size=args.batchSize, shuffle=True)
    testing_data_loader = DataLoader(dataset=test_set, batch_size=args.testBatchSize, shuffle=False)

    model = FSRCNNTrainer(args, training_data_loader, testing_data_loader)
    print("USE ",model.device)
    model.run()
示例#4
0
def main():
    # ===========================================================
    # Set train dataset & test dataset
    # ===========================================================
    print('===> Loading datasets')
    train_set = get_training_set(args.upscale_factor, args.image_dir)
    test_set = get_test_set(args.upscale_factor, args.image_dir)
    training_data_loader = DataLoader(dataset=train_set,
                                      batch_size=args.batchSize,
                                      shuffle=True)
    testing_data_loader = DataLoader(dataset=test_set,
                                     batch_size=args.testBatchSize,
                                     shuffle=False)

    # ===========================================================
    # Generate Model from training data set
    # ===========================================================
    model = SRGANTrainer(args, training_data_loader, testing_data_loader)

    model.run()
            mse = criterion(prediction, target)
            psnr = 10 * log10(1 / mse.item())
            avg_psnr += psnr
            #progress_bar(batch_num, len(test_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))

    return avg_psnr / len(test_loader)
    #print("    Average PSNR: {:.4f} dB".format(avg_psnr / len(test_loader)))


if args.fwd_bits < 32:
    model = quant.duplicate_model_with_linearquant_nobn(
        model,
        param_bits=args.param_bits,
        fwd_bits=args.fwd_bits,
        counter=args.n_sample)
    #print(model)

    dataset = get_test_set(upscale_factor, args.data_root)
    test_loader = DataLoader(dataset=dataset,
                             batch_size=args.batch_size,
                             shuffle=False)
    print("load dataset done")
    test(model, test_loader, upscale_factor, args.n_sample)
print("==============================================================")

print("============================eval==================================")
avg_psnr = test(model, test_loader, upscale_factor)
print("    Average PSNR: {:.4f} dB".format(avg_psnr))

print("==================================================================")