def main(args):
    model = models.LearnableDemosaick(num_filters=args.nfilters,
                                      fsize=args.fsize)
    reference_model = models.NaiveDemosaick()

    dset = datasets.DemosaickingDataset(args.dataset,
                                        transform=datasets.ToTensor())
    log.info("Validating on {} with {} images".format(args.dataset, len(dset)))

    loader = DataLoader(dset, batch_size=16, num_workers=4, shuffle=False)

    if args.cuda:
        model = model.cuda()

    l1_fn = metrics.CroppedL1Loss(crop=args.fsize // 2)
    msssim_fn = metrics.MSSSIM()
    alpha = 0.84
    crop = args.fsize // 2
    psnr_fn = metrics.PSNR(crop=args.fsize // 2)

    env = os.path.basename(args.chkpt) + "_eval"
    checkpointer = utils.Checkpointer(args.chkpt, model, None, verbose=False)
    chkpt_name, _ = checkpointer.load_latest()
    log.info("Loading checkpoint {}.".format(chkpt_name))

    callback = demosaick.DemosaickCallback(model,
                                           reference_model,
                                           len(loader),
                                           loader,
                                           env=env)

    idx = 0
    with tqdm(total=len(loader), unit=' batches') as pbar:
        pbar.set_description("Validation")
        avg = utils.Averager(["loss", "psnr", "ssim", "l1"])
        for batch_id, batch in enumerate(loader):
            mosaick, reference = batch
            mosaick = Variable(mosaick, requires_grad=False)
            reference = Variable(reference, requires_grad=False)

            if args.cuda:
                mosaick = mosaick.cuda()
                reference = reference.cuda()

            output = model(mosaick)

            if args.save is not None:
                if not os.path.exists(args.save):
                    os.makedirs(args.save)
                for i in range(output.shape[0]):
                    im = output[i].cpu().data.numpy()
                    im = np.transpose(im, [1, 2, 0])
                    im = np.clip(im, 0, 1)
                    fname = os.path.join(args.save, "{:04d}.png".format(idx))
                    idx += 1
                    skimage.io.imsave(fname, im)

            if crop > 0:
                output = output[:, :, crop:-crop, crop:-crop]
                reference = reference[:, :, crop:-crop, crop:-crop]
            ssim_ = 1 - msssim_fn(output, reference)
            l1_ = l1_fn(output, reference)
            loss = ssim_ * alpha + (1 - alpha) * l1_
            psnr = psnr_fn(output, reference)

            avg.update("loss", loss.data[0], count=mosaick.shape[0])
            avg.update("psnr", psnr.data[0], count=mosaick.shape[0])
            avg.update("ssim", ssim_.data[0], count=mosaick.shape[0])
            avg.update("l1", l1_.data[0], count=mosaick.shape[0])
            pbar.update(1)

        logs = {
            "loss": avg["loss"],
            "psnr": avg["psnr"],
            "ssim": avg["ssim"],
            "l1": avg["l1"]
        }

        pbar.set_postfix(logs)
        callback.on_epoch_end(0, logs)
def main(args):
    model = models.LearnableDemosaick(num_filters=args.nfilters,
                                      fsize=args.fsize)
    # model.softmax_scale[...] = 0.01
    reference_model = models.NaiveDemosaick()

    if not os.path.exists(args.output):
        os.makedirs(args.output)

    dset = datasets.DemosaickingDataset(args.dataset,
                                        transform=datasets.ToTensor())
    val_dset = datasets.DemosaickingDataset(args.val_dataset,
                                            transform=datasets.ToTensor())
    log.info("Training on {} with {} images".format(args.dataset, len(dset)))
    log.info("Validating on {} with {} images".format(args.val_dataset,
                                                      len(val_dset)))

    # log.info("Computing PCA filters")
    # vects = demosaick.get_pca_filters(dset, args.fsize)
    # model.sel_filts.data = th.from_numpy(vects)

    loader = DataLoader(dset,
                        batch_size=args.batch_size,
                        num_workers=4,
                        shuffle=True)
    val_loader = DataLoader(val_dset, batch_size=args.batch_size)

    if args.cuda:
        model = model.cuda()
        # model.softmax_scale.cuda()

    # params = [p for n, p in model.named_parameters() if n != "green_filts"]
    # optimizer = th.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True)
    optimizer = th.optim.Adam(model.parameters(), lr=args.lr)

    # mse_fn = metrics.CroppedMSELoss(crop=args.fsize//2)
    l1_fn = th.nn.L1Loss()
    msssim_fn = metrics.MSSSIM()
    # grad_l1_fn = metrics.CroppedGradientLoss(crop=args.fsize//2)
    # loss_fn = lambda a, b: 0.84*msssim_fn(a, b) + (1-0.84)*l1_fn(a, b)
    alpha = args.alpha
    crop = args.fsize // 2
    psnr_fn = metrics.PSNR(crop=args.fsize // 2)

    env = os.path.basename(args.output)
    checkpointer = utils.Checkpointer(args.output,
                                      model,
                                      optimizer,
                                      verbose=False,
                                      interval=600)
    callback = demosaick.DemosaickCallback(model,
                                           reference_model,
                                           len(loader),
                                           val_loader,
                                           env=env)

    if args.regularize:
        log.info("Using L1 weight regularization")

    if args.chkpt is not None:
        log.info("Loading checkpoint {}".format(args.chkpt))
        checkpointer.load_checkpoint(args.chkpt, ignore_optim=True)
    else:
        chkpt_name, _ = checkpointer.load_latest()
        log.info("Resuming from latest checkpoint {}.".format(chkpt_name))

    ema = utils.ExponentialMovingAverage(
        ["loss", "psnr", "ssim", "l1", "psnr_g"])
    for epoch in range(args.num_epochs):
        # callback.on_epoch_end(epoch, {})

        # Training
        model.train(True)
        with tqdm(total=len(loader), unit=' batches') as pbar:
            pbar.set_description("Epoch {}/{}".format(epoch + 1,
                                                      args.num_epochs))
            # callback.on_epoch_begin(epoch)
            for batch_id, batch in enumerate(loader):
                mosaick, reference = batch
                mosaick = Variable(mosaick, requires_grad=False)
                reference = Variable(reference, requires_grad=False)

                if args.cuda:
                    mosaick = mosaick.cuda()
                    reference = reference.cuda()

                output = model(mosaick)

                optimizer.zero_grad()
                if crop > 0:
                    output = output[:, :, crop:-crop, crop:-crop]
                    reference = reference[:, :, crop:-crop, crop:-crop]

                ssim_ = 1 - msssim_fn(output, reference)
                l1_ = l1_fn(output, reference)
                loss = ssim_ * alpha + (1 - alpha) * l1_
                if args.regularize:
                    l1_reg = None
                    reg_w = 1e-6
                    for n, p in model.named_parameters():
                        if l1_reg is None:
                            l1_reg = p.norm(1)
                        else:
                            l1_reg = l1_reg + p.norm(1)
                    loss += l1_reg * reg_w
                loss.backward()
                optimizer.step()

                psnr = psnr_fn(output, reference)
                psnr_green = psnr_fn(output[:, 1, ...], reference[:, 1, ...])

                ema.update("loss", loss.data[0])
                ema.update("psnr", psnr.data[0])
                ema.update("psnr_g", psnr_green.data[0])
                ema.update("ssim", ssim_.data[0])
                ema.update("l1", l1_.data[0])

                logs = {
                    "loss": ema["loss"],
                    "psnr": ema["psnr"],
                    "psnr_g": ema["psnr_g"],
                    "ssim": ema["ssim"],
                    "l1": ema["l1"]
                }
                pbar.set_postfix(logs)
                pbar.update(1)
                if pbar.n % args.viz_step == 0:
                    callback.on_batch_end(batch_id, logs)
                    callback.show_val_batch()
                checkpointer.periodic_checkpoint(epoch)

        # Validation
        model.train(False)
        with tqdm(total=len(val_loader), unit=' batches') as pbar:
            pbar.set_description("Epoch {}/{} (val)".format(
                epoch + 1, args.num_epochs))
            avg = utils.Averager(["loss", "psnr", "ssim", "l1"])
            for batch_id, batch in enumerate(val_loader):
                mosaick, reference = batch
                mosaick = Variable(mosaick, requires_grad=False)
                reference = Variable(reference, requires_grad=False)

                if args.cuda:
                    mosaick = mosaick.cuda()
                    reference = reference.cuda()

                output = model(mosaick)
                if crop > 0:
                    output = output[:, :, crop:-crop, crop:-crop]
                    reference = reference[:, :, crop:-crop, crop:-crop]
                ssim_ = 1 - msssim_fn(output, reference)
                l1_ = l1_fn(output, reference)
                loss = ssim_ * alpha + (1 - alpha) * l1_
                psnr = psnr_fn(output, reference)

                avg.update("loss", loss.data[0], count=mosaick.shape[0])
                avg.update("psnr", psnr.data[0], count=mosaick.shape[0])
                avg.update("ssim", ssim_.data[0], count=mosaick.shape[0])
                avg.update("l1", l1_.data[0], count=mosaick.shape[0])
                pbar.update(1)

            logs = {
                "loss": avg["loss"],
                "psnr": avg["psnr"],
                "ssim": avg["ssim"],
                "l1": avg["l1"]
            }

            pbar.set_postfix(logs)
            callback.on_epoch_end(epoch, logs)

        # save
        checkpointer.on_epoch_end(epoch)
Beispiel #3
0
def main(args):
    model = models.DeconvNonlinearCG(num_stages=3)
    ref_model = models.DeconvNonlinearCG(ref=True)

    if not os.path.exists(args.output):
        os.makedirs(args.output)

    dset = datasets.DeconvDataset(args.dataset)
    val_dset = datasets.DeconvDataset(args.val_dataset, is_validate=True)

    log.info("Training on {} with {} images".format(args.dataset, len(dset)))
    log.info("Validating on {} with {} images".format(args.val_dataset,
                                                      len(val_dset)))

    if args.cuda:
        model = model.cuda()
        ref_model = ref_model.cuda()
    print("Training parameters:")
    params_to_train = []
    for n, p in model.named_parameters():
        print("  -", n)
        params_to_train.append(p)
    optimizer = th.optim.Adam(params_to_train, lr=args.lr)
    # optimizer = th.optim.SGD(model.parameters(), lr=args.lr)
    loss_fn = metrics.CroppedL1Loss(crop=16)
    psnr_fn = metrics.PSNR(crop=16)

    loader = DataLoader(dset,
                        batch_size=args.batch_size,
                        num_workers=1,
                        shuffle=True)
    val_loader = DataLoader(val_dset, batch_size=8)

    checkpointer = utils.Checkpointer(args.output,
                                      model,
                                      optimizer,
                                      verbose=True)
    callback = DeconvCallback(model,
                              ref_model,
                              val_loader,
                              args.cuda,
                              env="gapps_deconv")

    smooth_loss = 0
    smooth_psnr = 0
    ema = 0.9
    chkpt_name, iteration = checkpointer.load_latest()
    log.info("Resuming from latest checkpoint {}.".format(chkpt_name))
    train_iterator = iter(loader)
    best_psnr = 0.0
    first = True
    while True:
        # Training
        # Get a batch from the dataset
        try:
            batch = train_iterator.next()
        except StopIteration:
            train_iterator = iter(loader)
            batch = train_iterator.next()
        model.train(True)

        # Setup input & reference
        blurred, reference, kernel = batch
        blurred = Variable(blurred, requires_grad=False)
        reference = Variable(reference, requires_grad=False)
        kernel = Variable(kernel, requires_grad=False)

        # Transfer data to gpu if necessary
        if args.cuda:
            blurred = blurred.cuda()
            reference = reference.cuda()
            kernel = kernel.cuda()

        # Run the model
        output = model(blurred, kernel, cg_iter)

        # Compute loss & optimize
        optimizer.zero_grad()
        loss = loss_fn(output, reference)
        loss.backward()
        optimizer.step()
        model.reg_powers.data.fill_(2.0)

        # Compute PSNR
        psnr = psnr_fn(output, reference)

        # Exponential smooth of error curve
        if first:
            smooth_loss = loss.data[0]
            smooth_psnr = psnr.data[0]
            first = False
        else:
            smooth_loss = ema * smooth_loss + (1 - ema) * loss.data[0]
            smooth_psnr = ema * smooth_psnr + (1 - ema) * psnr.data[0]

        print('loss: {}, psnr: {}'.format(smooth_loss, smooth_psnr))
        model.train(False)
        ref_model.train(False)

        logs = {"loss": smooth_loss, "psnr": smooth_psnr}
        callback.on_iteration_end(iteration, logs)

        if iteration % 20 == 0:
            # Validation
            # Go through the whole validation dataset
            total_loss = 0
            total_psnr = 0
            total_ref_loss = 0
            total_ref_psnr = 0
            n_seen = 0
            for batch_id, batch in enumerate(val_loader):
                blurred, reference, kernel = batch
                blurred = Variable(blurred, requires_grad=False)
                reference = Variable(reference, requires_grad=False)
                kernel = Variable(kernel, requires_grad=False)

                if args.cuda:
                    blurred = blurred.cuda()
                    reference = reference.cuda()
                    kernel = kernel.cuda()

                output = model(blurred, kernel, ref_cg_iter)
                loss = loss_fn(output, reference)
                psnr = psnr_fn(output, reference)

                ref_output = ref_model(blurred, kernel, ref_cg_iter)
                ref_loss = loss_fn(ref_output, reference)
                ref_psnr = psnr_fn(ref_output, reference)

                total_loss += loss.data[0] * args.batch_size
                total_psnr += psnr.data[0] * args.batch_size
                total_ref_loss += ref_loss.data[0] * args.batch_size
                total_ref_psnr += ref_psnr.data[0] * args.batch_size
                n_seen += args.batch_size

            val_loss = total_loss / n_seen
            val_psnr = total_psnr / n_seen
            ref_loss = total_ref_loss / n_seen
            ref_psnr = total_ref_psnr / n_seen

            logs = {
                "val_loss": val_loss,
                "val_psnr": val_psnr,
                "ref_loss": ref_loss,
                "ref_psnr": ref_psnr
            }

            callback.on_validation_end(iteration, logs)
            # save
            checkpointer.on_epoch_end(iteration)
            # save best
            if val_psnr > best_psnr:
                filename = 'epoch_{:03d}_best.pth.tar'.format(iteration + 1)
                checkpointer.save_checkpoint(iteration, filename)
                best_psnr = val_psnr

        iteration += 1