def closure(): optimizer.zero_grad() output = model(fixed_latent) if args.nonlinear: energy = constitutive_constraint(perm_tensor, output, sobel_filter, args.alpha1, args.alpha2) \ + continuity_constraint(output, sobel_filter) else: energy = constitutive_constraint( perm_tensor, output, sobel_filter) + continuity_constraint( output, sobel_filter) loss_dirichlet, loss_neumann = boundary_condition(output) loss_boundary = loss_dirichlet + loss_neumann loss = energy + loss_boundary * args.weight_bound loss.backward() if args.verbose: print(f'epoch {epoch}: loss {loss.item():6f}, '\ f'energy {energy.item():.6f}, diri {loss_dirichlet.item():.6f}, '\ f'neum {loss_neumann.item():.6f}') return loss
def test(epoch): model.eval() loss_test = 0. relative_l2, err2 = [], [] for batch_idx, (input, target) in enumerate(test_loader): input, target = input.to(device), target.to(device) output = model(input) loss_pde = constitutive_constraint(input, output, sobel_filter) \ + continuity_constraint(output, sobel_filter) loss_dirichlet, loss_neumann = boundary_condition(output) loss_boundary = loss_dirichlet + loss_neumann loss = loss_pde + loss_boundary * args.weight_bound loss_test += loss.item() # sum over H, W --> (B, C) err2_sum = torch.sum((output - target)**2, [-1, -2]) relative_l2.append(torch.sqrt(err2_sum / (target**2).sum([-1, -2]))) err2.append(err2_sum) # plot predictions if (epoch % args.plot_freq == 0 or epoch == args.epochs) and \ batch_idx == len(test_loader) - 1: n_samples = 6 if epoch == args.epochs else 2 idx = torch.randperm(input.size(0))[:n_samples] samples_output = output.data.cpu()[idx].numpy() samples_target = target.data.cpu()[idx].numpy() for i in range(n_samples): print('epoch {}: plotting prediction {}'.format(epoch, i)) plot_prediction_det(args.pred_dir, samples_target[i], samples_output[i], epoch, i, plot_fn=args.plot_fn) loss_test /= (batch_idx + 1) relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0)) r2_score = 1 - to_numpy(torch.cat(err2, 0).sum(0)) / y_test_variation print(f"Epoch: {epoch}, test r2-score: {r2_score}") print(f"Epoch: {epoch}, test relative-l2: {relative_l2}") print(f'Epoch {epoch}: test loss: {loss_train:.6f}, loss_pde: {loss_pde.item():.6f}, '\ f'dirichlet {loss_dirichlet:.6f}, nuemann {loss_neumann.item():.6f}') if epoch % args.log_freq == 0: logger['loss_test'].append(loss_test) logger['r2_test'].append(r2_score) logger['nrmse_test'].append(relative_l2)
for epoch in range(start_epoch, args.epochs + 1): model.train() # if epoch == 30: # print('begin finding lr') # logs,losses = find_lr(model, train_loader, optimizer, loss_fn, # args.weight_bound, init_value=1e-8, final_value=10., beta=0.98) # plt.plot(logs[10:-5], losses[10:-5]) # plt.savefig(args.train_dir + '/find_lr.png') # sys.exit(0) loss_train, mse = 0., 0. for batch_idx, (input, ) in enumerate(train_loader, start=1): input = input.to(device) model.zero_grad() output = model(input) loss_pde = constitutive_constraint(input, output, sobel_filter) \ + continuity_constraint(output, sobel_filter) loss_dirichlet, loss_neumann = boundary_condition(output) loss_boundary = loss_dirichlet + loss_neumann loss = loss_pde + loss_boundary * args.weight_bound loss.backward() # lr scheduling step = (epoch - 1) * len(train_loader) + batch_idx pct = step / total_steps lr = scheduler.step(pct) adjust_learning_rate(optimizer, lr) optimizer.step() loss_train += loss.item() loss_train /= batch_idx print(f'Epoch {epoch}, lr {lr:.6f}')
def test(epoch): model.eval() loss_test = 0. # mse = 0. relative_l2, err2 = [], [] for batch_idx, (input, target) in enumerate(test_loader): input, target = input.to(device), target.to(device) # every 10 epochs evaluate the mean accurately if epoch % 10 == 0: output_samples = model.sample(input, n_samples=20, temperature=1.0) output = output_samples.mean(0) else: # evaluate with one output sample output, _ = model.generate(input) residual_norm = constitutive_constraint(input, output, sobel_filter) \ + continuity_constraint(output, sobel_filter) loss_dirichlet, loss_neumann = boundary_condition(output) loss_boundary = loss_dirichlet + loss_neumann loss_pde = residual_norm + loss_boundary * args.weight_bound # evaluate predictive entropy: E_p(y|x) [log p(y|x)] neg_entropy = log_likeihood.mean() / math.log(2.) / n_out_pixels loss = loss_pde * args.beta + neg_entropy loss_test += loss.item() err2_sum = torch.sum((output - target)**2, [-1, -2]) # print(err2_sum) relative_l2.append(torch.sqrt(err2_sum / (target**2).sum([-1, -2]))) err2.append(err2_sum) # plot predictions if (epoch % args.plot_freq == 0 or epoch % args.epochs == 0) and batch_idx == 0: n_samples = 6 if epoch == args.epochs else 2 idx = np.random.permutation(input.size(0))[:n_samples] samples_target = target.data.cpu()[idx].numpy() for i in range(n_samples): print('epoch {}: plotting prediction {}'.format(epoch, i)) pred_mean, pred_var = model.predict(input[[idx[i]]]) plot_prediction_bayes2(args.pred_dir, samples_target[i], pred_mean[0], pred_var[0], epoch, idx[i], plot_fn='imshow', cmap='jet', same_scale=False) # plot samples p(y|x) print(idx[i]) print(input[[idx[i]]].shape) samples_pred = model.sample(input[[idx[i]]], n_samples=15)[:, 0] samples = torch.cat((target[[idx[i]]], samples_pred), 0) # print(samples.shape) save_samples(args.pred_dir, samples, epoch, idx[i], 'samples', nrow=4, heatmap=True, cmap='jet') loss_test /= (batch_idx + 1) relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0)) r2_score = 1 - to_numpy(torch.cat(err2, 0).sum(0)) / y_test_variation print(f"Epoch {epoch}: test r2-score: {r2_score}") print(f"Epoch {epoch}: test relative l2: {relative_l2}") print(f'Epoch {epoch}: test loss: {loss_test:.6f}, residual: {residual_norm.item():.6f}, '\ f'boundary {loss_boundary.item():.6f}, neg entropy {neg_entropy.item():.6f}') if epoch % args.log_freq == 0: logger['loss_test'].append(loss_test) logger['r2_test'].append(r2_score) logger['nrmse_test'].append(relative_l2) logger['entropy_test'].append(-neg_entropy.item())