Esempio n. 1
0
def get_EDSR():
    edsr = EDSR(args)
    save_path = "../../checkpoints/edsr/edsr_baseline_x2-1bc95232.pt"
    load_stat = torch.load(save_path)
    edsr.load_state_dict(load_stat)
    return edsr
Esempio n. 2
0
def eval_path(args):
    if not os.path.exists(args.outputs_dir):
        os.makedirs(args.outputs_dir)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True

    device_ids = list(range(args.gpus))
    model = EDSR(args)
    model = nn.DataParallel(model, device_ids=device_ids)
    model = model.cuda()

    if args.resume:
        if os.path.isdir(args.resume):
            #获取目录中最后一个
            pth_list = sorted(glob(os.path.join(args.resume, '*.pth')))
            if len(pth_list) > 0:
                args.resume = pth_list[-1]
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            state_dict = checkpoint['state_dict']
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                namekey = 'module.' + k  # remove `module.`
                new_state_dict[namekey] = v
            model.load_state_dict(new_state_dict)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    model.eval()

    file_names = sorted(os.listdir(args.test_lr))
    lr_list = []
    for one in file_names:
        dst_dir = os.path.join(args.outputs_dir, one)
        if os.path.exists(dst_dir) and len(os.listdir(dst_dir)) == 100:
            continue
        lr_tmp = sorted(glob(os.path.join(args.test_lr, one, '*.png')))
        lr_list.extend(lr_tmp)

    data_set = EvalDataset(lr_list)
    eval_loader = DataLoader(data_set,
                             batch_size=args.batch_size,
                             num_workers=args.workers)

    with tqdm(total=(len(data_set) - len(data_set) % args.batch_size)) as t:
        for data in eval_loader:
            inputs, names = data
            inputs = inputs.cuda()
            with torch.no_grad():
                outputs = model(inputs).data.float().cpu().clamp_(0,
                                                                  255).numpy()
            for img, file in zip(outputs, names):
                img = np.transpose(img[[2, 1, 0], :, :], (1, 2, 0))
                img = img.round()

                arr = file.split('/')
                dst_dir = os.path.join(args.outputs_dir, arr[-2])
                if not os.path.exists(dst_dir):
                    os.makedirs(dst_dir)
                dst_name = os.path.join(dst_dir, arr[-1])

                cv2.imwrite(dst_name, img)
            t.update(len(names))
Esempio n. 3
0
def main():
    global opt, model
    if opt.cuda:
        print("=> use gpu id: '{}'".format(opt.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpus)
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed)
    torch.manual_seed(opt.seed)
    if opt.cuda:
        torch.cuda.manual_seed(opt.seed)

    cudnn.benchmark = True
    scale = int(args.scale[0])
    print("===> Loading datasets")

    opt.n_train = 400
    loader = data.Data(opt)
    opt_high = copy.deepcopy(opt)
    opt_high.offset_train = 400
    opt_high.n_train = 400

    loader_high = data.Data(opt_high)

    training_data_loader = loader.loader_train
    training_high_loader = loader_high.loader_train
    test_data_loader = loader.loader_test

    print("===> Building model")
    GLR = _NetG_DOWN(stride=2)
    GHR = EDSR(args)
    GDN = _NetG_DOWN(stride=1)
    DLR = _NetD(stride=1)
    DHR = _NetD(stride=2)
    GNO = _NetG_DOWN(stride=1)

    Loaded = torch.load(
        '../experiment/model/EDSR_baseline_x{}.pt'.format(scale))
    GHR.load_state_dict(Loaded)

    model = nn.ModuleList()

    model.append(GDN)  #DN
    model.append(GHR)
    model.append(GLR)  #LR
    model.append(DLR)
    model.append(DHR)
    model.append(GNO)  #

    cudnn.benchmark = True

    print("===> Setting GPU")
    if opt.cuda:
        model = model.cuda()

    optG = torch.optim.Adam(
        list(model[0].parameters()) + list(model[1].parameters()) +
        list(model[2].parameters()) + list(model[5].parameters()),
        lr=opt.lr,
        weight_decay=0)
    optD = torch.optim.Adam(list(model[3].parameters()) +
                            list(model[4].parameters()),
                            lr=opt.lr,
                            weight_decay=0)

    # optionally resume from a checkpoint
    opt.resume = 'model_total_{}.pth'.format(scale)
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1

            optG.load_state_dict(checkpoint['optimizer'][0])
            optD.load_state_dict(checkpoint['optimizer'][1])
            model.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # opt.start_epoch = 401
    step = 2 if opt.start_epoch > opt.epochs else 1

    # model.load_state_dict(torch.load('backup.pt'))

    optimizer = [optG, optD]

    # print("===> Setting Optimizer")

    if opt.test_only:
        print('===> Testing')
        test(test_data_loader, model, opt.start_epoch)
        return

    if step == 1:
        print("===> Training Step 1.")
        for epoch in range(opt.start_epoch, opt.epochs + 1):
            train(training_data_loader, training_high_loader, model, optimizer,
                  epoch, False)
            save_checkpoint(model, optimizer, epoch, scale)
            test(test_data_loader, model, epoch)
        torch.save(model.state_dict(), 'backup.pt')
    elif step == 2:
        print("===> Training Step 2.")
        opt.lr = 1e-4
        for epoch in range(opt.start_epoch + 1, opt.epochs * 2 + 1):
            train(training_data_loader, training_high_loader, model, optimizer,
                  epoch, True)
            save_checkpoint(model, optimizer, epoch, scale)
            test(test_data_loader, model, epoch)
Esempio n. 4
0
def main():
    global opt, model
    if opt.cuda:
        print("=> use gpu id: '{}'".format(opt.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpus)
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed)
    torch.manual_seed(opt.seed)
    if opt.cuda:
        torch.cuda.manual_seed(opt.seed)

    cudnn.benchmark = True
    scale = int(args.scale[0])
    print("===> Loading datasets")

    opt.n_train = 400
    loader = data.Data(opt)
    opt_high = copy.deepcopy(opt)
    opt_high.offset_train = 400
    opt_high.n_train = 400

    loader_high = data.Data(opt_high)

    training_data_loader = loader.loader_train
    training_high_loader = loader_high.loader_train
    test_data_loader = loader.loader_test

    print("===> Building model")
    GLR = _NetG_DOWN(stride=2)  #EDSR(args)
    GHR = EDSR(
        args)  #_NetG_UP()#Generator(G_input_dim, num_filters, G_output_dim)
    GDN = _NetG_DOWN(stride=1)  #EDSR(args)
    DLR = _NetD(
        stride=1
    )  # True)# _NetD(3)#Generator(G_input_dim, num_filters, G_output_dim)
    DHR = _NetD(stride=2)  #Generator(G_input_dim, num_filters, G_output_dim)
    GNO = _NetG_DOWN(stride=1)  #EDSR(args)

    Loaded = torch.load(
        '../experiment/model/EDSR_baseline_x{}.pt'.format(scale))
    GHR.load_state_dict(Loaded)

    model = nn.ModuleList()

    model.append(GDN)  #DN
    model.append(GHR)
    model.append(GLR)  #LR
    model.append(DLR)
    model.append(DHR)
    model.append(GNO)  #

    print(model)

    cudnn.benchmark = True
    # optionally resume from a checkpoint
    opt.resume = 'model_total_{}.pth'.format(scale)
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1
            model.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # model[4] = _NetD_(4, True)#, True, 4)

    print("===> Setting GPU")
    if opt.cuda:
        model = model.cuda()

    print("===> Setting Optimizer")
    # optimizer = optim.Adam(model.parameters(), lr=opt.lr)#, momentum=opt.momentum, weight_decay=opt.weight_decay)

    if opt.test_only:
        print('===> Testing')
        test(test_data_loader, model, opt.start_epoch)
        return

    print("===> Training Step 1.")
    for epoch in range(opt.start_epoch, opt.epochs + 1):
        train(training_data_loader, training_high_loader, model, epoch, False)
        save_checkpoint(model, epoch, scale)
        test(test_data_loader, model, epoch)

    print("===> Training Step 2.")
    opt.lr = 1e-4
    for epoch in range(opt.start_epoch, opt.epochs + 1):
        train(training_data_loader, training_high_loader, model, epoch, True)
        save_checkpoint(model, epoch, scale)
        test(test_data_loader, model, epoch)
Esempio n. 5
0
def main(args):
    print("===> Loading datasets")
    file_name = sorted(os.listdir(args.data_lr))
    lr_list = []
    hr_list = []
    for one in file_name:
        lr_tmp = sorted(glob(os.path.join(args.data_lr, one, '*.png')))
        lr_list.extend(lr_tmp)
        hr_tmp = sorted(glob(os.path.join(args.data_hr, one, '*.png')))
        if len(hr_tmp) != 100:
            print(one)
        hr_list.extend(hr_tmp)

    # lr_list = glob(os.path.join(args.data_lr, '*'))
    # hr_list = glob(os.path.join(args.data_hr, '*'))
    # lr_list = lr_list[0:max_index]
    # hr_list = hr_list[0:max_index]

    data_set = DatasetLoader(lr_list, hr_list, args.patch_size, args.scale)
    data_len = len(data_set)
    train_loader = DataLoader(data_set,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              shuffle=True,
                              pin_memory=True,
                              drop_last=True)

    print("===> Building model")
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True

    device_ids = list(range(args.gpus))
    model = EDSR(args)
    criterion = nn.L1Loss(reduction='sum')

    print("===> Setting GPU")
    model = nn.DataParallel(model, device_ids=device_ids)
    model = model.cuda()
    criterion = criterion.cuda()

    start_epoch = args.start_epoch
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isdir(args.resume):
            #获取目录中最后一个
            pth_list = sorted(glob(os.path.join(args.resume, '*.pth')))
            if len(pth_list) > 0:
                args.resume = pth_list[-1]
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            start_epoch = checkpoint['epoch'] + 1
            state_dict = checkpoint['state_dict']
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                namekey = 'module.' + k  # remove `module.`
                new_state_dict[namekey] = v
            model.load_state_dict(new_state_dict)

            #如果文件中有lr,则不用启动参数
            args.lr = checkpoint.get('lr', args.lr)

    if args.start_epoch != 0:
        #如果设置了 start_epoch 则不用checkpoint中的epoch参数
        start_epoch = args.start_epoch

    print("===> Setting Optimizer")
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.lr,
                           weight_decay=args.weight_decay,
                           betas=(0.9, 0.999),
                           eps=1e-08)

    # record = []
    print("===> Training")
    for epoch in range(start_epoch, args.epochs):
        adjust_lr(optimizer, epoch)

        losses, psnrs = one_epoch_train_logger(model, optimizer, criterion,
                                               data_len, train_loader, epoch,
                                               args.epochs, args.batch_size,
                                               optimizer.param_groups[0]["lr"])

        # lr = optimizer.param_groups[0]["lr"]
        # the_lr = 1e-2
        # lr_len = 2
        # while lr + (1e-9) < the_lr:
        #     the_lr *= 0.1
        #     lr_len += 1
        # record.append([losses.avg,psnrs.avg,lr_len])

        # save model
        # if epoch+1 != args.epochs:
        #     continue
        model_out_path = os.path.join(
            args.checkpoint, "model_epoch_%04d_edsr_loss_%.3f_psnr_%.3f.pth" %
            (epoch, losses.avg, psnrs.avg))
        if not os.path.exists(args.checkpoint):
            os.makedirs(args.checkpoint)
        torch.save(
            {
                'state_dict': model.module.state_dict(),
                "epoch": epoch,
                'lr': optimizer.param_groups[0]["lr"]
            }, model_out_path)