コード例 #1
0
def validate(val_loader, model, cri, epoch, output_writers, device):
    val_loss = AverageMeter()

    # switch to evaluate mode
    model.eval()

    allrmse = []
    allpsnr = []
    allssim = []
    i = 0
    with torch.no_grad():
        for kdata, mask, image, fully in val_loader:
            kdata = kdata.float().to(device)
            mask = mask.float().to(device)
            image = image.float().to(device)

            _, _, w, h = kdata.size()

            # ifftshift
            kdata = roll(kdata, int(w / 2), 2)
            kdata = roll(kdata, int(h / 2), 3)
            mask = roll(mask, int(w / 2), 2)
            mask = roll(mask, int(h / 2), 3)

            # model forward
            reconsimage = model(kdata, mask)

            # calculate loss
            loss = cri(reconsimage[-1], image)

            # calculate score
            rmse = get_rmse(reconsimage[-1], image)
            psnr = get_psnr(reconsimage[-1], image)
            ssim = get_ssim(reconsimage[-1], image)
            allrmse.append(rmse.item())
            allpsnr.append(psnr)
            allssim.append(ssim.item())

            # record validation loss
            val_loss.update(loss.item(), kdata.size(0))

            # display results in tensorboard
            if 1 < i < 4:
                image = vutils.make_grid(image,
                                         normalize=True,
                                         scale_each=True)
                output_writers[i].add_image('gt image', image, 0)
                rec1 = vutils.make_grid(reconsimage[-1],
                                        normalize=True,
                                        scale_each=True)
                output_writers[i].add_image('reconstruction image 1 ', rec1,
                                            epoch)
            i = i + 1

        # print out average scores
        print(' * Average Validation Loss {:.3f}'.format(val_loss.avg))
        print(' * Average RMSE {:.4f}'.format(np.mean(np.asarray(allrmse))))
        print(' * Average PSNR {:.4f}'.format(np.mean(np.asarray(allpsnr))))
        print(' * Average SSIM {:.4f}'.format(np.mean(np.asarray(allssim))))
        return val_loss.avg
コード例 #2
0
    def validate(self):
        batch_time = utils.AverageMeter()
        losses = utils.AverageMeter()
        psnr = utils.AverageMeter()
        ssim = utils.AverageMeter()

        training = self.model.training
        self.model.eval()

        end = time.time()
        for batch_idx, (raws, imgs, targets, img_files, img_exposures, lbl_exposures, ratios) in tqdm.tqdm(
                enumerate(self.val_loader), total=len(self.val_loader),
                desc='{} iteration={} epoch={}'.format('Valid' if self.cmd == 'train' else 'Test',
                                                       self.iteration, self.epoch), ncols=80, leave=False):
            gc.collect()
            if self.cuda:
                raws, targets = raws.cuda(), targets.cuda(async=True)

            with torch.no_grad():
                raws = Variable(raws)
                targets = Variable(targets)
                output = self.model(raws)

                targets = targets[:, :, :output.size(2), :output.size(3)]
                loss = self.criterion(output, targets)
                if np.isnan(float(loss.item())):
                    raise ValueError('loss is nan while validating')
                losses.update(loss.item(), targets.size(0))

            outputs = torch.clamp(output, 0, 1).cpu()
            targets = targets.cpu()

            for output, img, target, img_file, img_exposure, lbl_exposure, ratio in zip(outputs, imgs, targets,
                                                                                        img_files, img_exposures,
                                                                                        lbl_exposures, ratios):
                output = output.numpy().transpose(1, 2, 0) * 255
                target = target.numpy().transpose(1, 2, 0) * 255

                if self.result_dir:
                    if self.cmd == 'test':
                        os.makedirs(self.result_dir, exist_ok=True)
                        fname = os.path.join(self.result_dir, '{}_compare.jpg'.format(os.path.basename(img_file)[:-4]))
                        temp = np.concatenate((target[:, :, :], output[:, :, :]), axis=1)
                        scipy.misc.toimage(temp, high=255, low=0, cmin=0, cmax=255).save(fname)
                        fname = os.path.join(self.result_dir, '{}_single.jpg'.format(os.path.basename(img_file)[:-4]))
                        scipy.misc.toimage(output, high=255, low=0, cmin=0, cmax=255).save(fname)

                # psnr.update(utils.get_psnr(output, target), 1)
                _psnr = utils.get_psnr(output, target)
                print("PSNR", img_file, _psnr)
                psnr.update(_psnr, 1)
                if self.cmd == 'test':
                    _ssim = utils.get_ssim(output, target)
                    print("SSIM", img_file, _ssim)
                    ssim.update(_ssim, 1)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if batch_idx % self.print_freq == 0:
                log_str = '{cmd:}: [{0}/{1}/{loss.count:}]\tepoch: {epoch:}\titer: {iteration:}\t' \
                      'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                      'Loss: {loss.val:.4f} ({loss.avg:.4f})\tPSNR: {psnr.val:.2f} ({psnr.avg:.2f})\tSSIM: {ssim.val:.4f} ({ssim.avg:.4f})\t'.format(
                    batch_idx, len(self.val_loader), cmd='Valid' if self.cmd == 'train' else 'Test',
                    epoch=self.epoch, iteration=self.iteration,
                    batch_time=batch_time, loss=losses, psnr=psnr, ssim=ssim)
                print(log_str)
                self.print_log(log_str)

        if self.cmd == 'train':
            is_best = psnr.avg > self.best_psnr
            self.best_psnr = max(psnr.avg, self.best_psnr)

            log_str = 'Valid_summary: [{0}/{1}/{psnr.count:}] epoch: {epoch:} iter: {iteration:}\t' \
                  'BestPSNR: {best_psnr:.3f}\t' \
                  'Time: {batch_time.avg:.3f}\tLoss: {loss.avg:.4f}\tPSNR: {psnr.avg:.3f}\t'.format(
                batch_idx, len(self.val_loader), epoch=self.epoch, iteration=self.iteration,
                best_psnr=self.best_psnr, batch_time=batch_time, loss=losses, psnr=psnr)
            print(log_str)
            self.print_log(log_str)

            checkpoint_file = os.path.join(self.checkpoint_dir, 'checkpoint.pth.tar')
            torch.save({
                'epoch': self.epoch,
                'iteration': self.iteration,
                'arch': self.model.__class__.__name__,
                'optim_state_dict': self.optim.state_dict(),
                'model_state_dict': self.model.state_dict(),
                'best_psnr': self.best_psnr,
                'batch_time': batch_time,
                'losses': losses,
                'psnr': psnr,
            }, checkpoint_file)
            if is_best:
                shutil.copy(checkpoint_file, os.path.join(self.checkpoint_dir, 'model_best.pth.tar'))
            if (self.epoch + 1) % 10 == 0: # save each 10 epoch
                shutil.copy(checkpoint_file, os.path.join(self.checkpoint_dir, 'checkpoint-{}.pth.tar'.format(self.epoch)))

            if training:
                self.model.train()
コード例 #3
0
    def train_epoch(self):
        batch_time = utils.AverageMeter()
        data_time = utils.AverageMeter()
        losses = utils.AverageMeter()
        psnr = utils.AverageMeter()
        ssim = utils.AverageMeter()

        self.model.train()
        self.optim.zero_grad()

        end = time.time()
        for batch_idx, (raws, imgs, targets, img_files, img_exposures, lbl_exposures, ratios) in tqdm.tqdm(
                enumerate(self.train_loader), total=len(self.train_loader),
                desc='Train epoch={}, iter={}'.format(self.epoch, self.iteration), ncols=80, leave=False):
            iteration = batch_idx + self.epoch * len(self.train_loader)
            data_time.update(time.time() - end)

            gc.collect()

            self.iteration = iteration

            if (self.iteration + 1) % self.interval_validate == 0:
                self.validate()

            if self.cuda:
                raws, targets = raws.cuda(), targets.cuda(async=True)
            raws, targets = Variable(raws), Variable(targets)

            outputs = self.model(raws)
            loss = self.criterion(outputs, targets)
            if np.isnan(float(loss.item())):
                raise ValueError('loss is nan while training')

            # measure accuracy and record loss
            losses.update(loss.item(), targets.size(0))

            outputs = torch.clamp(outputs, 0, 1).data.cpu()
            targets = targets.data.cpu()
            for output, img, target, img_file, img_exposure, lbl_exposure, ratio in zip(outputs, imgs, targets,
                                                                                        img_files, img_exposures,
                                                                                        lbl_exposures, ratios):
                output = output.numpy().transpose(1, 2, 0) * 255
                target = target.numpy().transpose(1, 2, 0) * 255
                psnr.update(utils.get_psnr(output, target), 1)
                if self.result_dir:
                    os.makedirs(self.result_dir + '%04d' % self.epoch, exist_ok=True)
                    fname = self.result_dir + '{:04d}/{:04d}_{}.jpg'.format(self.epoch, batch_idx, os.path.basename(img_file)[:-4])
                    temp = np.concatenate((target[:, :, :], output[:, :, :]), axis=1)
                    scipy.misc.toimage(temp, high=255, low=0, cmin=0, cmax=255).save(fname)

            # backprop
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if self.iteration % self.print_freq == 0:
                log_str = 'Train: [{0}/{1}/{loss.count:}]\tepoch: {epoch:}\titer: {iteration:}\t' \
                      'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                      'Data: {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                      'Loss: {loss.val:.4f} ({loss.avg:.4f})\tPSNR: {psnr.val:.1f} ({psnr.avg:.1f})\tlr {lr:.6f}'.format(
                    batch_idx, len(self.train_loader), epoch=self.epoch, iteration=self.iteration,
                    lr=self.optim.param_groups[0]['lr'],
                    batch_time=batch_time, data_time=data_time, loss=losses, psnr=psnr)
                print(log_str, flush=True)
                self.print_log(log_str)

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()  # update lr

        log_str = 'Train_summary: [{0}/{1}/{loss.count:}]\tepoch: {epoch:}\titer: {iteration:}\t' \
                      'Time: {batch_time.avg:.3f}\tData: {data_time.avg:.3f}\t' \
                      'Loss: {loss.avg:.4f}\tPSNR: {psnr.avg:.1f}\tlr {lr:.6f}'.format(
                    batch_idx, len(self.train_loader), epoch=self.epoch, iteration=self.iteration,
                    lr=self.optim.param_groups[0]['lr'],
                    batch_time=batch_time, data_time=data_time, loss=losses, psnr=psnr)
        print(log_str)
        self.print_log(log_str)
コード例 #4
0
def main():
    args = parse_args()

    if args.input.lower().endswith(".npy"):
        image = np.load(args.input)
    elif args.input.lower().endswith((".png", ".jpg", ".jpeg", ".tif")):
        image = imread(args.input)
    else:
        raise ValueError

    print("Read image {} with shape {}.".format(args.input, image.shape))

    if image.dtype == np.uint8:
        bit_depth = 8
    elif image.dtype == np.uint16:
        bit_depth = 16
    elif np.issubdtype(image.dtype, np.floating):
        bit_depth = 1
    else:
        raise TypeError

    psnr = get_psnr(bit_depth)

    model = keras.models.load_model(args.model, custom_objects={"psnr": psnr})

    # Transform into 4D tensor
    if image.ndim == 2:
        image = image[np.newaxis, :, :, np.newaxis]
    elif image.ndim == 3:
        image = image[np.newaxis, :, :]
    elif image.ndim != 4:
        raise Exception("Error with image dimensions: ndim should be 2 or " \
            "3, but received {}".format(image.ndim))

    patch_size = args.patch
    print("Cutting images into {}x{} patches.".format(patch_size, patch_size))
    image = get_patches(image, patch_size, 0)

    print("Adding noise with standard deviation {}.".format(args.sigma))
    noisy = NoiseGenerator.add_gaussian_noise(image, args.sigma)

    print("Making prediction on input tensor with shape {}".format(
        noisy.shape))
    pred = model.predict(noisy, batch_size=1)

    pred_psnr = psnr(image, pred)
    print("PSNR of prediction: {}".format(K.eval(pred_psnr)))

    if args.output:
        if args.output.lower().endswith(".npz"):
            print("Saving image, noisy image and prediction to {}".format(
                args.output))
            np.savez(args.output, image=image, noisy=noisy, pred=pred)
        elif args.output.lower().endswith(".npy"):
            print("Saving prediction to {}".format(args.output))
            np.save(args.output, pred)
        elif args.output.lower().endswith((".png", ".jpg", ".jpeg", ".tif")):
            print("Saving prediction to {}".format(args.output))
            for i, p in enumerate(pred):
                extension_dot_idx = args.output.rfind(".")
                output = args.output[:extension_dot_idx] + "_" + str(i + 1) \
                    + args.output[extension_dot_idx:]
                imsave(output, np.clip(np.squeeze(p), 0, bit_depth). \
                    astype(image.dtype))
        else:
            raise ValueError

    if args.visualize:
        while True:
            try:
                i = input("Patch number (0-{}): ".format(pred.shape[0] - 1))
                if i == "exit":
                    exit(0)
                else:
                    i = int(i)

                fig, axes = plt.subplots(1, 3)
                axes[0].imshow(image[i, :, :, 0], cmap="Greys_r")
                axes[0].set_title("Original")
                axes[0].set_aspect('equal', adjustable='box')
                axes[1].imshow(noisy[i, :, :, 0], cmap="Greys_r")
                axes[1].set_title("Noisy (sigma = {})".format(args.sigma))
                axes[1].set_aspect('equal', adjustable='box')
                axes[2].imshow(pred[i, :, :, 0], cmap="Greys_r")
                axes[2].set_title("Restored")
                axes[2].set_aspect('equal', adjustable='box')

                for ax in axes:
                    ax.set_xticks([])
                    ax.set_yticks([])

                plt.show()
            except Exception:
                continue
コード例 #5
0
ファイル: train.py プロジェクト: zacharielegault/noise2noise
def main():
    args = parse_args()

    timestr = time.strftime("%Y%m%d%H%M%S")

    log_dir_parent = args.log_directory
    log_dir = log_dir_parent + "/{}".format(timestr)
    checkpoints_dir = log_dir + "/checkpoints"
    os.mkdir(log_dir)
    os.mkdir(checkpoints_dir)

    optimizer = Adam()

    train = np.load(args.data_directory + "/train.npy")
    valid = np.load(args.data_directory + "/valid.npy")

    callbacks = [
        Logger(
            filename=log_dir + "/{}.log".format(timestr),
            optimizer=optimizer,
            sigma=args.sigma,
            epochs=args.epochs,
            batch_size=args.batch_size,
            dataset_dir=args.data_directory,
            checkpoints_dir=checkpoints_dir,
            noise2noise=args.noise2noise,
            dtype=train.dtype),
        CSVLogger(log_dir + "/{}.csv".format(timestr)),
        TerminateOnNaN(),
        ModelCheckpoint(
            checkpoints_dir + '/checkpoint.' + timestr + \
            '.{epoch:03d}-{val_loss:.3f}-{val_psnr:.5f}.h5',
            monitor='val_psnr',
            mode='max',
            save_best_only=True)
    ]

    # Build PSNR function
    if train.dtype == np.uint8:
        psnr = get_psnr(8)
    elif train.dtype == np.uint16:
        psnr = get_psnr(16)
    elif np.issubdtype(train.dtype, np.floating):
        psnr = get_psnr(1)
    else:
        raise TypeError

    model = unet(shape=(None, None, 1))
    model.compile(optimizer=optimizer, loss='mse', metrics=[psnr])

    # Same noise for all training procedure
    # train_input = NoiseGenerator.add_gaussian_noise(train, args.sigma)
    # if args.noise2noise:
    #     train_target = NoiseGenerator.add_gaussian_noise(train, args.sigma)
    # else:
    #     train_target = train

    # valid_input = NoiseGenerator.add_gaussian_noise(valid, args.sigma)
    # valid_target = valid  # Validation images should be clean

    # hist = model.fit(train_input, train_target,
    #                  epochs=args.epochs,
    #                  batch_size=args.batch_size,
    #                  verbose=args.verbose,
    #                  callbacks=callbacks,
    #                  validation_data=(valid_input, valid_target))

    # New noise for each new batch
    train_generator = NoiseGenerator(data=train,
                                     batch_size=args.batch_size,
                                     sigma=args.sigma,
                                     noise2noise=args.noise2noise)
    valid_generator = NoiseGenerator(data=valid,
                                     batch_size=args.batch_size,
                                     sigma=args.sigma,
                                     noise2noise=False)
    
    model.fit_generator(generator=train_generator,
                        epochs=args.epochs,
                        verbose=args.verbose,
                        callbacks=callbacks,
                        validation_data=valid_generator)