def test(epoch): if epoch % args.epochs == 0 or epoch % args.test_freq == 0: output = model(fixed_latent) output = to_numpy(output) if args.animate: i_plot = epoch // args.test_freq plot_prediction_det_animate2(run_dir, output_arr, output[0], epoch, args.idx, i_plot, plot_fn='imshow', cmap=args.cmap, same_scale=args.same_scale) else: plot_prediction_det(run_dir, output_arr, output[0], epoch, args.idx, plot_fn='imshow', cmap=args.cmap, same_scale=args.same_scale) np.save(run_dir + f'/epoch{epoch}.npy', output[0])
def test(epoch): if epoch % args.epochs == 0 or epoch % args.test_freq == 0: # plot the solution xx, yy = np.meshgrid(np.arange(ngrids[0]), np.arange(ngrids[1])) x_test = xx.flatten()[:, None] / ngrids[1] y_test = yy.flatten()[:, None] / ngrids[0] x_test, y_test = to_tensor_gpu(x_test, y_test) net_u.eval() x_test.requires_grad = True y_test.requires_grad = True xy_test = torch.cat((y_test, x_test), 1) y_pred = net_u(xy_test) target = output_arr # three output of net_u from 0-3 channel: u, flux_y, flux_x u_pred = y_pred[:, 0].detach().cpu().numpy().reshape(*ngrids) u_y = y_pred[:, 1].detach().cpu().numpy().reshape(*ngrids) u_x = y_pred[:, 2].detach().cpu().numpy().reshape(*ngrids) prediction = np.stack((u_pred, u_x, u_y)) # prediction = y_pred.view(*ngrids, -1).transpose(0, 1).permute(2, 1, 0).detach().cpu().numpy() if args.animate: i_plot = epoch // args.test_freq plot_prediction_det_animate2(run_dir, target, prediction, epoch, args.idx, i_plot, plot_fn='imshow', cmap=args.cmap, same_scale=args.same_scale) else: plot_prediction_det(run_dir, target, prediction, epoch, args.idx, plot_fn='imshow', cmap=args.cmap, same_scale=args.same_scale) np.save(run_dir + f'/epoch{epoch}.npy', prediction)
def test(epoch): model.eval() mse = 0. for batch_idx, (input, target) in enumerate(test_loader): input, target = input.to(device), target.to(device) output = model(input) mse += F.mse_loss(output, target, size_average=False).item() # plot predictions if epoch % args.plot_freq == 0 and batch_idx == 0: n_samples = 6 if epoch == args.epochs else 2 idx = torch.randperm(input.size(0))[:n_samples] samples_output = output.data.cpu()[idx].numpy() samples_target = target.data.cpu()[idx].numpy() for i in range(n_samples): print('epoch {}: plotting prediction {}'.format(epoch, i)) plot_prediction_det(args.pred_dir, samples_target[i], samples_output[i], epoch, i, plot_fn=args.plot_fn) rmse_test = np.sqrt(mse / n_out_pixels_test) r2_score = 1 - mse / test_stats['y_var'] print("epoch: {}, test r2-score: {:.6f}".format(epoch, r2_score)) if epoch % args.log_freq == 0: logger['r2_test'].append(r2_score) logger['rmse_test'].append(rmse_test)
def test(epoch): model.eval() loss_test = 0. relative_l2, err2 = [], [] for batch_idx, (input, target) in enumerate(test_loader): input, target = input.to(device), target.to(device) output = model(input) loss_pde = constitutive_constraint(input, output, sobel_filter) \ + continuity_constraint(output, sobel_filter) loss_dirichlet, loss_neumann = boundary_condition(output) loss_boundary = loss_dirichlet + loss_neumann loss = loss_pde + loss_boundary * args.weight_bound loss_test += loss.item() # sum over H, W --> (B, C) err2_sum = torch.sum((output - target)**2, [-1, -2]) relative_l2.append(torch.sqrt(err2_sum / (target**2).sum([-1, -2]))) err2.append(err2_sum) # plot predictions if (epoch % args.plot_freq == 0 or epoch == args.epochs) and \ batch_idx == len(test_loader) - 1: n_samples = 6 if epoch == args.epochs else 2 idx = torch.randperm(input.size(0))[:n_samples] samples_output = output.data.cpu()[idx].numpy() samples_target = target.data.cpu()[idx].numpy() for i in range(n_samples): print('epoch {}: plotting prediction {}'.format(epoch, i)) plot_prediction_det(args.pred_dir, samples_target[i], samples_output[i], epoch, i, plot_fn=args.plot_fn) loss_test /= (batch_idx + 1) relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0)) r2_score = 1 - to_numpy(torch.cat(err2, 0).sum(0)) / y_test_variation print(f"Epoch: {epoch}, test r2-score: {r2_score}") print(f"Epoch: {epoch}, test relative-l2: {relative_l2}") print(f'Epoch {epoch}: test loss: {loss_train:.6f}, loss_pde: {loss_pde.item():.6f}, '\ f'dirichlet {loss_dirichlet:.6f}, nuemann {loss_neumann.item():.6f}') if epoch % args.log_freq == 0: logger['loss_test'].append(loss_test) logger['r2_test'].append(r2_score) logger['nrmse_test'].append(relative_l2)
def test(epoch): model.eval() loss_test = 0. relative_l2, err2 = [], [] for batch_idx, (input, target) in enumerate(test_loader): input, target = input.to(device), target.to(device) output = model(input) loss = F.mse_loss(output, target) loss_test += loss.item() # sum over H, W --> (B, C) err2_sum = torch.sum((output - target)**2, [-1, -2]) relative_l2.append(torch.sqrt(err2_sum / (target**2).sum([-1, -2]))) err2.append(err2_sum) # plot predictions if (epoch % args.plot_freq == 0 or epoch == args.epochs) and \ batch_idx == len(test_loader) - 1: n_samples = 6 if epoch == args.epochs else 2 idx = torch.randperm(input.size(0))[:n_samples] samples_output = output.data.cpu()[idx].numpy() samples_target = target.data.cpu()[idx].numpy() for i in range(n_samples): print('epoch {}: plotting prediction {}'.format(epoch, i)) plot_prediction_det(args.pred_dir, samples_target[i], samples_output[i], epoch, i, plot_fn=args.plot_fn) loss_test /= (batch_idx + 1) relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0)) r2_score = 1 - to_numpy(torch.cat(err2, 0).sum(0)) / y_test_variation print( f"Epoch: {epoch}, test r2-score: {r2_score}, relative-l2: {relative_l2}" ) if epoch % args.log_freq == 0: logger['loss_test'].append(loss_test) logger['r2_test'].append(r2_score) logger['nrmse_test'].append(relative_l2)