def main(): parser = argparse.ArgumentParser() parser.add_argument("--batchSize", type=int, default=128, help="Training batch size. Default 128") parser.add_argument("--Epochs", type=int, default=50, help="Number of epochs to train for") parser.add_argument("--lr", type=float, default=0.1, help="Learning Rate. Default=0.1") parser.add_argument("--step", type=int, default=10, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=10") parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number") parser.add_argument("--cuda", action="store_true", help="Use cuda?") parser.add_argument("--clip", type=float, default=0.4, help="Clipping Gradients. Default=0.4") parser.add_argument("--threads", type=int, default=1, help="Number of threads for data loader to use, Default: 1") parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9") parser.add_argument("--weight-decay", default=1e-4, type=float, help="Weight decay, Default: 1e-4") parser.add_argument("--pretrained", default='', type=str, help="Path to pretrained model") parser.add_argument("--train_data", required=True, type=str, help="Path to preprocessed train dataset") parser.add_argument("--test_data", default="./assets/", type=str, help="Path to file containing test images") args = parser.parse_args() cuda = args.cuda if cuda: print("=> use gpu id: '{}'".format(0)) os.environ["CUDA_VISIBLE_DEVICES"] = '0' if not torch.cuda.is_available(): raise Exception("No GPU found or Wrong gpu id, please run without --cuda") cudnn.benchmark = True train_set = prepareDataset("data/train.h5") train_data = DataLoader(dataset=train_set, num_workers=args.threads, batch_size=args.batchSize, shuffle=True) model = VDSR() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) criterion = nn.MSELoss(size_average=False) if cuda: model = model.cuda() criterion = criterion.cuda() if args.pretrained: if os.path.isfile(args.pretrained): print("=> loading model '{}'".format(args.pretrained)) checkpoint = torch.load(args.pretrained) args.start_epoch = checkpoint["epoch"] + 1 model.load_state_dict(checkpoint['model'].state_dict()) else: print("No model found at '{}'".format(opt.pretrained)) train(args.start_epoch, train_data, optimizer, model, criterion, args.Epochs, args) eval(model, args)
import torch import matplotlib.pyplot as plt from model import VDSR import cv2 import torchvision.transforms as T import numpy as np import math device = torch.device('cuda:0') transform = T.ToTensor() net = VDSR() checkpoint = torch.load('D:/VDSR_SGD_epoch_60.pth') net.load_state_dict(checkpoint['model_state_dict']) net = net.to(device) net.eval() image_path = 'D:/train_data/91/000tt16.bmp' img = cv2.imread(image_path) img_r = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) img = cv2.resize(img, (img.shape[1], img.shape[0] // 2), interpolation=cv2.INTER_CUBIC) #img_original=img_r[200:230,300:330] Y, Cr, Cb = cv2.split(img) patch = Y[200:230, 300:330] plt.imshow(img_r) plt.show() img = transform(Y)