Esempio n. 1
0
    def plot_prediction_at_x(self, n_pred, plot_samples=False):
        r"""Plot `n_pred` predictions for randomly selected input from test dataset.
        - target
        - predictive mean
        - standard deviation of predictive output distribution
        - error of the above two

        Args:
            n_pred: number of candidate predictions
            plot_samples (bool): plot 15 output samples from p(y|x) for given x
        """
        save_dir = self.post_dir + '/predict_at_x'
        mkdir(save_dir)
        print('Plotting predictions at x from test dataset..................')
        np.random.seed(1)
        idx = np.random.permutation(len(self.test_loader.dataset))[:n_pred]
        for i in idx:
            print('input index: {}'.format(i))
            input, target = self.test_loader.dataset[i]
            pred_mean, pred_var = self.model.predict(
                input.unsqueeze(0).to(self.device),
                n_samples=self.n_samples,
                temperature=self.temperature)

            plot_prediction_bayes2(save_dir,
                                   target,
                                   pred_mean.squeeze(0),
                                   pred_var.squeeze(0),
                                   self.epochs,
                                   i,
                                   plot_fn=self.plot_fn)
            if plot_samples:
                samples_pred = self.model.sample(input.unsqueeze(0).to(
                    self.device),
                                                 n_samples=15)[:, 0]
                samples = torch.cat(
                    (target.unsqueeze(0), samples_pred.detach().cpu()), 0)
                save_samples(save_dir,
                             samples,
                             self.epochs,
                             i,
                             'samples',
                             nrow=4,
                             heatmap=True,
                             cmap='jet')
Esempio n. 2
0
                loss = F.mse_loss(denoised, clean, reduction='sum')
                mse += loss.item()
                psnr += cal_psnr(clean, denoised).sum().item()

            psnr = psnr / n_test_samples
            rmse = np.sqrt(mse / n_test_pixels)

            if epoch % args.plot_epochs == 0:
                print(
                    'Epoch {}: plot test denoising [input, denoised, clean, denoised - clean]'
                    .format(epoch))
                samples = torch.cat((noisy[:4], denoised[:4], clean[:4],
                                     denoised[:4] - clean[:4]))
                save_samples(args.pred_dir,
                             samples,
                             epoch,
                             'test',
                             epoch=True,
                             cmap=args.cmap)
                # fixed test
                fixed_denoised = model(fixed_test_noisy)
                samples = torch.cat(
                    (fixed_test_noisy[:4].cpu(), fixed_denoised[:4].cpu(),
                     fixed_test_clean[:4],
                     fixed_denoised[:4].cpu() - fixed_test_clean[:4]))
                save_samples(args.pred_dir,
                             samples,
                             epoch,
                             'fixed_test1',
                             epoch=True,
                             cmap=args.cmap)
                samples = torch.cat(
Esempio n. 3
0
    def test(epoch):
        model.eval()
        loss_test = 0.
        # mse = 0.
        relative_l2, err2 = [], []
        for batch_idx, (input, target) in enumerate(test_loader):
            input, target = input.to(device), target.to(device)
            # every 10 epochs evaluate the mean accurately
            if epoch % 10 == 0:
                output_samples = model.sample(input,
                                              n_samples=20,
                                              temperature=1.0)
                output = output_samples.mean(0)
            else:
                # evaluate with one output sample
                output, _ = model.generate(input)

            residual_norm = constitutive_constraint(input, output, sobel_filter) \
                + continuity_constraint(output, sobel_filter)
            loss_dirichlet, loss_neumann = boundary_condition(output)
            loss_boundary = loss_dirichlet + loss_neumann
            loss_pde = residual_norm + loss_boundary * args.weight_bound
            # evaluate predictive entropy: E_p(y|x) [log p(y|x)]
            neg_entropy = log_likeihood.mean() / math.log(2.) / n_out_pixels
            loss = loss_pde * args.beta + neg_entropy

            loss_test += loss.item()
            err2_sum = torch.sum((output - target)**2, [-1, -2])
            # print(err2_sum)
            relative_l2.append(torch.sqrt(err2_sum /
                                          (target**2).sum([-1, -2])))
            err2.append(err2_sum)

            # plot predictions
            if (epoch % args.plot_freq == 0
                    or epoch % args.epochs == 0) and batch_idx == 0:
                n_samples = 6 if epoch == args.epochs else 2
                idx = np.random.permutation(input.size(0))[:n_samples]
                samples_target = target.data.cpu()[idx].numpy()

                for i in range(n_samples):
                    print('epoch {}: plotting prediction {}'.format(epoch, i))
                    pred_mean, pred_var = model.predict(input[[idx[i]]])
                    plot_prediction_bayes2(args.pred_dir,
                                           samples_target[i],
                                           pred_mean[0],
                                           pred_var[0],
                                           epoch,
                                           idx[i],
                                           plot_fn='imshow',
                                           cmap='jet',
                                           same_scale=False)
                    # plot samples p(y|x)
                    print(idx[i])
                    print(input[[idx[i]]].shape)
                    samples_pred = model.sample(input[[idx[i]]],
                                                n_samples=15)[:, 0]
                    samples = torch.cat((target[[idx[i]]], samples_pred), 0)
                    # print(samples.shape)
                    save_samples(args.pred_dir,
                                 samples,
                                 epoch,
                                 idx[i],
                                 'samples',
                                 nrow=4,
                                 heatmap=True,
                                 cmap='jet')

        loss_test /= (batch_idx + 1)
        relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0))
        r2_score = 1 - to_numpy(torch.cat(err2, 0).sum(0)) / y_test_variation

        print(f"Epoch {epoch}: test r2-score:  {r2_score}")
        print(f"Epoch {epoch}: test relative l2:  {relative_l2}")
        print(f'Epoch {epoch}: test loss: {loss_test:.6f}, residual: {residual_norm.item():.6f}, '\
                f'boundary {loss_boundary.item():.6f}, neg entropy {neg_entropy.item():.6f}')

        if epoch % args.log_freq == 0:
            logger['loss_test'].append(loss_test)
            logger['r2_test'].append(r2_score)
            logger['nrmse_test'].append(relative_l2)
            logger['entropy_test'].append(-neg_entropy.item())