def wave_pml(model_output, gt): source_boundary_values = gt['source_boundary_values'] x = model_output['model_in'] # (meta_batch_size, num_points, 3) y = model_output['model_out'] # (meta_batch_size, num_points, 1) squared_slowness = gt['squared_slowness'] dirichlet_mask = gt['dirichlet_mask'] batch_size = x.shape[1] du, status = diff_operators.jacobian(y, x) dudt = du[..., 0] if torch.all(dirichlet_mask): diff_constraint_hom = torch.Tensor([0]) else: hess, status = diff_operators.jacobian(du[..., 0, :], x) lap = hess[..., 1, 1, None] + hess[..., 2, 2, None] dudt2 = hess[..., 0, 0, None] diff_constraint_hom = dudt2 - 1 / squared_slowness * lap dirichlet = y[dirichlet_mask] - source_boundary_values[dirichlet_mask] neumann = dudt[dirichlet_mask] return { 'dirichlet': torch.abs(dirichlet).sum() * batch_size / 1e1, 'neumann': torch.abs(neumann).sum() * batch_size / 1e2, 'diff_constraint_hom': torch.abs(diff_constraint_hom).sum() }
def forward(self, model_input): input_dict = { key: input.clone().detach().requires_grad_(True) for key, input in model_input.items() } if self.input_processing_fn is not None: input_dict_transformed = self.input_processing_fn(input_dict) coords = input_dict_transformed['coords'] if self.nl != 'sine': coords_pe = self.pe(coords) output = self.net(coords_pe) else: output = self.net(coords) if self.use_grad: output = jacobian(output, input_dict_transformed[self.grad_var])[0][:, :, 0] return { 'model_in': input_dict_transformed, 'model_out': { 'output': output } }
def hji_MultiVehicleCollision(model_output, gt): source_boundary_values = gt['source_boundary_values'] x = model_output['model_in'] # (meta_batch_size, num_points, 4) y = model_output['model_out'] # (meta_batch_size, num_points, 1) dirichlet_mask = gt['dirichlet_mask'] batch_size = x.shape[1] if torch.all(dirichlet_mask): diff_constraint_hom = torch.Tensor([0]) else: du, status = diff_operators.jacobian(y, x) dudt = du[..., 0, 0] dudx = du[..., 0, 1:] # Scale the costate for theta appropriately to align with the range of [-pi, pi] dudx[..., num_pos_states:] = dudx[..., num_pos_states:] / alpha_angle # Compute the hamiltonian for the ego vehicle ham = velocity * ( torch.cos(alpha_angle * x[..., num_pos_states + 1]) * dudx[..., 0] + torch.sin( alpha_angle * x[..., num_pos_states + 1]) * dudx[..., 1] ) - omega_max * torch.abs(dudx[..., num_pos_states]) # Hamiltonian effect due to other vehicles for i in range(numEvaders): theta_index = num_pos_states + 1 + i + 1 xcostate_index = 2 * (i + 1) ycostate_index = 2 * (i + 1) + 1 thetacostate_index = num_pos_states + 1 + i ham_local = velocity * ( torch.cos(alpha_angle * x[..., theta_index]) * dudx[..., xcostate_index] + torch.sin(alpha_angle * x[..., theta_index]) * dudx[..., ycostate_index]) + omega_max * torch.abs( dudx[..., thetacostate_index]) ham = ham + ham_local # Effect of time factor ham = ham * alpha_time # If we are computing BRT then take min with zero if minWith == 'zero': ham = torch.clamp(ham, max=0.0) diff_constraint_hom = dudt - ham if minWith == 'target': diff_constraint_hom = torch.max( diff_constraint_hom[:, :, None], y - source_boundary_values) dirichlet = y[dirichlet_mask] - source_boundary_values[dirichlet_mask] # A factor of 15e2 to make loss roughly equal return { 'dirichlet': torch.abs(dirichlet).sum() * batch_size / 15e2, 'diff_constraint_hom': torch.abs(diff_constraint_hom).sum() }
def test_backward(self): # run our backward graph x = [torch.randn(1, 1).cuda() for i in range(4)] our_grad = self.test_sess.backward(x).squeeze() # run forward graph and calc grad using pytorch x[0] = torch.nn.Parameter(x[0]) out = self.test_sess(x) pytorch_grad = jacobian(out[None, None, :], x[0])[0].squeeze() self.assertTrue(torch.allclose(our_grad, pytorch_grad))
def hji_air3D(model_output, gt): source_boundary_values = gt['source_boundary_values'] x = model_output['model_in'] # (meta_batch_size, num_points, 4) y = model_output['model_out'] # (meta_batch_size, num_points, 1) dirichlet_mask = gt['dirichlet_mask'] batch_size = x.shape[1] du, status = diff_operators.jacobian(y, x) dudt = du[..., 0, 0] dudx = du[..., 0, 1:] x_theta = x[..., 3] * 1.0 # Scale the costate for theta appropriately to align with the range of [-pi, pi] dudx[..., 2] = dudx[..., 2] / alpha_angle # Scale the coordinates x_theta = alpha_angle * x_theta # Air3D dynamics # \dot x = -v_a + v_b \cos \psi + a y # \dot y = v_b \sin \psi - a x # \dot \psi = b - a # Compute the hamiltonian for the ego vehicle ham = omega_max * torch.abs( dudx[..., 0] * x[..., 2] - dudx[..., 1] * x[..., 1] - dudx[..., 2]) # Control component ham = ham - omega_max * torch.abs(dudx[..., 2]) # Disturbance component ham = ham + (velocity * (torch.cos(x_theta) - 1.0) * dudx[..., 0]) + ( velocity * torch.sin(x_theta) * dudx[..., 1]) # Constant component # If we are computing BRT then take min with zero if minWith == 'zero': ham = torch.clamp(ham, max=0.0) if torch.all(dirichlet_mask): diff_constraint_hom = torch.Tensor([0]) else: diff_constraint_hom = dudt - ham if minWith == 'target': diff_constraint_hom = torch.max( diff_constraint_hom[:, :, None], y - source_boundary_values) dirichlet = y[dirichlet_mask] - source_boundary_values[dirichlet_mask] # A factor of 15e2 to make loss roughly equal return { 'dirichlet': torch.abs(dirichlet).sum() * batch_size / 15e2, 'diff_constraint_hom': torch.abs(diff_constraint_hom).sum() }
def check_backward(): sampler = None input_processing_fn = None sigma_model = modules.RadianceNet(input_processing_fn=input_processing_fn, sampler=sampler, input_name=['ray_samples']) rgb_model = modules.RadianceNet(input_processing_fn=input_processing_fn, sampler=sampler) for name, model in {'sigma': sigma_model, 'rgb': rgb_model}.items(): print(f'Checking gradients for {name} model') t = torch.rand(128, 64, 1).cuda() ray_dirs = torch.rand(128, 64, 3).cuda() ray_origins = torch.rand(128, 64, 3).cuda() orientations = torch.rand(128, 64, 6).cuda() model_in = { 't': t, 'ray_directions': ray_dirs, 'ray_origins': ray_origins, 'ray_orientations': orientations } our_grad = model(model_in)['model_out']['output'].squeeze() t = torch.nn.Parameter(t) model.set_mode('integral') model_in = { 't': t, 'ray_directions': ray_dirs, 'ray_origins': ray_origins, 'ray_orientations': orientations } model_out = model(model_in) out = model_out['model_out']['output'] pytorch_grad = jacobian(out, model_out['model_in']['t'])[0].squeeze() #print(torch.abs(our_grad - pytorch_grad).max()) print('Passed' if torch.allclose(our_grad, pytorch_grad, atol=1e-6 ) else 'Failed!')
def helmholtz_pml(model_output, gt): source_boundary_values = gt['source_boundary_values'] if 'rec_boundary_values' in gt: rec_boundary_values = gt['rec_boundary_values'] wavenumber = gt['wavenumber'].float() x = model_output['model_in'] # (meta_batch_size, num_points, 2) y = model_output['model_out'] # (meta_batch_size, num_points, 2) squared_slowness = gt['squared_slowness'].repeat(1, 1, y.shape[-1] // 2) batch_size = x.shape[1] full_waveform_inversion = False if 'pretrain' in gt: pred_squared_slowness = y[:, :, -1] + 1. if torch.all(gt['pretrain'] == -1): full_waveform_inversion = True pred_squared_slowness = torch.clamp(y[:, :, -1], min=-0.999) + 1. squared_slowness_init = torch.stack( (torch.ones_like(pred_squared_slowness), torch.zeros_like(pred_squared_slowness)), dim=-1) squared_slowness = torch.stack( (pred_squared_slowness, torch.zeros_like(pred_squared_slowness)), dim=-1) squared_slowness = torch.where( (torch.abs(x[..., 0, None]) > 0.75) | (torch.abs(x[..., 1, None]) > 0.75), squared_slowness_init, squared_slowness) y = y[:, :, :-1] du, status = diff_operators.jacobian(y, x) dudx1 = du[..., 0] dudx2 = du[..., 1] a0 = 5.0 # let pml extend from -1. to -1 + Lpml and 1 - Lpml to 1.0 Lpml = 0.5 dist_west = -torch.clamp(x[..., 0] + (1.0 - Lpml), max=0) dist_east = torch.clamp(x[..., 0] - (1.0 - Lpml), min=0) dist_south = -torch.clamp(x[..., 1] + (1.0 - Lpml), max=0) dist_north = torch.clamp(x[..., 1] - (1.0 - Lpml), min=0) sx = wavenumber * a0 * ((dist_west / Lpml)**2 + (dist_east / Lpml)**2)[..., None] sy = wavenumber * a0 * ((dist_north / Lpml)**2 + (dist_south / Lpml)**2)[..., None] ex = torch.cat((torch.ones_like(sx), -sx / wavenumber), dim=-1) ey = torch.cat((torch.ones_like(sy), -sy / wavenumber), dim=-1) A = modules.compl_div(ey, ex).repeat(1, 1, dudx1.shape[-1] // 2) B = modules.compl_div(ex, ey).repeat(1, 1, dudx1.shape[-1] // 2) C = modules.compl_mul(ex, ey).repeat(1, 1, dudx1.shape[-1] // 2) a, _ = diff_operators.jacobian(modules.compl_mul(A, dudx1), x) b, _ = diff_operators.jacobian(modules.compl_mul(B, dudx2), x) a = a[..., 0] b = b[..., 1] c = modules.compl_mul(modules.compl_mul(C, squared_slowness), wavenumber**2 * y) diff_constraint_hom = a + b + c diff_constraint_on = torch.where( source_boundary_values != 0., diff_constraint_hom - source_boundary_values, torch.zeros_like(diff_constraint_hom)) diff_constraint_off = torch.where(source_boundary_values == 0., diff_constraint_hom, torch.zeros_like(diff_constraint_hom)) if full_waveform_inversion: data_term = torch.where(rec_boundary_values != 0, y - rec_boundary_values, torch.Tensor([0.]).cuda()) else: data_term = torch.Tensor([0.]) if 'pretrain' in gt: # we are not trying to solve for velocity data_term = pred_squared_slowness - squared_slowness[..., 0] return { 'diff_constraint_on': torch.abs(diff_constraint_on).sum() * batch_size / 1e3, 'diff_constraint_off': torch.abs(diff_constraint_off).sum(), 'data_term': torch.abs(data_term).sum() * batch_size / 1 }