def test(epoch):
     if epoch % args.epochs == 0 or epoch % args.test_freq == 0:
         output = model(fixed_latent)
         output = to_numpy(output)
         if args.animate:
             i_plot = epoch // args.test_freq
             plot_prediction_det_animate2(run_dir,
                                          output_arr,
                                          output[0],
                                          epoch,
                                          args.idx,
                                          i_plot,
                                          plot_fn='imshow',
                                          cmap=args.cmap,
                                          same_scale=args.same_scale)
         else:
             plot_prediction_det(run_dir,
                                 output_arr,
                                 output[0],
                                 epoch,
                                 args.idx,
                                 plot_fn='imshow',
                                 cmap=args.cmap,
                                 same_scale=args.same_scale)
         np.save(run_dir + f'/epoch{epoch}.npy', output[0])
 def test(epoch):
     if epoch % args.epochs == 0 or epoch % args.test_freq == 0:
         # plot the solution
         xx, yy = np.meshgrid(np.arange(ngrids[0]), np.arange(ngrids[1]))
         x_test = xx.flatten()[:, None] / ngrids[1]
         y_test = yy.flatten()[:, None] / ngrids[0]
         x_test, y_test = to_tensor_gpu(x_test, y_test)
         net_u.eval()
         x_test.requires_grad = True
         y_test.requires_grad = True
         xy_test = torch.cat((y_test, x_test), 1)
         y_pred = net_u(xy_test)
         target = output_arr
         # three output of net_u from 0-3 channel: u, flux_y, flux_x
         u_pred = y_pred[:, 0].detach().cpu().numpy().reshape(*ngrids)
         u_y = y_pred[:, 1].detach().cpu().numpy().reshape(*ngrids)
         u_x = y_pred[:, 2].detach().cpu().numpy().reshape(*ngrids)
         prediction = np.stack((u_pred, u_x, u_y))
         # prediction = y_pred.view(*ngrids, -1).transpose(0, 1).permute(2, 1, 0).detach().cpu().numpy()
         if args.animate:
             i_plot = epoch // args.test_freq
             plot_prediction_det_animate2(run_dir, target, prediction, epoch, args.idx, i_plot,
                 plot_fn='imshow', cmap=args.cmap, same_scale=args.same_scale)
         else:
             plot_prediction_det(run_dir, target, prediction, epoch, args.idx, 
                 plot_fn='imshow', cmap=args.cmap, same_scale=args.same_scale)
         np.save(run_dir + f'/epoch{epoch}.npy', prediction)
Ejemplo n.º 3
0
def test(epoch):
    model.eval()
    mse = 0.
    for batch_idx, (input, target) in enumerate(test_loader):
        input, target = input.to(device), target.to(device)
        output = model(input)
        mse += F.mse_loss(output, target, size_average=False).item()

        # plot predictions
        if epoch % args.plot_freq == 0 and batch_idx == 0:
            n_samples = 6 if epoch == args.epochs else 2
            idx = torch.randperm(input.size(0))[:n_samples]
            samples_output = output.data.cpu()[idx].numpy()
            samples_target = target.data.cpu()[idx].numpy()

            for i in range(n_samples):
                print('epoch {}: plotting prediction {}'.format(epoch, i))
                plot_prediction_det(args.pred_dir,
                                    samples_target[i],
                                    samples_output[i],
                                    epoch,
                                    i,
                                    plot_fn=args.plot_fn)

    rmse_test = np.sqrt(mse / n_out_pixels_test)
    r2_score = 1 - mse / test_stats['y_var']
    print("epoch: {}, test r2-score:  {:.6f}".format(epoch, r2_score))

    if epoch % args.log_freq == 0:
        logger['r2_test'].append(r2_score)
        logger['rmse_test'].append(rmse_test)
Ejemplo n.º 4
0
    def test(epoch):
        model.eval()
        loss_test = 0.
        relative_l2, err2 = [], []
        for batch_idx, (input, target) in enumerate(test_loader):
            input, target = input.to(device), target.to(device)
            output = model(input)
            loss_pde = constitutive_constraint(input, output, sobel_filter) \
                + continuity_constraint(output, sobel_filter)
            loss_dirichlet, loss_neumann = boundary_condition(output)
            loss_boundary = loss_dirichlet + loss_neumann
            loss = loss_pde + loss_boundary * args.weight_bound
            loss_test += loss.item()
            # sum over H, W --> (B, C)
            err2_sum = torch.sum((output - target)**2, [-1, -2])
            relative_l2.append(torch.sqrt(err2_sum /
                                          (target**2).sum([-1, -2])))
            err2.append(err2_sum)
            # plot predictions
            if (epoch % args.plot_freq == 0 or epoch == args.epochs) and \
                batch_idx == len(test_loader) - 1:
                n_samples = 6 if epoch == args.epochs else 2
                idx = torch.randperm(input.size(0))[:n_samples]
                samples_output = output.data.cpu()[idx].numpy()
                samples_target = target.data.cpu()[idx].numpy()
                for i in range(n_samples):
                    print('epoch {}: plotting prediction {}'.format(epoch, i))
                    plot_prediction_det(args.pred_dir,
                                        samples_target[i],
                                        samples_output[i],
                                        epoch,
                                        i,
                                        plot_fn=args.plot_fn)

        loss_test /= (batch_idx + 1)
        relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0))
        r2_score = 1 - to_numpy(torch.cat(err2, 0).sum(0)) / y_test_variation
        print(f"Epoch: {epoch}, test r2-score:  {r2_score}")
        print(f"Epoch: {epoch}, test relative-l2:  {relative_l2}")
        print(f'Epoch {epoch}: test loss: {loss_train:.6f}, loss_pde: {loss_pde.item():.6f}, '\
                f'dirichlet {loss_dirichlet:.6f}, nuemann {loss_neumann.item():.6f}')

        if epoch % args.log_freq == 0:
            logger['loss_test'].append(loss_test)
            logger['r2_test'].append(r2_score)
            logger['nrmse_test'].append(relative_l2)
    def test(epoch):
        model.eval()
        loss_test = 0.
        relative_l2, err2 = [], []
        for batch_idx, (input, target) in enumerate(test_loader):
            input, target = input.to(device), target.to(device)
            output = model(input)
            loss = F.mse_loss(output, target)
            loss_test += loss.item()
            # sum over H, W --> (B, C)
            err2_sum = torch.sum((output - target)**2, [-1, -2])
            relative_l2.append(torch.sqrt(err2_sum /
                                          (target**2).sum([-1, -2])))
            err2.append(err2_sum)
            # plot predictions
            if (epoch % args.plot_freq == 0 or epoch == args.epochs) and \
                batch_idx == len(test_loader) - 1:
                n_samples = 6 if epoch == args.epochs else 2
                idx = torch.randperm(input.size(0))[:n_samples]
                samples_output = output.data.cpu()[idx].numpy()
                samples_target = target.data.cpu()[idx].numpy()

                for i in range(n_samples):
                    print('epoch {}: plotting prediction {}'.format(epoch, i))
                    plot_prediction_det(args.pred_dir,
                                        samples_target[i],
                                        samples_output[i],
                                        epoch,
                                        i,
                                        plot_fn=args.plot_fn)

        loss_test /= (batch_idx + 1)
        relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0))
        r2_score = 1 - to_numpy(torch.cat(err2, 0).sum(0)) / y_test_variation
        print(
            f"Epoch: {epoch}, test r2-score:  {r2_score}, relative-l2:  {relative_l2}"
        )
        if epoch % args.log_freq == 0:
            logger['loss_test'].append(loss_test)
            logger['r2_test'].append(r2_score)
            logger['nrmse_test'].append(relative_l2)