예제 #1
0
def wave_pml(model_output, gt):
    source_boundary_values = gt['source_boundary_values']
    x = model_output['model_in']  # (meta_batch_size, num_points, 3)
    y = model_output['model_out']  # (meta_batch_size, num_points, 1)
    squared_slowness = gt['squared_slowness']
    dirichlet_mask = gt['dirichlet_mask']
    batch_size = x.shape[1]

    du, status = diff_operators.jacobian(y, x)
    dudt = du[..., 0]

    if torch.all(dirichlet_mask):
        diff_constraint_hom = torch.Tensor([0])
    else:
        hess, status = diff_operators.jacobian(du[..., 0, :], x)
        lap = hess[..., 1, 1, None] + hess[..., 2, 2, None]
        dudt2 = hess[..., 0, 0, None]
        diff_constraint_hom = dudt2 - 1 / squared_slowness * lap

    dirichlet = y[dirichlet_mask] - source_boundary_values[dirichlet_mask]
    neumann = dudt[dirichlet_mask]

    return {
        'dirichlet': torch.abs(dirichlet).sum() * batch_size / 1e1,
        'neumann': torch.abs(neumann).sum() * batch_size / 1e2,
        'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()
    }
예제 #2
0
    def forward(self, model_input):

        input_dict = {
            key: input.clone().detach().requires_grad_(True)
            for key, input in model_input.items()
        }

        if self.input_processing_fn is not None:
            input_dict_transformed = self.input_processing_fn(input_dict)
        coords = input_dict_transformed['coords']

        if self.nl != 'sine':
            coords_pe = self.pe(coords)
            output = self.net(coords_pe)
        else:
            output = self.net(coords)

        if self.use_grad:
            output = jacobian(output,
                              input_dict_transformed[self.grad_var])[0][:, :,
                                                                        0]

        return {
            'model_in': input_dict_transformed,
            'model_out': {
                'output': output
            }
        }
예제 #3
0
    def hji_MultiVehicleCollision(model_output, gt):
        source_boundary_values = gt['source_boundary_values']
        x = model_output['model_in']  # (meta_batch_size, num_points, 4)
        y = model_output['model_out']  # (meta_batch_size, num_points, 1)
        dirichlet_mask = gt['dirichlet_mask']
        batch_size = x.shape[1]

        if torch.all(dirichlet_mask):
            diff_constraint_hom = torch.Tensor([0])
        else:
            du, status = diff_operators.jacobian(y, x)
            dudt = du[..., 0, 0]
            dudx = du[..., 0, 1:]

            # Scale the costate for theta appropriately to align with the range of [-pi, pi]
            dudx[...,
                 num_pos_states:] = dudx[..., num_pos_states:] / alpha_angle

            # Compute the hamiltonian for the ego vehicle
            ham = velocity * (
                torch.cos(alpha_angle * x[..., num_pos_states + 1]) *
                dudx[..., 0] + torch.sin(
                    alpha_angle * x[..., num_pos_states + 1]) * dudx[..., 1]
            ) - omega_max * torch.abs(dudx[..., num_pos_states])

            # Hamiltonian effect due to other vehicles
            for i in range(numEvaders):
                theta_index = num_pos_states + 1 + i + 1
                xcostate_index = 2 * (i + 1)
                ycostate_index = 2 * (i + 1) + 1
                thetacostate_index = num_pos_states + 1 + i
                ham_local = velocity * (
                    torch.cos(alpha_angle * x[..., theta_index]) *
                    dudx[..., xcostate_index] +
                    torch.sin(alpha_angle * x[..., theta_index]) *
                    dudx[..., ycostate_index]) + omega_max * torch.abs(
                        dudx[..., thetacostate_index])
                ham = ham + ham_local

            # Effect of time factor
            ham = ham * alpha_time

            # If we are computing BRT then take min with zero
            if minWith == 'zero':
                ham = torch.clamp(ham, max=0.0)

            diff_constraint_hom = dudt - ham
            if minWith == 'target':
                diff_constraint_hom = torch.max(
                    diff_constraint_hom[:, :, None],
                    y - source_boundary_values)

        dirichlet = y[dirichlet_mask] - source_boundary_values[dirichlet_mask]

        # A factor of 15e2 to make loss roughly equal
        return {
            'dirichlet': torch.abs(dirichlet).sum() * batch_size / 15e2,
            'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()
        }
예제 #4
0
    def test_backward(self):
        # run our backward graph
        x = [torch.randn(1, 1).cuda() for i in range(4)]
        our_grad = self.test_sess.backward(x).squeeze()

        # run forward graph and calc grad using pytorch
        x[0] = torch.nn.Parameter(x[0])
        out = self.test_sess(x)
        pytorch_grad = jacobian(out[None, None, :], x[0])[0].squeeze()
        self.assertTrue(torch.allclose(our_grad, pytorch_grad))
예제 #5
0
    def hji_air3D(model_output, gt):
        source_boundary_values = gt['source_boundary_values']
        x = model_output['model_in']  # (meta_batch_size, num_points, 4)
        y = model_output['model_out']  # (meta_batch_size, num_points, 1)
        dirichlet_mask = gt['dirichlet_mask']
        batch_size = x.shape[1]

        du, status = diff_operators.jacobian(y, x)
        dudt = du[..., 0, 0]
        dudx = du[..., 0, 1:]

        x_theta = x[..., 3] * 1.0

        # Scale the costate for theta appropriately to align with the range of [-pi, pi]
        dudx[..., 2] = dudx[..., 2] / alpha_angle
        # Scale the coordinates
        x_theta = alpha_angle * x_theta

        # Air3D dynamics
        # \dot x    = -v_a + v_b \cos \psi + a y
        # \dot y    = v_b \sin \psi - a x
        # \dot \psi = b - a

        # Compute the hamiltonian for the ego vehicle
        ham = omega_max * torch.abs(
            dudx[..., 0] * x[..., 2] - dudx[..., 1] * x[..., 1] -
            dudx[..., 2])  # Control component
        ham = ham - omega_max * torch.abs(dudx[...,
                                               2])  # Disturbance component
        ham = ham + (velocity * (torch.cos(x_theta) - 1.0) * dudx[..., 0]) + (
            velocity * torch.sin(x_theta) * dudx[..., 1])  # Constant component

        # If we are computing BRT then take min with zero
        if minWith == 'zero':
            ham = torch.clamp(ham, max=0.0)

        if torch.all(dirichlet_mask):
            diff_constraint_hom = torch.Tensor([0])
        else:
            diff_constraint_hom = dudt - ham
            if minWith == 'target':
                diff_constraint_hom = torch.max(
                    diff_constraint_hom[:, :, None],
                    y - source_boundary_values)

        dirichlet = y[dirichlet_mask] - source_boundary_values[dirichlet_mask]

        # A factor of 15e2 to make loss roughly equal
        return {
            'dirichlet': torch.abs(dirichlet).sum() * batch_size / 15e2,
            'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()
        }
예제 #6
0
def check_backward():
    sampler = None
    input_processing_fn = None
    sigma_model = modules.RadianceNet(input_processing_fn=input_processing_fn,
                                      sampler=sampler,
                                      input_name=['ray_samples'])
    rgb_model = modules.RadianceNet(input_processing_fn=input_processing_fn,
                                    sampler=sampler)

    for name, model in {'sigma': sigma_model, 'rgb': rgb_model}.items():
        print(f'Checking gradients for {name} model')

        t = torch.rand(128, 64, 1).cuda()
        ray_dirs = torch.rand(128, 64, 3).cuda()
        ray_origins = torch.rand(128, 64, 3).cuda()
        orientations = torch.rand(128, 64, 6).cuda()

        model_in = {
            't': t,
            'ray_directions': ray_dirs,
            'ray_origins': ray_origins,
            'ray_orientations': orientations
        }
        our_grad = model(model_in)['model_out']['output'].squeeze()

        t = torch.nn.Parameter(t)
        model.set_mode('integral')
        model_in = {
            't': t,
            'ray_directions': ray_dirs,
            'ray_origins': ray_origins,
            'ray_orientations': orientations
        }
        model_out = model(model_in)
        out = model_out['model_out']['output']

        pytorch_grad = jacobian(out, model_out['model_in']['t'])[0].squeeze()

        #print(torch.abs(our_grad - pytorch_grad).max())
        print('Passed' if torch.allclose(our_grad, pytorch_grad, atol=1e-6
                                         ) else 'Failed!')
예제 #7
0
def helmholtz_pml(model_output, gt):
    source_boundary_values = gt['source_boundary_values']

    if 'rec_boundary_values' in gt:
        rec_boundary_values = gt['rec_boundary_values']

    wavenumber = gt['wavenumber'].float()
    x = model_output['model_in']  # (meta_batch_size, num_points, 2)
    y = model_output['model_out']  # (meta_batch_size, num_points, 2)
    squared_slowness = gt['squared_slowness'].repeat(1, 1, y.shape[-1] // 2)
    batch_size = x.shape[1]

    full_waveform_inversion = False
    if 'pretrain' in gt:
        pred_squared_slowness = y[:, :, -1] + 1.
        if torch.all(gt['pretrain'] == -1):
            full_waveform_inversion = True
            pred_squared_slowness = torch.clamp(y[:, :, -1], min=-0.999) + 1.
            squared_slowness_init = torch.stack(
                (torch.ones_like(pred_squared_slowness),
                 torch.zeros_like(pred_squared_slowness)),
                dim=-1)
            squared_slowness = torch.stack(
                (pred_squared_slowness,
                 torch.zeros_like(pred_squared_slowness)),
                dim=-1)
            squared_slowness = torch.where(
                (torch.abs(x[..., 0, None]) > 0.75) |
                (torch.abs(x[..., 1, None]) > 0.75), squared_slowness_init,
                squared_slowness)
        y = y[:, :, :-1]

    du, status = diff_operators.jacobian(y, x)
    dudx1 = du[..., 0]
    dudx2 = du[..., 1]

    a0 = 5.0

    # let pml extend from -1. to -1 + Lpml and 1 - Lpml to 1.0
    Lpml = 0.5
    dist_west = -torch.clamp(x[..., 0] + (1.0 - Lpml), max=0)
    dist_east = torch.clamp(x[..., 0] - (1.0 - Lpml), min=0)
    dist_south = -torch.clamp(x[..., 1] + (1.0 - Lpml), max=0)
    dist_north = torch.clamp(x[..., 1] - (1.0 - Lpml), min=0)

    sx = wavenumber * a0 * ((dist_west / Lpml)**2 +
                            (dist_east / Lpml)**2)[..., None]
    sy = wavenumber * a0 * ((dist_north / Lpml)**2 +
                            (dist_south / Lpml)**2)[..., None]

    ex = torch.cat((torch.ones_like(sx), -sx / wavenumber), dim=-1)
    ey = torch.cat((torch.ones_like(sy), -sy / wavenumber), dim=-1)

    A = modules.compl_div(ey, ex).repeat(1, 1, dudx1.shape[-1] // 2)
    B = modules.compl_div(ex, ey).repeat(1, 1, dudx1.shape[-1] // 2)
    C = modules.compl_mul(ex, ey).repeat(1, 1, dudx1.shape[-1] // 2)

    a, _ = diff_operators.jacobian(modules.compl_mul(A, dudx1), x)
    b, _ = diff_operators.jacobian(modules.compl_mul(B, dudx2), x)

    a = a[..., 0]
    b = b[..., 1]
    c = modules.compl_mul(modules.compl_mul(C, squared_slowness),
                          wavenumber**2 * y)

    diff_constraint_hom = a + b + c
    diff_constraint_on = torch.where(
        source_boundary_values != 0.,
        diff_constraint_hom - source_boundary_values,
        torch.zeros_like(diff_constraint_hom))
    diff_constraint_off = torch.where(source_boundary_values == 0.,
                                      diff_constraint_hom,
                                      torch.zeros_like(diff_constraint_hom))
    if full_waveform_inversion:
        data_term = torch.where(rec_boundary_values != 0,
                                y - rec_boundary_values,
                                torch.Tensor([0.]).cuda())
    else:
        data_term = torch.Tensor([0.])

        if 'pretrain' in gt:  # we are not trying to solve for velocity
            data_term = pred_squared_slowness - squared_slowness[..., 0]

    return {
        'diff_constraint_on':
        torch.abs(diff_constraint_on).sum() * batch_size / 1e3,
        'diff_constraint_off': torch.abs(diff_constraint_off).sum(),
        'data_term': torch.abs(data_term).sum() * batch_size / 1
    }