예제 #1
0
def main():
    global opt, model
    opt = parser.parse_args()
    print(opt)

    cuda = opt.cuda
    if cuda:
        print("=> use gpu id: '{}'".format(opt.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
        if not torch.cuda.is_available():
                raise Exception("No GPU found or Wrong gpu id, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed)
    torch.manual_seed(opt.seed)
    if cuda:
        torch.cuda.manual_seed(opt.seed)

    cudnn.benchmark = True

    print("===> Loading datasets")
    train_set = DatasetFromHdf5("./data/SuperResolution/train_291_31_x234.h5")
    training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

    print("===> Building model")
    model = MemNet(1, 64, 6, 6)
    criterion = nn.MSELoss(size_average=False)

    print("===> Setting GPU")
    if cuda:
        #model = model.cuda()
        model = torch.nn.DataParallel(model).cuda()  #multi-card data parallel
        criterion = criterion.cuda()

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1
            model.load_state_dict(checkpoint["model"])
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # optionally copy weights from a checkpoint
    if opt.pretrained:
        if os.path.isfile(opt.pretrained):
            print("=> loading model '{}'".format(opt.pretrained))
            weights = torch.load(opt.pretrained)
            model.load_state_dict(weights['model'].state_dict())
            model.load_state_dict(weights['model'].state_dict())
        else:
            print("=> no model found at '{}'".format(opt.pretrained))  

    print("===> Setting Optimizer")
    optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)

    print("===> Training")
    for epoch in range(opt.start_epoch, opt.nEpochs + 1):
        train(training_data_loader, optimizer, model, criterion, epoch)
        save_checkpoint(model, epoch)
예제 #2
0
def main():
    print("===> Building model")
    model = MemNet(1, 64, 6, 6)
    print(model)
예제 #3
0
import torch
from torch.autograd import Variable

#from memnet1 import MemNet
from memnet1 import MemNet
from visualize_net import  make_dot

x = Variable(torch.randn(1,1,31,31))#change 12 to the channel number of network input
model = MemNet(1,64,65,6)
y = model(x)
g = make_dot(y)
g.view()
예제 #4
0
    img[:, :, 2] = ycbcr[:, :, 2]
    img = Image.fromarray(img, "YCbCr").convert("RGB")
    return img


opt = parser.parse_args()
cuda = opt.cuda

if cuda:
    print("=> use gpu id: '{}'".format(opt.gpus))
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
    if not torch.cuda.is_available():
        raise Exception(
            "No GPU found or Wrong gpu id, please run without --cuda")

model = MemNet(1, 64, 6, 6)
state = convert_state_dict(torch.load(opt.model)['model'])
#Since using multiple card for training model, so we need employ convert_state_dict() to remove the module prefix when loading module
model.load_state_dict(state)

im_gt_ycbcr = imread("data/SuperResolution/Set5/" + opt.image + ".bmp",
                     mode="YCbCr")
im_b_ycbcr = imread("data/SuperResolution/Set5/" + opt.image + "_scale_" +
                    str(opt.scale) + ".bmp",
                    mode="YCbCr")

im_gt_y = im_gt_ycbcr[:, :, 0].astype(float)
im_b_y = im_b_ycbcr[:, :, 0].astype(float)

psnr_bicubic = PSNR(im_gt_y, im_b_y, shave_border=opt.scale)