Пример #1
0
def main():
    global opt
    opt = parser.parse_args()
    opt.gpuids = list(map(int, opt.gpuids))

    print(opt)

    if opt.cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")
    cudnn.benchmark = True

    train_set = get_training_set(opt.upscale_factor, opt.add_noise,
                                 opt.noise_std)
    validation_set = get_validation_set(opt.upscale_factor)
    test_set = get_test_set(opt.upscale_factor)
    training_data_loader = DataLoader(dataset=train_set,
                                      num_workers=opt.threads,
                                      batch_size=opt.batch_size,
                                      shuffle=True)
    validating_data_loader = DataLoader(dataset=validation_set,
                                        num_workers=opt.threads,
                                        batch_size=opt.test_batch_size,
                                        shuffle=False)
    testing_data_loader = DataLoader(dataset=test_set,
                                     num_workers=opt.threads,
                                     batch_size=opt.test_batch_size,
                                     shuffle=False)

    model = SRCNN()
    criterion = nn.MSELoss()

    if opt.cuda:
        torch.cuda.set_device(opt.gpuids[0])
        with torch.cuda.device(opt.gpuids[0]):
            model = model.cuda()
            criterion = criterion.cuda()
        model = nn.DataParallel(model,
                                device_ids=opt.gpuids,
                                output_device=opt.gpuids[0])

    optimizer = optim.Adam(model.parameters(), lr=opt.lr)

    if opt.test:
        model_name = join("model", opt.model)
        model = torch.load(model_name)
        model = nn.DataParallel(model,
                                device_ids=opt.gpuids,
                                output_device=opt.gpuids[0])
        start_time = time.time()
        test(model, criterion, testing_data_loader)
        elapsed_time = time.time() - start_time
        print("===> average {:.2f} image/sec for processing".format(
            100.0 / elapsed_time))
        return

    for epoch in range(1, opt.epochs + 1):
        train(model, criterion, epoch, optimizer, training_data_loader)
        validate(model, criterion, validating_data_loader)
        if epoch % 10 == 0:
            checkpoint(model, epoch)
Пример #2
0
import argparse
parser = argparse.ArgumentParser(description='predictionCNN Example')
parser.add_argument('--cuda', action='store_true', default=False)
parser.add_argument('--weight_path', type=str, default=None)
parser.add_argument('--save_dir', type=str, default=None)
opt = parser.parse_args()

test_set = DatasetFromFolderEval(image_dir='./data/General-100/test',
                                 scale_factor=4)
test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)

model = SRCNN()
criterion = nn.MSELoss()
if opt.cuda:
    model = model.cuda()
    criterion = criterion.cuda()

model.load_state_dict(
    torch.load(opt.weight_path, map_location='cuda' if opt.cuda else 'cpu'))

model.eval()
total_loss, total_psnr = 0, 0
total_loss_b, total_psnr_b = 0, 0
with torch.no_grad():
    for batch in test_loader:
        inputs, targets = batch[0], batch[1]
        if opt.cuda:
            inputs = inputs.cuda()
            targets = targets.cuda()
Пример #3
0
test_set = get_test_set(opt.upscale_factor)
training_data_loader = DataLoader(dataset=train_set,
                                  num_workers=opt.threads,
                                  batch_size=opt.batch_size,
                                  shuffle=True)
testing_data_loader = DataLoader(dataset=test_set,
                                 num_workers=opt.threads,
                                 batch_size=opt.test_batch_size,
                                 shuffle=False)

srcnn = SRCNN()
criterion = nn.MSELoss()

if (use_cuda):
    torch.cuda.set_device(opt.gpuid)
    srcnn.cuda()
    criterion = criterion.cuda()

optimizer = optim.SGD(srcnn.parameters(), lr=opt.lr)
#optimizer = optim.Adam(srcnn.parameters(),lr=opt.lr)


def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        input, target = Variable(batch[0]), Variable(batch[1])
        if use_cuda:
            input = input.cuda()
            target = target.cuda()

        optimizer.zero_grad()