Esempio n. 1
0
    def __init__(self, args, conv=common.default_conv):
        super(HRST_CNN, self).__init__()
        n_resblocks = args.n_resblocks
        args.n_resblocks = args.n_resblocks - args.n_resblocks_ft
        n_feats = args.n_feats
        kernel_size = 3
        scale = args.scale[0]
        act = nn.ReLU(True)

        body_ft = [
            common.ResBlock(conv,
                            n_feats,
                            kernel_size,
                            act=act,
                            res_scale=args.res_scale)
            for _ in range(args.n_resblocks_ft)
        ]
        body_ft.append(conv(n_feats, n_feats, kernel_size))

        tail_ft = [
            conv(n_feats, n_feats, kernel_size),
            conv(n_feats, n_feats, kernel_size),
            conv(n_feats, args.n_colors, kernel_size)
        ]
        premodel = EDSR(args)
        self.sub_mean = premodel.sub_mean
        self.head = premodel.head
        body = premodel.body
        body_child = list(body.children())
        body_child.pop()
        self.body = nn.Sequential(*body_child)
        self.body_ft = nn.Sequential(*body_ft)
        self.tail_ft = nn.Sequential(*tail_ft)
        self.add_mean = premodel.add_mean
        args.n_resblocks = n_resblocks
from model.ssl import EDSR_Zoom
from option import args
model = EDSR_Zoom(args)

from model.edsr import EDSR
model = EDSR(args)
#
from model.rdn import RDN
model = RDN(args)
#
# from model.san import SAN
# model = SAN(args)
#
# from model.rcan import RCAN
# model = RCAN(args)

model = model.cuda()
from torchsummary import summary
channels = 3
H = 32
W = 32
summary(model, input_size=(channels, H, W))
Esempio n. 3
0
def get_EDSR():
    edsr = EDSR(args)
    save_path = "../../checkpoints/edsr/edsr_baseline_x2-1bc95232.pt"
    load_stat = torch.load(save_path)
    edsr.load_state_dict(load_stat)
    return edsr
Esempio n. 4
0
    def __init__(self, args, conv=common.default_conv):
        super(NHR, self).__init__()
        n_resblocks = args.n_resblocks
        args.n_resblocks = args.n_resblocks - args.n_resblocks_ft
        n_feats = args.n_feats
        kernel_size = 3
        scale = args.scale[0]
        act = nn.ReLU(True)
        n_color = args.n_colors
        self.normal_lr = args.normal_lr == 'lr'
        self.args = args
        if self.normal_lr:
            body_ft = [
                ResBlock(conv,
                         n_feats + 4,
                         n_feats + 4,
                         kernel_size,
                         act=act,
                         res_scale=args.res_scale)
                for _ in range(args.n_resblocks_ft)
            ]
            body_ft.append(conv2d(n_feats + 4, n_feats, kernel_size, act=True))

            tail_ft1 = [
                common.Upsampler(conv, scale, n_feats, act=True),
                conv2d(n_feats, n_feats + 4, kernel_size, act=True),
            ]
            tail_ft2 = [
                conv2d(n_feats + 4, n_feats + 4, kernel_size, act=True),
                conv2d(n_feats + 4, n_feats + 4, kernel_size, act=True),
                conv2d(n_feats + 4, n_feats + 4, kernel_size, act=True),
                conv2d(n_feats + 4, n_color, kernel_size, act=False)
            ]
        else:
            body_ft = [
                ResBlock(conv,
                         n_feats,
                         n_feats,
                         kernel_size,
                         act=act,
                         res_scale=args.res_scale),
                ResBlock(conv,
                         n_feats,
                         n_feats,
                         kernel_size,
                         act=act,
                         res_scale=args.res_scale)
            ]
            # ResBlock(conv, n_feats+4, n_feats+4, kernel_size, act=act, res_scale=args.res_scale)
            #]
            body_ft.append(conv2d(n_feats, n_feats, kernel_size, act=True))

            tail_ft1 = [
                common.Upsampler(conv, scale, n_feats, act=True),
                conv2d(n_feats, n_feats, kernel_size, act=True),
            ]
            tail_ft2 = [
                conv2d(n_feats + 4, n_feats + 4, kernel_size, act=True),
                conv2d(n_feats + 4, n_feats + 4, kernel_size, act=True),
                conv2d(n_feats + 4, n_feats + 4, kernel_size, act=True),
                conv2d(n_feats + 4, n_color, kernel_size, act=False)
            ]

        premodel = EDSR(args)
        self.sub_mean = premodel.sub_mean
        self.head = premodel.head
        body = premodel.body
        body_child = list(body.children())
        body_child.pop()
        self.body = premodel.body
        self.body_ft = nn.Sequential(*body_ft)
        self.tail_ft1 = nn.Sequential(*tail_ft1)
        self.tail_ft2 = nn.Sequential(*tail_ft2)
        self.add_mean = premodel.add_mean
        args.n_resblocks = n_resblocks
Esempio n. 5
0
def eval_path(args):
    if not os.path.exists(args.outputs_dir):
        os.makedirs(args.outputs_dir)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True

    device_ids = list(range(args.gpus))
    model = EDSR(args)
    model = nn.DataParallel(model, device_ids=device_ids)
    model = model.cuda()

    if args.resume:
        if os.path.isdir(args.resume):
            #获取目录中最后一个
            pth_list = sorted(glob(os.path.join(args.resume, '*.pth')))
            if len(pth_list) > 0:
                args.resume = pth_list[-1]
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            state_dict = checkpoint['state_dict']
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                namekey = 'module.' + k  # remove `module.`
                new_state_dict[namekey] = v
            model.load_state_dict(new_state_dict)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    model.eval()

    file_names = sorted(os.listdir(args.test_lr))
    lr_list = []
    for one in file_names:
        dst_dir = os.path.join(args.outputs_dir, one)
        if os.path.exists(dst_dir) and len(os.listdir(dst_dir)) == 100:
            continue
        lr_tmp = sorted(glob(os.path.join(args.test_lr, one, '*.png')))
        lr_list.extend(lr_tmp)

    data_set = EvalDataset(lr_list)
    eval_loader = DataLoader(data_set,
                             batch_size=args.batch_size,
                             num_workers=args.workers)

    with tqdm(total=(len(data_set) - len(data_set) % args.batch_size)) as t:
        for data in eval_loader:
            inputs, names = data
            inputs = inputs.cuda()
            with torch.no_grad():
                outputs = model(inputs).data.float().cpu().clamp_(0,
                                                                  255).numpy()
            for img, file in zip(outputs, names):
                img = np.transpose(img[[2, 1, 0], :, :], (1, 2, 0))
                img = img.round()

                arr = file.split('/')
                dst_dir = os.path.join(args.outputs_dir, arr[-2])
                if not os.path.exists(dst_dir):
                    os.makedirs(dst_dir)
                dst_name = os.path.join(dst_dir, arr[-1])

                cv2.imwrite(dst_name, img)
            t.update(len(names))
Esempio n. 6
0
def main():
    global opt, model
    if opt.cuda:
        print("=> use gpu id: '{}'".format(opt.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpus)
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed)
    torch.manual_seed(opt.seed)
    if opt.cuda:
        torch.cuda.manual_seed(opt.seed)

    cudnn.benchmark = True
    scale = int(args.scale[0])
    print("===> Loading datasets")

    opt.n_train = 400
    loader = data.Data(opt)
    opt_high = copy.deepcopy(opt)
    opt_high.offset_train = 400
    opt_high.n_train = 400

    loader_high = data.Data(opt_high)

    training_data_loader = loader.loader_train
    training_high_loader = loader_high.loader_train
    test_data_loader = loader.loader_test

    print("===> Building model")
    GLR = _NetG_DOWN(stride=2)  #EDSR(args)
    GHR = EDSR(
        args)  #_NetG_UP()#Generator(G_input_dim, num_filters, G_output_dim)
    GDN = _NetG_DOWN(stride=1)  #EDSR(args)
    DLR = _NetD(
        stride=1
    )  # True)# _NetD(3)#Generator(G_input_dim, num_filters, G_output_dim)
    DHR = _NetD(stride=2)  #Generator(G_input_dim, num_filters, G_output_dim)
    GNO = _NetG_DOWN(stride=1)  #EDSR(args)

    Loaded = torch.load(
        '../experiment/model/EDSR_baseline_x{}.pt'.format(scale))
    GHR.load_state_dict(Loaded)

    model = nn.ModuleList()

    model.append(GDN)  #DN
    model.append(GHR)
    model.append(GLR)  #LR
    model.append(DLR)
    model.append(DHR)
    model.append(GNO)  #

    print(model)

    cudnn.benchmark = True
    # optionally resume from a checkpoint
    opt.resume = 'model_total_{}.pth'.format(scale)
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1
            model.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # model[4] = _NetD_(4, True)#, True, 4)

    print("===> Setting GPU")
    if opt.cuda:
        model = model.cuda()

    print("===> Setting Optimizer")
    # optimizer = optim.Adam(model.parameters(), lr=opt.lr)#, momentum=opt.momentum, weight_decay=opt.weight_decay)

    if opt.test_only:
        print('===> Testing')
        test(test_data_loader, model, opt.start_epoch)
        return

    print("===> Training Step 1.")
    for epoch in range(opt.start_epoch, opt.epochs + 1):
        train(training_data_loader, training_high_loader, model, epoch, False)
        save_checkpoint(model, epoch, scale)
        test(test_data_loader, model, epoch)

    print("===> Training Step 2.")
    opt.lr = 1e-4
    for epoch in range(opt.start_epoch, opt.epochs + 1):
        train(training_data_loader, training_high_loader, model, epoch, True)
        save_checkpoint(model, epoch, scale)
        test(test_data_loader, model, epoch)
Esempio n. 7
0
def main():
    global opt, model
    if opt.cuda:
        print("=> use gpu id: '{}'".format(opt.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpus)
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed)
    torch.manual_seed(opt.seed)
    if opt.cuda:
        torch.cuda.manual_seed(opt.seed)

    cudnn.benchmark = True
    scale = int(args.scale[0])
    print("===> Loading datasets")

    opt.n_train = 400
    loader = data.Data(opt)
    opt_high = copy.deepcopy(opt)
    opt_high.offset_train = 400
    opt_high.n_train = 400

    loader_high = data.Data(opt_high)

    training_data_loader = loader.loader_train
    training_high_loader = loader_high.loader_train
    test_data_loader = loader.loader_test

    print("===> Building model")
    GLR = _NetG_DOWN(stride=2)
    GHR = EDSR(args)
    GDN = _NetG_DOWN(stride=1)
    DLR = _NetD(stride=1)
    DHR = _NetD(stride=2)
    GNO = _NetG_DOWN(stride=1)

    Loaded = torch.load(
        '../experiment/model/EDSR_baseline_x{}.pt'.format(scale))
    GHR.load_state_dict(Loaded)

    model = nn.ModuleList()

    model.append(GDN)  #DN
    model.append(GHR)
    model.append(GLR)  #LR
    model.append(DLR)
    model.append(DHR)
    model.append(GNO)  #

    cudnn.benchmark = True

    print("===> Setting GPU")
    if opt.cuda:
        model = model.cuda()

    optG = torch.optim.Adam(
        list(model[0].parameters()) + list(model[1].parameters()) +
        list(model[2].parameters()) + list(model[5].parameters()),
        lr=opt.lr,
        weight_decay=0)
    optD = torch.optim.Adam(list(model[3].parameters()) +
                            list(model[4].parameters()),
                            lr=opt.lr,
                            weight_decay=0)

    # optionally resume from a checkpoint
    opt.resume = 'model_total_{}.pth'.format(scale)
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1

            optG.load_state_dict(checkpoint['optimizer'][0])
            optD.load_state_dict(checkpoint['optimizer'][1])
            model.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # opt.start_epoch = 401
    step = 2 if opt.start_epoch > opt.epochs else 1

    # model.load_state_dict(torch.load('backup.pt'))

    optimizer = [optG, optD]

    # print("===> Setting Optimizer")

    if opt.test_only:
        print('===> Testing')
        test(test_data_loader, model, opt.start_epoch)
        return

    if step == 1:
        print("===> Training Step 1.")
        for epoch in range(opt.start_epoch, opt.epochs + 1):
            train(training_data_loader, training_high_loader, model, optimizer,
                  epoch, False)
            save_checkpoint(model, optimizer, epoch, scale)
            test(test_data_loader, model, epoch)
        torch.save(model.state_dict(), 'backup.pt')
    elif step == 2:
        print("===> Training Step 2.")
        opt.lr = 1e-4
        for epoch in range(opt.start_epoch + 1, opt.epochs * 2 + 1):
            train(training_data_loader, training_high_loader, model, optimizer,
                  epoch, True)
            save_checkpoint(model, optimizer, epoch, scale)
            test(test_data_loader, model, epoch)
Esempio n. 8
0
def main(args):
    print("===> Loading datasets")
    file_name = sorted(os.listdir(args.data_lr))
    lr_list = []
    hr_list = []
    for one in file_name:
        lr_tmp = sorted(glob(os.path.join(args.data_lr, one, '*.png')))
        lr_list.extend(lr_tmp)
        hr_tmp = sorted(glob(os.path.join(args.data_hr, one, '*.png')))
        if len(hr_tmp) != 100:
            print(one)
        hr_list.extend(hr_tmp)

    # lr_list = glob(os.path.join(args.data_lr, '*'))
    # hr_list = glob(os.path.join(args.data_hr, '*'))
    # lr_list = lr_list[0:max_index]
    # hr_list = hr_list[0:max_index]

    data_set = DatasetLoader(lr_list, hr_list, args.patch_size, args.scale)
    data_len = len(data_set)
    train_loader = DataLoader(data_set,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              shuffle=True,
                              pin_memory=True,
                              drop_last=True)

    print("===> Building model")
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True

    device_ids = list(range(args.gpus))
    model = EDSR(args)
    criterion = nn.L1Loss(reduction='sum')

    print("===> Setting GPU")
    model = nn.DataParallel(model, device_ids=device_ids)
    model = model.cuda()
    criterion = criterion.cuda()

    start_epoch = args.start_epoch
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isdir(args.resume):
            #获取目录中最后一个
            pth_list = sorted(glob(os.path.join(args.resume, '*.pth')))
            if len(pth_list) > 0:
                args.resume = pth_list[-1]
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            start_epoch = checkpoint['epoch'] + 1
            state_dict = checkpoint['state_dict']
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                namekey = 'module.' + k  # remove `module.`
                new_state_dict[namekey] = v
            model.load_state_dict(new_state_dict)

            #如果文件中有lr,则不用启动参数
            args.lr = checkpoint.get('lr', args.lr)

    if args.start_epoch != 0:
        #如果设置了 start_epoch 则不用checkpoint中的epoch参数
        start_epoch = args.start_epoch

    print("===> Setting Optimizer")
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.lr,
                           weight_decay=args.weight_decay,
                           betas=(0.9, 0.999),
                           eps=1e-08)

    # record = []
    print("===> Training")
    for epoch in range(start_epoch, args.epochs):
        adjust_lr(optimizer, epoch)

        losses, psnrs = one_epoch_train_logger(model, optimizer, criterion,
                                               data_len, train_loader, epoch,
                                               args.epochs, args.batch_size,
                                               optimizer.param_groups[0]["lr"])

        # lr = optimizer.param_groups[0]["lr"]
        # the_lr = 1e-2
        # lr_len = 2
        # while lr + (1e-9) < the_lr:
        #     the_lr *= 0.1
        #     lr_len += 1
        # record.append([losses.avg,psnrs.avg,lr_len])

        # save model
        # if epoch+1 != args.epochs:
        #     continue
        model_out_path = os.path.join(
            args.checkpoint, "model_epoch_%04d_edsr_loss_%.3f_psnr_%.3f.pth" %
            (epoch, losses.avg, psnrs.avg))
        if not os.path.exists(args.checkpoint):
            os.makedirs(args.checkpoint)
        torch.save(
            {
                'state_dict': model.module.state_dict(),
                "epoch": epoch,
                'lr': optimizer.param_groups[0]["lr"]
            }, model_out_path)
Esempio n. 9
0
                                  shuffle=True)
testing_data_loader = DataLoader(dataset=test_set,
                                 num_workers=threads,
                                 batch_size=testBatchSize,
                                 shuffle=False)

if args.verbose:
    print("{} training images / {} testing images".format(
        len(train_set), len(test_set)))
    print("===> dataset loaded !")
    print('===> Building model')

model = EDSR(upscale_factor,
             input_channels,
             target_channels,
             n_resblocks=n_resblocks,
             n_feats=n_patch_features,
             res_scale=.1,
             bn=None).to(device)

criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=lr)

if args.verbose:
    print(model.parameters())
    print(model)
    print("==> Model built")


def train(epoch):
    epoch_loss = 0