Exemplo n.º 1
0
def eval_path(args):
    if not os.path.exists(args.outputs_dir):
        os.makedirs(args.outputs_dir)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True

    device_ids = list(range(args.gpus))
    model = EDSR(args)
    model = nn.DataParallel(model, device_ids=device_ids)
    model = model.cuda()

    if args.resume:
        if os.path.isdir(args.resume):
            #获取目录中最后一个
            pth_list = sorted(glob(os.path.join(args.resume, '*.pth')))
            if len(pth_list) > 0:
                args.resume = pth_list[-1]
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            state_dict = checkpoint['state_dict']
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                namekey = 'module.' + k  # remove `module.`
                new_state_dict[namekey] = v
            model.load_state_dict(new_state_dict)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    model.eval()

    file_names = sorted(os.listdir(args.test_lr))
    lr_list = []
    for one in file_names:
        dst_dir = os.path.join(args.outputs_dir, one)
        if os.path.exists(dst_dir) and len(os.listdir(dst_dir)) == 100:
            continue
        lr_tmp = sorted(glob(os.path.join(args.test_lr, one, '*.png')))
        lr_list.extend(lr_tmp)

    data_set = EvalDataset(lr_list)
    eval_loader = DataLoader(data_set,
                             batch_size=args.batch_size,
                             num_workers=args.workers)

    with tqdm(total=(len(data_set) - len(data_set) % args.batch_size)) as t:
        for data in eval_loader:
            inputs, names = data
            inputs = inputs.cuda()
            with torch.no_grad():
                outputs = model(inputs).data.float().cpu().clamp_(0,
                                                                  255).numpy()
            for img, file in zip(outputs, names):
                img = np.transpose(img[[2, 1, 0], :, :], (1, 2, 0))
                img = img.round()

                arr = file.split('/')
                dst_dir = os.path.join(args.outputs_dir, arr[-2])
                if not os.path.exists(dst_dir):
                    os.makedirs(dst_dir)
                dst_name = os.path.join(dst_dir, arr[-1])

                cv2.imwrite(dst_name, img)
            t.update(len(names))
Exemplo n.º 2
0
def main(args):
    print("===> Loading datasets")
    file_name = sorted(os.listdir(args.data_lr))
    lr_list = []
    hr_list = []
    for one in file_name:
        lr_tmp = sorted(glob(os.path.join(args.data_lr, one, '*.png')))
        lr_list.extend(lr_tmp)
        hr_tmp = sorted(glob(os.path.join(args.data_hr, one, '*.png')))
        if len(hr_tmp) != 100:
            print(one)
        hr_list.extend(hr_tmp)

    # lr_list = glob(os.path.join(args.data_lr, '*'))
    # hr_list = glob(os.path.join(args.data_hr, '*'))
    # lr_list = lr_list[0:max_index]
    # hr_list = hr_list[0:max_index]

    data_set = DatasetLoader(lr_list, hr_list, args.patch_size, args.scale)
    data_len = len(data_set)
    train_loader = DataLoader(data_set,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              shuffle=True,
                              pin_memory=True,
                              drop_last=True)

    print("===> Building model")
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True

    device_ids = list(range(args.gpus))
    model = EDSR(args)
    criterion = nn.L1Loss(reduction='sum')

    print("===> Setting GPU")
    model = nn.DataParallel(model, device_ids=device_ids)
    model = model.cuda()
    criterion = criterion.cuda()

    start_epoch = args.start_epoch
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isdir(args.resume):
            #获取目录中最后一个
            pth_list = sorted(glob(os.path.join(args.resume, '*.pth')))
            if len(pth_list) > 0:
                args.resume = pth_list[-1]
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            start_epoch = checkpoint['epoch'] + 1
            state_dict = checkpoint['state_dict']
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                namekey = 'module.' + k  # remove `module.`
                new_state_dict[namekey] = v
            model.load_state_dict(new_state_dict)

            #如果文件中有lr,则不用启动参数
            args.lr = checkpoint.get('lr', args.lr)

    if args.start_epoch != 0:
        #如果设置了 start_epoch 则不用checkpoint中的epoch参数
        start_epoch = args.start_epoch

    print("===> Setting Optimizer")
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.lr,
                           weight_decay=args.weight_decay,
                           betas=(0.9, 0.999),
                           eps=1e-08)

    # record = []
    print("===> Training")
    for epoch in range(start_epoch, args.epochs):
        adjust_lr(optimizer, epoch)

        losses, psnrs = one_epoch_train_logger(model, optimizer, criterion,
                                               data_len, train_loader, epoch,
                                               args.epochs, args.batch_size,
                                               optimizer.param_groups[0]["lr"])

        # lr = optimizer.param_groups[0]["lr"]
        # the_lr = 1e-2
        # lr_len = 2
        # while lr + (1e-9) < the_lr:
        #     the_lr *= 0.1
        #     lr_len += 1
        # record.append([losses.avg,psnrs.avg,lr_len])

        # save model
        # if epoch+1 != args.epochs:
        #     continue
        model_out_path = os.path.join(
            args.checkpoint, "model_epoch_%04d_edsr_loss_%.3f_psnr_%.3f.pth" %
            (epoch, losses.avg, psnrs.avg))
        if not os.path.exists(args.checkpoint):
            os.makedirs(args.checkpoint)
        torch.save(
            {
                'state_dict': model.module.state_dict(),
                "epoch": epoch,
                'lr': optimizer.param_groups[0]["lr"]
            }, model_out_path)