Exemple #1
0
def train(training_data, dev_data, args):
    training_gen = data.DataLoader(training_data, batch_size=2)
    dev_gen = data.DataLoader(dev_data, batch_size=2)
    device = torch.device('cuda' if cuda.is_available() else 'cpu')
    print('Initializing model')
    model = SRCNN()
    loss = RMSE()
    if cuda.device_count() > 1:
        print('Using %d CUDA devices' % cuda.device_count())
        model = nn.DataParallel(
            model, device_ids=[i for i in range(cuda.device_count())])
    model.to(device)
    loss.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    def _train(data, opt=True):
        total = 0
        for y, x in data:
            y, x = y.to(device), x.to(device)
            pred_y = model(x)
            l = loss(pred_y, y)
            total += l.item()
            if opt:
                optimizer.zero_grad()
                l.backward()
                optimizer.step()
        cuda.synchronize()
        return total

    print('Training')
    for ep in range(args.ep):
        train_loss = _train(training_gen)
        dev_loss = _train(dev_gen, opt=False)
        print_flush('Epoch %d: Train %.4f Dev %.4f' %
                    (ep, train_loss, dev_loss))
        if ep % 50 == 0:
            save_model(model, args.o)
    return model
Exemple #2
0
                    help="use CUDA to speed up computation")
opt = parser.parse_args()

# check cuda
if opt.cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")
# load LR image
open_image = lambda x: TF.to_tensor(Image.open(x).convert("RGB"))
image = open_image(opt.filename).unsqueeze(0)

# load model
device = torch.device(
    "cuda" if opt.cuda and torch.cuda.is_available() else "cpu")
map_location = "cuda:0" if opt.cuda else device
checkpoint = load_checkpoint(opt.checkpoint, map_location)
# init model
model = SRCNN()
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
# process input
with torch.no_grad():
    image = image.to(device)
    output = model(image)
# save result
output_name = os.path.splitext(opt.filename)
output_name = "_srcnn_x3".join(output_name)
torchvision.utils.save_image(output[0], output_name)

print("Output HR image saved to '{output}'".format(output=output_name))