def model_inference(args, lres, pde_layer): # select inference device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # construct model print(f"Loading model parameters from {args.ckpt}...") igres = (int(args.nt/args.downsamp_t), int(args.nz/args.downsamp_xz), int(args.nx/args.downsamp_xz),) unet = UNet3d(in_features=4, out_features=args.lat_dims, igres=igres, nf=args.unet_nf, mf=args.unet_mf) imnet = ImNet(dim=3, in_features=args.lat_dims, out_features=4, nf=args.imnet_nf, activation=NONLINEARITIES[args.nonlin]) # load model params resume_dict = torch.load(args.ckpt) unet.load_state_dict(resume_dict["unet_state_dict"]) imnet.load_state_dict(resume_dict["imnet_state_dict"]) unet.to(device) imnet.to(device) unet.eval() imnet.eval() all_model_params = list(unet.parameters())+list(imnet.parameters()) # evaluate latent_grid = unet(torch.tensor(lres, dtype=torch.float32)[None].to(device)) latent_grid = latent_grid.permute(0, 2, 3, 4, 1) # [batch, T, Z, X, C] # create evaluation grid t_max = float(args.eval_tres/args.nt) z_max = 1 x_max = 4 # layout query points for the desired slices eps = 1e-6 t_seq = torch.linspace(eps, t_max-eps, args.eval_tres) # temporal sequences z_seq = torch.linspace(eps, z_max-eps, args.eval_zres) # z sequences x_seq = torch.linspace(eps, x_max-eps, args.eval_xres) # x sequences mins = torch.zeros(3, dtype=torch.float32, device=device) maxs = torch.tensor([t_max, z_max, x_max], dtype=torch.float32, device=device) # define lambda function for pde_layer fwd_fn = lambda points: query_local_implicit_grid(imnet, latent_grid, points, mins, maxs) # update pde layer and compute predicted values + pde residues pde_layer.update_forward_method(fwd_fn) res_dict = evaluate_feat_grid(pde_layer, latent_grid, t_seq, z_seq, x_seq, mins, maxs, args.eval_pseudo_batch_size) return res_dict
def test_query_local_implicit_grid(self, batch_size, npts, n_dim, n_in, n_out, n_filter, latent_grid_size, xmin, xmax): """unit test.""" query_pts = torch.rand(batch_size, npts, n_dim) model = implicit_net.ImNet(dim=n_dim, in_features=n_in, out_features=n_out, nf=n_filter) latent_grid = torch.rand(batch_size, *([latent_grid_size] * n_dim), n_in) # [b, n1, ..., nd, c] # import pdb; pdb.set_trace() out = lig.query_local_implicit_grid(model, latent_grid, query_pts, xmin, xmax) np.testing.assert_allclose(out.shape, [batch_size, npts, n_out], atol=1e-4)
def test_local_implicit_grid_with_pde_layer_diff_eqn(self, pde_dict): """integration test for diffusion equation.""" # setup parameters batch_size = 8 # batch size grid_res = 16 # grid resolution nc = 32 # number of latent channels n_filter = 16 # number of filters in neural net n_pts = 1024 # number of query points in_vars = pde_dict['in_vars'] out_vars = pde_dict['out_vars'] eqn_strs = pde_dict['eqn_strs'] eqn_names = pde_dict['eqn_names'] dim_in = len(pde_dict['in_vars'].split(',')) dim_out = len(pde_dict['out_vars'].split(',')) # setup local implicit grid as forward function latent_grid = torch.rand(batch_size, grid_res, grid_res, grid_res, nc) query_pts = torch.rand(batch_size, n_pts, dim_in) model = ImNet(dim=dim_in, in_features=nc, out_features=dim_out, nf=n_filter) fwd_fn = lambda query_pts: query_local_implicit_grid( model, latent_grid, query_pts, 0., 1.) # setup pde layer pdel = PDELayer(in_vars=in_vars, out_vars=out_vars) for eqn_str, eqn_name in zip(eqn_strs, eqn_names): pdel.add_equation(eqn_str, eqn_name) pdel.update_forward_method(fwd_fn) val, res = pdel(query_pts) # it's harder to check values due to the randomness of the neural net. so we test shape # instead np.testing.assert_allclose(val.shape, [batch_size, n_pts, dim_out]) for key in res.keys(): res_value = res[key] np.testing.assert_allclose(res_value.shape, [batch_size, n_pts, 1])
def train(args, unet, imnet, train_loader, epoch, global_step, device, logger, writer, optimizer, pde_layer): """Training function.""" unet.train() imnet.train() tot_loss = 0 count = 0 xmin = torch.zeros(3, dtype=torch.float32).to(device) xmax = torch.ones(3, dtype=torch.float32).to(device) loss_func = loss_functional(args.reg_loss_type) for batch_idx, data_tensors in enumerate(train_loader): # send tensors to device data_tensors = [t.to(device) for t in data_tensors] input_grid, point_coord, point_value = data_tensors optimizer.zero_grad() latent_grid = unet(input_grid) # [batch, N, C, T, X, Y] # permute such that C is the last channel for local implicit grid query latent_grid = latent_grid.permute(0, 2, 3, 4, 1) # [batch, N, T, X, Y, C] # define lambda function for pde_layer fwd_fn = lambda points: query_local_implicit_grid(imnet, latent_grid, points, xmin, xmax) # update pde layer and compute predicted values + pde residues pde_layer.update_forward_method(fwd_fn) pred_value, residue_dict = pde_layer(point_coord, return_residue=True) # function value regression loss reg_loss = loss_func(pred_value, point_value) # pde residue loss pde_tensors = torch.stack([d for d in residue_dict.values()], dim=0) pde_loss = loss_func(pde_tensors, torch.zeros_like(pde_tensors)) loss = args.alpha_reg * reg_loss + args.alpha_pde * pde_loss if args.use_apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # gradient clipping torch.nn.utils.clip_grad_value_(unet.module.parameters(), args.clip_grad) torch.nn.utils.clip_grad_value_(imnet.module.parameters(), args.clip_grad) optimizer.step() tot_loss += loss.item() count += input_grid.size()[0] if batch_idx % args.log_interval == 0: if args.rank == 0: # logger log logger.info( "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss Sum: {:.6f}\t" "Loss Reg: {:.6f}\tLoss Pde: {:.6f}".format( epoch, batch_idx * len(input_grid) * args.nprocs, len(train_loader) * len(input_grid) * args.nprocs, 100. * batch_idx / len(train_loader), loss.item(), args.alpha_reg * reg_loss, args.alpha_pde * pde_loss)) # tensorboard log writer.add_scalar('train/reg_loss_unweighted', reg_loss, global_step=int(global_step)) writer.add_scalar('train/pde_loss_unweighted', pde_loss, global_step=int(global_step)) writer.add_scalar('train/sum_loss', loss, global_step=int(global_step)) writer.add_scalars('train/losses_weighted', {"reg_loss": args.alpha_reg * reg_loss, "pde_loss": args.alpha_pde * pde_loss, "sum_loss": loss}, global_step=int(global_step)) global_step += 1 tot_loss /= count return tot_loss
def eval(args, unet, imnet, eval_loader, epoch, global_step, device, logger, writer, optimizer, pde_layer): """Eval function. Used for evaluating entire slices and comparing to GT.""" unet.eval() imnet.eval() phys_channels = ["p", "b", "u", "w"] phys2id = dict(zip(phys_channels, range(len(phys_channels)))) xmin = torch.zeros(3, dtype=torch.float32).to(device) xmax = torch.ones(3, dtype=torch.float32).to(device) for data_tensors in eval_loader: # only need the first batch break # send tensors to device data_tensors = [t.to(device) for t in data_tensors] hres_grid, lres_grid, _, _ = data_tensors latent_grid = unet(lres_grid) # [batch, C, T, Z, X] nb, nc, nt, nz, nx = hres_grid.shape # permute such that C is the last channel for local implicit grid query latent_grid = latent_grid.permute(0, 2, 3, 4, 1) # [batch, T, Z, X, C] # define lambda function for pde_layer fwd_fn = lambda points: query_local_implicit_grid(imnet, latent_grid, points, xmin, xmax) # update pde layer and compute predicted values + pde residues pde_layer.update_forward_method(fwd_fn) # layout query points for the desired slices eps = 1e-6 t_seq = torch.linspace(eps, 1-eps, nt)[::int(nt/8)] # temporal sequences z_seq = torch.linspace(eps, 1-eps, nz) # z sequences x_seq = torch.linspace(eps, 1-eps, nx) # x sequences query_coord = torch.stack(torch.meshgrid(t_seq, z_seq, x_seq), axis=-1) # [nt, nz, nx, 3] query_coord = query_coord.reshape([-1, 3]).to(device) # [nt*nz*nx, 3] n_query = query_coord.shape[0] res_dict = defaultdict(list) n_iters = int(np.ceil(n_query/args.pseudo_batch_size)) for idx in range(n_iters): sid = idx * args.pseudo_batch_size eid = min(sid+args.pseudo_batch_size, n_query) query_coord_batch = query_coord[sid:eid] query_coord_batch = query_coord_batch[None].expand(*(nb, eid-sid, 3)) # [nb, eid-sid, 3] pred_value, residue_dict = pde_layer(query_coord_batch, return_residue=True) pred_value = pred_value.detach() for key in residue_dict.keys(): residue_dict[key] = residue_dict[key].detach() for name, chan_id in zip(phys_channels, range(4)): res_dict[name].append(pred_value[..., chan_id]) # [b, pb] for name, val in residue_dict.items(): res_dict[name].append(val[..., 0]) # [b, pb] for key in res_dict.keys(): res_dict[key] = (torch.cat(res_dict[key], axis=1) .reshape([nb, len(t_seq), len(z_seq), len(x_seq)])) # log the imgs sample-by-sample if args.rank == 0: for samp_id in range(nb): for key in res_dict.keys(): field = res_dict[key][samp_id] # [nt, nz, nx] # add predicted slices images = utils.batch_colorize_scalar_tensors(field) # [nt, nz, nx, 3] writer.add_images('sample_{}/{}/predicted'.format(samp_id, key), images, dataformats='NHWC', global_step=int(global_step)) # add ground truth slices (only for phys channels) if key in phys_channels: gt_fields = hres_grid[samp_id, phys2id[key], ::int(nt/8)] # [nt, nz, nx] gt_images = utils.batch_colorize_scalar_tensors(gt_fields) # [nt, nz, nx, 3] writer.add_images('sample_{}/{}/ground_truth'.format(samp_id, key), gt_images, dataformats='NHWC', global_step=int(global_step))