def closure():
     optimizer.zero_grad()
     output = model(fixed_latent)
     if args.nonlinear:
         energy = constitutive_constraint(perm_tensor, output,
             sobel_filter, args.alpha1, args.alpha2) \
             + continuity_constraint(output, sobel_filter)
     else:
         energy = constitutive_constraint(
             perm_tensor, output, sobel_filter) + continuity_constraint(
                 output, sobel_filter)
     loss_dirichlet, loss_neumann = boundary_condition(output)
     loss_boundary = loss_dirichlet + loss_neumann
     loss = energy + loss_boundary * args.weight_bound
     loss.backward()
     if args.verbose:
         print(f'epoch {epoch}: loss {loss.item():6f}, '\
             f'energy {energy.item():.6f}, diri {loss_dirichlet.item():.6f}, '\
             f'neum {loss_neumann.item():.6f}')
     return loss
Esempio n. 2
0
    def test(epoch):
        model.eval()
        loss_test = 0.
        relative_l2, err2 = [], []
        for batch_idx, (input, target) in enumerate(test_loader):
            input, target = input.to(device), target.to(device)
            output = model(input)
            loss_pde = constitutive_constraint(input, output, sobel_filter) \
                + continuity_constraint(output, sobel_filter)
            loss_dirichlet, loss_neumann = boundary_condition(output)
            loss_boundary = loss_dirichlet + loss_neumann
            loss = loss_pde + loss_boundary * args.weight_bound
            loss_test += loss.item()
            # sum over H, W --> (B, C)
            err2_sum = torch.sum((output - target)**2, [-1, -2])
            relative_l2.append(torch.sqrt(err2_sum /
                                          (target**2).sum([-1, -2])))
            err2.append(err2_sum)
            # plot predictions
            if (epoch % args.plot_freq == 0 or epoch == args.epochs) and \
                batch_idx == len(test_loader) - 1:
                n_samples = 6 if epoch == args.epochs else 2
                idx = torch.randperm(input.size(0))[:n_samples]
                samples_output = output.data.cpu()[idx].numpy()
                samples_target = target.data.cpu()[idx].numpy()
                for i in range(n_samples):
                    print('epoch {}: plotting prediction {}'.format(epoch, i))
                    plot_prediction_det(args.pred_dir,
                                        samples_target[i],
                                        samples_output[i],
                                        epoch,
                                        i,
                                        plot_fn=args.plot_fn)

        loss_test /= (batch_idx + 1)
        relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0))
        r2_score = 1 - to_numpy(torch.cat(err2, 0).sum(0)) / y_test_variation
        print(f"Epoch: {epoch}, test r2-score:  {r2_score}")
        print(f"Epoch: {epoch}, test relative-l2:  {relative_l2}")
        print(f'Epoch {epoch}: test loss: {loss_train:.6f}, loss_pde: {loss_pde.item():.6f}, '\
                f'dirichlet {loss_dirichlet:.6f}, nuemann {loss_neumann.item():.6f}')

        if epoch % args.log_freq == 0:
            logger['loss_test'].append(loss_test)
            logger['r2_test'].append(r2_score)
            logger['nrmse_test'].append(relative_l2)
Esempio n. 3
0
    def test(epoch):
        model.eval()
        loss_test = 0.
        # mse = 0.
        relative_l2, err2 = [], []
        for batch_idx, (input, target) in enumerate(test_loader):
            input, target = input.to(device), target.to(device)
            # every 10 epochs evaluate the mean accurately
            if epoch % 10 == 0:
                output_samples = model.sample(input,
                                              n_samples=20,
                                              temperature=1.0)
                output = output_samples.mean(0)
            else:
                # evaluate with one output sample
                output, _ = model.generate(input)

            residual_norm = constitutive_constraint(input, output, sobel_filter) \
                + continuity_constraint(output, sobel_filter)
            loss_dirichlet, loss_neumann = boundary_condition(output)
            loss_boundary = loss_dirichlet + loss_neumann
            loss_pde = residual_norm + loss_boundary * args.weight_bound
            # evaluate predictive entropy: E_p(y|x) [log p(y|x)]
            neg_entropy = log_likeihood.mean() / math.log(2.) / n_out_pixels
            loss = loss_pde * args.beta + neg_entropy

            loss_test += loss.item()
            err2_sum = torch.sum((output - target)**2, [-1, -2])
            # print(err2_sum)
            relative_l2.append(torch.sqrt(err2_sum /
                                          (target**2).sum([-1, -2])))
            err2.append(err2_sum)

            # plot predictions
            if (epoch % args.plot_freq == 0
                    or epoch % args.epochs == 0) and batch_idx == 0:
                n_samples = 6 if epoch == args.epochs else 2
                idx = np.random.permutation(input.size(0))[:n_samples]
                samples_target = target.data.cpu()[idx].numpy()

                for i in range(n_samples):
                    print('epoch {}: plotting prediction {}'.format(epoch, i))
                    pred_mean, pred_var = model.predict(input[[idx[i]]])
                    plot_prediction_bayes2(args.pred_dir,
                                           samples_target[i],
                                           pred_mean[0],
                                           pred_var[0],
                                           epoch,
                                           idx[i],
                                           plot_fn='imshow',
                                           cmap='jet',
                                           same_scale=False)
                    # plot samples p(y|x)
                    print(idx[i])
                    print(input[[idx[i]]].shape)
                    samples_pred = model.sample(input[[idx[i]]],
                                                n_samples=15)[:, 0]
                    samples = torch.cat((target[[idx[i]]], samples_pred), 0)
                    # print(samples.shape)
                    save_samples(args.pred_dir,
                                 samples,
                                 epoch,
                                 idx[i],
                                 'samples',
                                 nrow=4,
                                 heatmap=True,
                                 cmap='jet')

        loss_test /= (batch_idx + 1)
        relative_l2 = to_numpy(torch.cat(relative_l2, 0).mean(0))
        r2_score = 1 - to_numpy(torch.cat(err2, 0).sum(0)) / y_test_variation

        print(f"Epoch {epoch}: test r2-score:  {r2_score}")
        print(f"Epoch {epoch}: test relative l2:  {relative_l2}")
        print(f'Epoch {epoch}: test loss: {loss_test:.6f}, residual: {residual_norm.item():.6f}, '\
                f'boundary {loss_boundary.item():.6f}, neg entropy {neg_entropy.item():.6f}')

        if epoch % args.log_freq == 0:
            logger['loss_test'].append(loss_test)
            logger['r2_test'].append(r2_score)
            logger['nrmse_test'].append(relative_l2)
            logger['entropy_test'].append(-neg_entropy.item())
Esempio n. 4
0
    print(f'total steps: {total_steps}')
    for epoch in range(start_epoch, args.epochs + 1):
        model.train()
        # if epoch == 30:
        #     print('begin finding lr')
        #     logs,losses = find_lr(model, train_loader, optimizer, loss_fn,
        #           args.weight_bound, init_value=1e-8, final_value=10., beta=0.98)
        #     plt.plot(logs[10:-5], losses[10:-5])
        #     plt.savefig(args.train_dir + '/find_lr.png')
        #     sys.exit(0)
        loss_train, mse = 0., 0.
        for batch_idx, (input, ) in enumerate(train_loader, start=1):
            input = input.to(device)
            model.zero_grad()
            output = model(input)
            loss_pde = constitutive_constraint(input, output, sobel_filter) \
                + continuity_constraint(output, sobel_filter)
            loss_dirichlet, loss_neumann = boundary_condition(output)
            loss_boundary = loss_dirichlet + loss_neumann
            loss = loss_pde + loss_boundary * args.weight_bound
            loss.backward()
            # lr scheduling
            step = (epoch - 1) * len(train_loader) + batch_idx
            pct = step / total_steps
            lr = scheduler.step(pct)
            adjust_learning_rate(optimizer, lr)
            optimizer.step()
            loss_train += loss.item()

        loss_train /= batch_idx
Esempio n. 5
0
            for batch_idx, (input, target) in enumerate(test_loader):
                input, target = input.to(device), target.to(device)
                model.zero_grad()
                latent, logp, eps = model(target, input)
                break
            initialized = True
            print('Finished data initialization of Actnorm')

        for batch_idx, (input, ) in enumerate(train_loader):
            input = input.to(device)
            model.zero_grad()
            # sample output for each input
            with autograd.detect_anomaly():
                output, log_likeihood = model.generate(input)
                # evaluate energy functional/residual norm: E_x E_p(y|x) [\beta V(y; x)]
                residual_norm = constitutive_constraint(input, output, sobel_filter) \
                    + continuity_constraint(output, sobel_filter)
                loss_dirichlet, loss_neumann = boundary_condition(output)
                loss_boundary = loss_dirichlet + loss_neumann
                loss_pde = residual_norm + loss_boundary * args.weight_bound
                # evaluate predictive entropy: E_p(y|x) [log p(y|x)], bits per pixel
                neg_entropy = log_likeihood.mean() / math.log(
                    2.) / n_out_pixels
                loss = loss_pde * args.beta + neg_entropy
                loss.backward()

            step = (epoch - 1) * len(train_loader) + batch_idx
            pct = step / total_steps
            lr = scheduler.step(pct)
            adjust_learning_rate(optimizer, lr)