示例#1
0
def model_inference(args, lres, pde_layer):
    # select inference device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # construct model
    print(f"Loading model parameters from {args.ckpt}...")
    igres = (int(args.nt/args.downsamp_t),
             int(args.nz/args.downsamp_xz),
             int(args.nx/args.downsamp_xz),)
    unet = UNet3d(in_features=4, out_features=args.lat_dims, igres=igres,
                  nf=args.unet_nf, mf=args.unet_mf)
    imnet = ImNet(dim=3, in_features=args.lat_dims, out_features=4, nf=args.imnet_nf,
                  activation=NONLINEARITIES[args.nonlin])

    # load model params
    resume_dict = torch.load(args.ckpt)
    unet.load_state_dict(resume_dict["unet_state_dict"])
    imnet.load_state_dict(resume_dict["imnet_state_dict"])

    unet.to(device)
    imnet.to(device)
    unet.eval()
    imnet.eval()
    all_model_params = list(unet.parameters())+list(imnet.parameters())

    # evaluate
    latent_grid = unet(torch.tensor(lres, dtype=torch.float32)[None].to(device))
    latent_grid = latent_grid.permute(0, 2, 3, 4, 1)  # [batch, T, Z, X, C]

    # create evaluation grid
    t_max = float(args.eval_tres/args.nt)
    z_max = 1
    x_max = 4

    # layout query points for the desired slices
    eps = 1e-6
    t_seq = torch.linspace(eps, t_max-eps, args.eval_tres)  # temporal sequences
    z_seq = torch.linspace(eps, z_max-eps, args.eval_zres)  # z sequences
    x_seq = torch.linspace(eps, x_max-eps, args.eval_xres)  # x sequences

    mins = torch.zeros(3, dtype=torch.float32, device=device)
    maxs = torch.tensor([t_max, z_max, x_max], dtype=torch.float32, device=device)

    # define lambda function for pde_layer
    fwd_fn = lambda points: query_local_implicit_grid(imnet, latent_grid, points, mins, maxs)

    # update pde layer and compute predicted values + pde residues
    pde_layer.update_forward_method(fwd_fn)

    res_dict = evaluate_feat_grid(pde_layer, latent_grid, t_seq, z_seq, x_seq, mins, maxs,
                                  args.eval_pseudo_batch_size)

    return res_dict
    def test_query_local_implicit_grid(self, batch_size, npts, n_dim, n_in,
                                       n_out, n_filter, latent_grid_size, xmin,
                                       xmax):
        """unit test."""
        query_pts = torch.rand(batch_size, npts, n_dim)
        model = implicit_net.ImNet(dim=n_dim,
                                   in_features=n_in,
                                   out_features=n_out,
                                   nf=n_filter)
        latent_grid = torch.rand(batch_size, *([latent_grid_size] * n_dim),
                                 n_in)
        # [b, n1, ..., nd, c]
        # import pdb; pdb.set_trace()
        out = lig.query_local_implicit_grid(model, latent_grid, query_pts,
                                            xmin, xmax)

        np.testing.assert_allclose(out.shape, [batch_size, npts, n_out],
                                   atol=1e-4)
示例#3
0
    def test_local_implicit_grid_with_pde_layer_diff_eqn(self, pde_dict):
        """integration test for diffusion equation."""
        # setup parameters
        batch_size = 8  # batch size
        grid_res = 16  # grid resolution
        nc = 32  # number of latent channels
        n_filter = 16  # number of filters in neural net
        n_pts = 1024  # number of query points

        in_vars = pde_dict['in_vars']
        out_vars = pde_dict['out_vars']
        eqn_strs = pde_dict['eqn_strs']
        eqn_names = pde_dict['eqn_names']

        dim_in = len(pde_dict['in_vars'].split(','))
        dim_out = len(pde_dict['out_vars'].split(','))

        # setup local implicit grid as forward function
        latent_grid = torch.rand(batch_size, grid_res, grid_res, grid_res, nc)
        query_pts = torch.rand(batch_size, n_pts, dim_in)
        model = ImNet(dim=dim_in,
                      in_features=nc,
                      out_features=dim_out,
                      nf=n_filter)
        fwd_fn = lambda query_pts: query_local_implicit_grid(
            model, latent_grid, query_pts, 0., 1.)

        # setup pde layer
        pdel = PDELayer(in_vars=in_vars, out_vars=out_vars)
        for eqn_str, eqn_name in zip(eqn_strs, eqn_names):
            pdel.add_equation(eqn_str, eqn_name)
        pdel.update_forward_method(fwd_fn)
        val, res = pdel(query_pts)

        # it's harder to check values due to the randomness of the neural net. so we test shape
        # instead
        np.testing.assert_allclose(val.shape, [batch_size, n_pts, dim_out])
        for key in res.keys():
            res_value = res[key]
            np.testing.assert_allclose(res_value.shape, [batch_size, n_pts, 1])
示例#4
0
def train(args, unet, imnet, train_loader, epoch, global_step, device,
          logger, writer, optimizer, pde_layer):
    """Training function."""
    unet.train()
    imnet.train()
    tot_loss = 0
    count = 0
    xmin = torch.zeros(3, dtype=torch.float32).to(device)
    xmax = torch.ones(3, dtype=torch.float32).to(device)
    loss_func = loss_functional(args.reg_loss_type)
    for batch_idx, data_tensors in enumerate(train_loader):
        # send tensors to device
        data_tensors = [t.to(device) for t in data_tensors]
        input_grid, point_coord, point_value = data_tensors
        optimizer.zero_grad()
        latent_grid = unet(input_grid)  # [batch, N, C, T, X, Y]
        # permute such that C is the last channel for local implicit grid query
        latent_grid = latent_grid.permute(0, 2, 3, 4, 1)  # [batch, N, T, X, Y, C]

        # define lambda function for pde_layer
        fwd_fn = lambda points: query_local_implicit_grid(imnet, latent_grid, points, xmin, xmax)

        # update pde layer and compute predicted values + pde residues
        pde_layer.update_forward_method(fwd_fn)
        pred_value, residue_dict = pde_layer(point_coord, return_residue=True)

        # function value regression loss
        reg_loss = loss_func(pred_value, point_value)

        # pde residue loss
        pde_tensors = torch.stack([d for d in residue_dict.values()], dim=0)
        pde_loss = loss_func(pde_tensors, torch.zeros_like(pde_tensors))
        loss = args.alpha_reg * reg_loss + args.alpha_pde * pde_loss

        if args.use_apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        # gradient clipping
        torch.nn.utils.clip_grad_value_(unet.module.parameters(), args.clip_grad)
        torch.nn.utils.clip_grad_value_(imnet.module.parameters(), args.clip_grad)

        optimizer.step()

        tot_loss += loss.item()
        count += input_grid.size()[0]
        if batch_idx % args.log_interval == 0:
            if args.rank == 0:
                # logger log
                logger.info(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss Sum: {:.6f}\t"
                    "Loss Reg: {:.6f}\tLoss Pde: {:.6f}".format(
                        epoch, batch_idx * len(input_grid) * args.nprocs, 
                        len(train_loader) * len(input_grid) * args.nprocs,
                        100. * batch_idx / len(train_loader), loss.item(),
                        args.alpha_reg * reg_loss, args.alpha_pde * pde_loss))
                # tensorboard log
                writer.add_scalar('train/reg_loss_unweighted', 
                                  reg_loss, global_step=int(global_step))
                writer.add_scalar('train/pde_loss_unweighted', 
                                  pde_loss, global_step=int(global_step))
                writer.add_scalar('train/sum_loss', loss, global_step=int(global_step))
                writer.add_scalars('train/losses_weighted',
                                   {"reg_loss": args.alpha_reg * reg_loss,
                                    "pde_loss": args.alpha_pde * pde_loss,
                                    "sum_loss": loss}, global_step=int(global_step))

        global_step += 1
    tot_loss /= count
    return tot_loss
示例#5
0
def eval(args, unet, imnet, eval_loader, epoch, global_step, device,
         logger, writer, optimizer, pde_layer):
    """Eval function. Used for evaluating entire slices and comparing to GT."""
    unet.eval()
    imnet.eval()
    phys_channels = ["p", "b", "u", "w"]
    phys2id = dict(zip(phys_channels, range(len(phys_channels))))
    xmin = torch.zeros(3, dtype=torch.float32).to(device)
    xmax = torch.ones(3, dtype=torch.float32).to(device)
    for data_tensors in eval_loader:
        # only need the first batch
        break
    # send tensors to device
    data_tensors = [t.to(device) for t in data_tensors]
    hres_grid, lres_grid, _, _ = data_tensors
    latent_grid = unet(lres_grid)  # [batch, C, T, Z, X]
    nb, nc, nt, nz, nx = hres_grid.shape

    # permute such that C is the last channel for local implicit grid query
    latent_grid = latent_grid.permute(0, 2, 3, 4, 1)  # [batch, T, Z, X, C]

    # define lambda function for pde_layer
    fwd_fn = lambda points: query_local_implicit_grid(imnet, latent_grid, points, xmin, xmax)

    # update pde layer and compute predicted values + pde residues
    pde_layer.update_forward_method(fwd_fn)

    # layout query points for the desired slices
    eps = 1e-6
    t_seq = torch.linspace(eps, 1-eps, nt)[::int(nt/8)]  # temporal sequences
    z_seq = torch.linspace(eps, 1-eps, nz)  # z sequences
    x_seq = torch.linspace(eps, 1-eps, nx)  # x sequences

    query_coord = torch.stack(torch.meshgrid(t_seq, z_seq, x_seq), axis=-1)  # [nt, nz, nx, 3]
    query_coord = query_coord.reshape([-1, 3]).to(device)  # [nt*nz*nx, 3]
    n_query = query_coord.shape[0]

    res_dict = defaultdict(list)

    n_iters = int(np.ceil(n_query/args.pseudo_batch_size))

    for idx in range(n_iters):
        sid = idx * args.pseudo_batch_size
        eid = min(sid+args.pseudo_batch_size, n_query)
        query_coord_batch = query_coord[sid:eid]
        query_coord_batch = query_coord_batch[None].expand(*(nb, eid-sid, 3))  # [nb, eid-sid, 3]

        pred_value, residue_dict = pde_layer(query_coord_batch, return_residue=True)
        pred_value = pred_value.detach()
        for key in residue_dict.keys():
            residue_dict[key] = residue_dict[key].detach()
        for name, chan_id in zip(phys_channels, range(4)):
            res_dict[name].append(pred_value[..., chan_id])  # [b, pb]
        for name, val in residue_dict.items():
            res_dict[name].append(val[..., 0])   # [b, pb]

    for key in res_dict.keys():
        res_dict[key] = (torch.cat(res_dict[key], axis=1)
                         .reshape([nb, len(t_seq), len(z_seq), len(x_seq)]))

    # log the imgs sample-by-sample
    if args.rank == 0:
        for samp_id in range(nb):
            for key in res_dict.keys():
                field = res_dict[key][samp_id]  # [nt, nz, nx]
                # add predicted slices
                images = utils.batch_colorize_scalar_tensors(field)  # [nt, nz, nx, 3]

                writer.add_images('sample_{}/{}/predicted'.format(samp_id, key), images,
                    dataformats='NHWC', global_step=int(global_step))
                # add ground truth slices (only for phys channels)
                if key in phys_channels:
                    gt_fields = hres_grid[samp_id, phys2id[key], ::int(nt/8)]  # [nt, nz, nx]
                    gt_images = utils.batch_colorize_scalar_tensors(gt_fields)  # [nt, nz, nx, 3]

                    writer.add_images('sample_{}/{}/ground_truth'.format(samp_id, key), gt_images,
                        dataformats='NHWC', global_step=int(global_step))