Exemplo n.º 1
0
def main():
    cuda = opt.cuda
    if cuda:
        print("=> use gpu id: '{}'".format(opt.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
        if not torch.cuda.is_available():
                raise Exception("No GPU found or Wrong gpu id, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed)
    torch.manual_seed(opt.seed)
    if cuda:
        torch.cuda.manual_seed(opt.seed)

    cudnn.benchmark = True

    print("===> Loading datasets")
    train_set = TrainDatasetFromFolder()
    training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

    print("===> Building model")
    model = Net()
    criterion = nn.MSELoss(reduction='sum')

    print("===> Setting GPU")
    if cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1
            model.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # optionally copy weights from a checkpoint
    if opt.pretrained:
        if os.path.isfile(opt.pretrained):
            print("=> loading model '{}'".format(opt.pretrained))
            weights = torch.load(opt.pretrained)
            model.load_state_dict(weights['model'].state_dict())
        else:
            print("=> no model found at '{}'".format(opt.pretrained))  

    print("===> Setting Optimizer")
    # optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)
    optimizer = optim.Adam([
        {'params': model.extr.parameters()},
        {'params': model.mapping.parameters()},
        {'params': model.recon.parameters(), 'lr': opt.lr*0.1}
    ], lr=opt.lr)
    
    print("===> Training")
    for epoch in range(opt.start_epoch, opt.nEpochs + 1):
        train(training_data_loader, optimizer, model, criterion, epoch)
        save_checkpoint(model, epoch)
Exemplo n.º 2
0
def main():
    global par, model
    par = parser.parse_args()
    print(par)

    print("===> 建立模型")
    # model=Net() #模型
    model=torch.load("./checkpoint/pre_model.pth")["model"]
    criterion=nn.MSELoss(reduction='sum') #损失函数reduction='sum'
    print("===> 加载数据集")
    train_set = TrainDatasetFromFolder("./data/train_set.h5")
    train_loader = DataLoader(dataset=train_set,num_workers=1,batch_size=par.batch_size, shuffle=True)
    print("===> 设置 GPU")
    cuda = par.cuda
    if cuda :
        if torch.cuda.is_available():
            model.cuda()
            criterion.cuda()
        else:raise Exception("没有可用的显卡设备")

    # optionally resume from a checkpoint
    if par.resume:
        if os.path.isfile(par.resume):
            checkpoint=torch.load(par.resume)
            par.start_epoch=checkpoint['epoch']
            model.load_state_dict(checkpoint["model"].statedict())

    print("===> 设置 优化器")
    optimizer = optim.SGD(model.parameters(), lr=par.lr, momentum=par.momentum, weight_decay=par.weight_decay)

    print("===> 进行训练")
    plt.figure(figsize=(8, 6), dpi=80)
    draw_list = []
    for epoch in range(par.start_epoch, par.nEpochs + 1):
        draw_list=train(train_loader, optimizer, model, criterion, epoch,draw_list)
        save_checkpoint(model, epoch)
        draw(range(1,len(draw_list)+1), 10, draw_list, 10,{"EPOCH:":epoch,"LR:":round(optimizer.param_groups[0]["lr"],4)})
    plt.show()
Exemplo n.º 3
0
def main(run):
    print(opt)
    num_gpus = torch.cuda.device_count()
    print("available gpus are", num_gpus)
    num_cpus = multiprocessing.cpu_count()
    print("available cpus are", num_cpus)

    save_dir = './result/{}/{}{}x{:.0e}/FULLTRAIN_SUPER_R{}_RAN{}/'.format(
        opt.train_dir, opt.model, opt.upscale_factor, opt.lr,
        opt.inter_frequency, opt.ranlevel)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    res_dir = save_dir + str(run) + '/'
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    cuda = opt.cuda
    if cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")

    torch.manual_seed(opt.seed)
    torch.backends.cudnn.enabled = True
    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    np.random.seed(opt.seed)
    if cuda:
        torch.cuda.manual_seed(opt.seed)

    print('===> Loading datasets')

    train_dir = "./data/{}/".format(opt.train_dir)
    train_set = TrainDatasetFromFolder(
        train_dir,
        is_gray=True,
        random_scale=True,
        crop_size=opt.upscale_factor * opt.patchsize,
        rotate=True,
        fliplr=True,
        fliptb=True,
        scale_factor=opt.upscale_factor,
        bic_inp=True if opt.model == 'vdsr' else False)
    training_data_loader = DataLoader(dataset=train_set,
                                      num_workers=opt.threads,
                                      batch_size=opt.batchSize,
                                      shuffle=True)

    val_set = TestDatasetFromFolder(
        train_dir,
        is_gray=True,
        scale_factor=opt.upscale_factor,
        bic_inp=True if opt.model == 'vdsr' else False)
    validating_data_loader = DataLoader(dataset=val_set,
                                        num_workers=opt.threads,
                                        batch_size=opt.testBatchSize,
                                        shuffle=False)

    if opt.model == 'vdsr':
        model_mse = vdsr.Net(num_channels=1, base_filter=64, num_residuals=18)
        model_l1 = vdsr.Net(num_channels=1, base_filter=64, num_residuals=18)
        model_ssim = vdsr.Net(num_channels=1, base_filter=64, num_residuals=18)
    else:
        model_mse = edsr.Net(num_channels=1,
                             base_filter=64,
                             num_residuals=18,
                             scale=opt.upscale_factor)
        model_l1 = edsr.Net(num_channels=1,
                            base_filter=64,
                            num_residuals=18,
                            scale=opt.upscale_factor)
        model_ssim = edsr.Net(num_channels=1,
                              base_filter=64,
                              num_residuals=18,
                              scale=opt.upscale_factor)

    if opt.resume:
        model_mse = load_model(res_dir + str(opt.alpha) +
                               "model_{}_{}_epoch_{}.pth".format(
                                   'mse', opt.upscale_factor, opt.nEpochs))
        model_l1 = load_model(res_dir + str(opt.alpha) +
                              "model_{}_{}_epoch_{}.pth".format(
                                  'l1', opt.upscale_factor, opt.nEpochs))
        model_ssim = load_model(res_dir + str(opt.alpha) +
                                "model_{}_{}_epoch_{}.pth".format(
                                    'ssim', opt.upscale_factor, opt.nEpochs))

    print('===> Building criterions')
    criterion_mse = nn.MSELoss()
    criterion_l1 = nn.L1Loss()
    criterion_ssim = ssim_loss.SSIM(size_average=False)

    print('===> Building optimizers')
    optimizer_mse = optim.Adam(model_mse.parameters(), lr=opt.lr)
    optimizer_l1 = optim.Adam(model_l1.parameters(), lr=opt.lr)
    optimizer_ssim = optim.Adam(model_ssim.parameters(), lr=opt.lr)

    m1 = Train_op(model_mse, optimizer_mse, criterion_mse, 'mse', opt.lr,
                  opt.upscale_factor, opt.cuda)
    m2 = Train_op(model_l1, optimizer_l1, criterion_l1, 'l1', opt.lr,
                  opt.upscale_factor, opt.cuda)
    m3 = Train_op(model_ssim, optimizer_ssim, criterion_ssim, 'ssim', opt.lr,
                  opt.upscale_factor, opt.cuda)

    models = [m1, m2, m3]

    x_label = []
    y_label = []

    print('===> start training')

    for epoch in range(0, opt.nEpochs + 1, opt.inter_frequency):
        tick_time = time.time()
        print('running epoch {}'.format(epoch))

        for m in models:
            lr = opt.lr * (opt.decay_rate**(epoch // opt.step))
            m.update_lr(lr)
            print('epoch {}, learning rate is {}'.format(epoch, lr))

        update_loss = interchange_im(models, training_data_loader,
                                     validating_data_loader, opt.alpha)

        print('evaluated loss is', update_loss)
        x_label.append(epoch)
        y_label.append(update_loss)

        print('this epoch cost {} seconds.'.format(time.time() - tick_time))

        if epoch % (opt.nEpochs // 10) == 0:
            for m in models:
                m.checkpoint(epoch, res_dir, prefix=str(opt.alpha))

    for m in models:
        m.checkpoint(epoch, res_dir, prefix=str(opt.alpha))

    # save obtained losses
    x_label = np.asarray(x_label)
    y_label = np.asarray(y_label).transpose()
    output = np.insert(y_label, 0, x_label, axis=0)

    if len(models) > 1:
        np.savetxt(res_dir + str(opt.alpha) + opt.save_file,
                   output,
                   fmt='%3.5f')
    else:
        np.savetxt(res_dir + 'loss_' + models[0].name + '.txt',
                   output,
                   fmt='%3.5f')
Exemplo n.º 4
0
# 测试集测试
import PSNR
import torch
import random
import torch.utils.data as Data
from torchvision.transforms import ToTensor, ToPILImage
from dataset import TrainDatasetFromFolder

tran_im = ToPILImage()
tran_ten = ToTensor()

test_set = TrainDatasetFromFolder(
    "./data/test_set_s3.h5")  #h5数据集制作工具在data中,可自己制作
test_loader = Data.DataLoader(dataset=test_set,
                              num_workers=1,
                              batch_size=40,
                              shuffle=False)
net = torch.load("./checkpoint/model_epoch_100.pth")["model"]

net.cpu()
i = random.randint(0, 100)

aa, bb = test_loader.dataset[i]

lim1 = tran_im(aa)
label = tran_im(bb)

lim1.show()

print("CUBIC_PSNR:", PSNR.psnr(lim1, label))
Exemplo n.º 5
0
def get_training_set(dir, options):
    return TrainDatasetFromFolder(dir + '/train_HR', dir + '/train_LR',
                                  options)