Exemplo n.º 1
0
def main():

    dataloaders = myDataloader()
    train_loader = dataloaders.getTrainLoader(batch_size)

    model = SRCNN().cuda()
    model.train()

    optimizer = optim.Adam(model.parameters(), lr=lr)
    mse_loss = nn.MSELoss()

    for ep in range(epoch):
        running_loss = 0.0
        for i, (pic, blurPic, _) in enumerate(train_loader):
            pic = pic.cuda()
            blurPic = blurPic.cuda()
            optimizer.zero_grad()
            out = model(blurPic)
            loss = mse_loss(out, pic)
            loss.backward()
            optimizer.step()

            running_loss += loss
            if i % 10 == 9:
                print('[%d %d] loss: %.4f' %
                      (ep + 1, i + 1, running_loss / 20))
                running_loss = 0.0
        if ep % 10 == 9:
            torch.save(model.state_dict(),
                       f="./result/train/" + str(ep + 1) + "srcnnParms.pth")
    print("finish training")
Exemplo n.º 2
0
def SRCNN2(
    args, image_file
):  # CHANGE TO INPUT THE after-resize IMAGE FILE, SO IN THE OUTPUT3, NEED TO STORE THE denoise+resize image
    # load the SRCNN weights model
    #cudnn.benchmark = True
    device = torch.device('cuda: 0' if torch.cuda.is_available() else 'cpu')
    model = SRCNN().to(device)
    state_dict = model.state_dict()
    weights_dir = os.getcwd() + '\\SRCNN_outputs\\x{}\\'.format(
        args.SR_scale)  #
    weights_file = os.path.join(weights_dir, 'best.pth')  ###
    if not weights_file:
        print(weights_file + ' not exist')
    for n, p in torch.load(weights_file,
                           map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()  # model set in evaluation mode

    img_format = image_file[-4:]
    image = pil_image.open(image_file).convert('RGB')  # 512

    image = np.array(image).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(image)

    y = ycbcr[..., 0]
    y /= 255.
    y = torch.from_numpy(y).to(device)
    y = y.unsqueeze(0).unsqueeze(0)

    with torch.no_grad():
        preds = model(y).clamp(0.0, 1.0)  # output2.size 510

    # psnr = calc_psnr(y, preds)
    # print('PSNR: {:.2f}'.format(psnr))

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(
        0)  # tensor -> np

    output = np.array([preds, ycbcr[..., 1],
                       ycbcr[..., 2]]).transpose([1, 2, 0])  # why transpose
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    return output  ## type in pil_image
Exemplo n.º 3
0
from model import SRCNN
from utils import *

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights_file', type=str, required=True)
    parser.add_argument('--image_file', type=str, required=True)
    parser.add_argument('--scale', type=int, default=3)
    args = parser.parse_args()

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = SRCNN().to(device)

    state_dict = model.state_dict()
    for n, p in torch.load(args.weights_file,
                           map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    # model, optim = torch.load(model.state_dict(), os.path.join(args.weights_file, 'epoch_150.pth'))

    model.eval()

    image = Image.open(args.image_file).convert('RGB')
    resample = image

    image_width = (image.width // args.scale) * args.scale
Exemplo n.º 4
0
        'lr': args.lr * 0.1
    }],
                           lr=args.lr)

    train_dataset = TrainDataset(args.train_file)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  drop_last=True)

    eval_dataset = EvalDataset(args.eval_file)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_psnr = 0.0

    for epoch in range(args.num_epochs):
        model.train()
        epoch_losses = AverageMeter()

        with tqdm(total=(len(train_dataset) -
                         len(train_dataset) % args.batch_size)) as t:
            t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs - 1))

            for data in train_dataloader:
                inputs, labels = data

                inputs = inputs.to(device)
Exemplo n.º 5
0
    validate(test_path, model)
else:
    train_dataset = SRCNN_dataset(train_config)
    criterion = nn.MSELoss().cuda()
    optimizer_adam = optim.Adam(model.parameters(), lr=train_config['lr'])
    train_dataset = SRCNN_dataset(train_config)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=train_config['batch_size'])
    #==========================================================================================
    for _epoch in range(Start_epoch, End_epoch):
        loss_avg = train(train_loader,
                         model,
                         criterion,
                         optimizer_adam,
                         _epoch,
                         Writer=writer)  #将train过程封装成函数,这样使整体代码结构清晰
        save_state = {
            'epoch': _epoch,  #存储网络的时候不要只存储state_dict(),要把一些关键参数都存进去
            'lr': train_config['lr'],
            'state': model.state_dict()
        }
        if _epoch % 5 == 0:
            if not os.path.exists('model/'):  #标准步骤,检测文件夹是否存在
                os.mkdir('model/')  #不存在进行创建
            save_name = 'model/E%dL%d.pkl' % (_epoch, int(10000 * loss_avg))
            torch.save(save_state, save_name)


def save_checkpoint(state, file_path='model/%filename'):
    torch.save(state, file_path)
Exemplo n.º 6
0
                validate(i, val_dataloader, model, criterion, val_loss_meter,
                         val_psnr_meter, writer, config)

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i)
                writer.add_scalar('psnr/val_psnr', val_psnr_meter.avg, i)

                format_str = '===> Iter [{:d}/{:d}] Val_Loss: {:.6f}, Val_PSNR: {:.4f}'
                print(
                    format_str.format(i, config['training']['iterations'],
                                      val_loss_meter.avg, val_psnr_meter.avg))
                sys.stdout.flush()

                if val_psnr_meter.avg >= best_val_psnr:
                    best_val_psnr = val_psnr_meter.avg
                    ckpt = {
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'best_val_psnr': best_val_psnr,
                        'iter': i
                    }
                    path = '{}/{}/{}_{}.pth'.format(
                        config['training']['checkpoint_folder'],
                        os.path.basename(args.config)[:-5],
                        os.path.basename(args.config)[:-5], i)
                    torch.save(ckpt, path)

                val_loss_meter.reset()
                val_psnr_meter.reset()

            if i >= config['training']['iterations']:
Exemplo n.º 7
0
        optimizer.zero_grad()
        out = model(input)
        loss = criterion(out, target)

        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return get_psnr(epoch_loss / len(trainloader))


train_stat = []
test_stat = []
with tqdm(range(num_epoch)) as bar:
    for epoch in bar:
        # Train
        train_psnr = train()
        test_psnr = test()
        bar.set_postfix({
            'epoch': epoch,
            'train_psnr': f'{train_psnr:.2f}',
            'test_psnr': f'{test_psnr:.2f}'
        })
        train_stat.append(train_psnr)
        test_stat.append(test_psnr)

        # Save model
        torch.save(model.state_dict(), f"model_{epoch}.pth")
    plt.plot(train_stat, label='train_psnr')
    plt.plot(test_stat, label='test_psnr')
    plt.legend()
    plt.show()
Exemplo n.º 8
0
    writer.add_scalar('train/loss', epoch_loss / len(train_loader), global_step=epoch)
    writer.add_scalar('train/psnr', epoch_psnr / len(train_loader), global_step=epoch)
    print('[Epoch {}] Loss: {:.4f}, PSNR: {:.4f} dB'.format(epoch + 1, epoch_loss / len(train_loader), epoch_psnr / len(train_loader)))

    if (epoch + 1) % 1000 != 0:
        continue

    model.eval()
    val_loss, val_psnr = 0, 0
    with torch.no_grad():
        for batch in val_loader:
            inputs, targets = batch[0], batch[1]
            if opt.cuda:
                inputs = inputs.cuda()
                targets = targets.cuda()
                
            prediction = model(inputs)
            loss = criterion(prediction, targets)
            val_loss += loss.data
            val_psnr += 10 * log10(1 / loss.data)

            save_image(prediction, sample_dir / '{}_epoch{:05}.png'.format(batch[2][0], epoch + 1), nrow=1)

    writer.add_scalar('val/loss', val_loss / len(val_loader), global_step=epoch)
    writer.add_scalar('val/psnr', val_psnr / len(val_loader), global_step=epoch)
    print("===> Avg. Loss: {:.4f}, PSNR: {:.4f} dB".format(val_loss / len(val_loader), val_psnr / len(val_loader)))

    torch.save(model.state_dict(), str(weight_dir / 'weight_epoch{:05}.pth'.format(epoch + 1)))