示例#1
0
def render_rotating_volume(volume_model, scene_origin, n_frames,
                           video_size, **render_kwargs):
  """Render frames from a camera orbiting the volume."""

  print("Rendering rotating volume ...")
  render_azimuths = np.linspace(0., 360., n_frames)
  cam2worlds = [
      scene.pose_spherical(azim, phi=-30, radius=4) for azim in render_azimuths
  ]
  frames = []

  height, width, focal = scene.scale_intrinsics(video_size)

  for cam2world in tqdm.tqdm(cam2worlds, desc="Rendering rotating volume"):
    rays = scene.camera_rays(cam2world, height, width, focal)
    rendered, _ = nerf.render_rays_mip(
        rays,
        volume_model,
        origin=scene_origin.value,
        white_bkgd=True,
        **render_kwargs)  # rgb, depth, disparity, silhouette
    rendered = torch.cat(rendered, dim=-1)  # [H, W, 6]
    frames.append(rendered)

  frames = torch.stack(frames, dim=0)  # [n_frames, H, W, 6]
  frames = frames.cpu()
  rgb, depth, disparity, silhouette = torch.split(frames, [3, 1, 1, 1], dim=-1)
  return (
      rgb.numpy(),  # [T, H, W, 3]
      depth.numpy(),  # [T, H, W, 1]
      disparity.numpy(),  # [T, H, W, 1]
      silhouette.numpy())  # [T, H, W, 1]
示例#2
0
def render_validation_view(volume_model, scene_origin, render_size, max_size,
                           **render_kwargs):
  """Render a frame from a camera at a new perspective."""
  cam2world = scene.pose_spherical(theta=30, phi=-45, radius=4)
  height, width, focal = scene.scale_intrinsics(min(render_size, max_size))
  rays = scene.camera_rays(cam2world, height, width, focal)
  (rgb, _, _, _), _ = nerf.render_rays_mip(
      rays,
      volume_model,
      origin=scene_origin.value,
      white_bkgd=True,
      chunksize=2**17,
      **render_kwargs)  # [H, W, 6]

  if height != render_size:
    rgb = rgb.movedim(-1, 0)[None]  # HWC to 1CHW
    rgb = F.interpolate(rgb, render_size, mode="bilinear")
    rgb = rgb[0].movedim(0, -1)  # 1CHW to HWC

  assert rgb.ndim == 3
  return rgb
示例#3
0
def main():
  # Reproducibility.
  torch.manual_seed(args.seed)
  random.seed(args.seed)
  np.random.seed(args.seed)

  # Obtain the utilized device.
  if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
    has_cuda = True
  else:
    device = torch.device("cpu")
    has_cuda = False

  # Setup logging.
  query = queries[args.query_idx]
  exp_name = f"{args.exp_name_prefix}_{args.query_idx:03d}_{query}"
  wandb.init(
      entity=args.wandb_entity,
      project=args.wandb_project,
      name=exp_name,
      config=args)

  # Initialize CLIP
  if args.clip_lam:
    model, preprocess, clip_size = load_clip(args.loss_model, device)
    model.eval()

  if args.retrieve_model == args.loss_model and args.clip_lam:
    test_model, test_preprocess, test_clip_size = model, preprocess, clip_size
  else:
    test_model, test_preprocess, test_clip_size = load_clip(
        args.retrieve_model, device)
    test_model.eval()

  # Initialize the volumetric model.
  volume_model = nerf.DreamFieldsMLP(
      activation="SiLU",
      features_early=[96],  # Dense layers before residual blocks.
      features_residual=[(128, 96)] * 3,  # Resid block feature dimensions.
      features_late=[96, 4],  # Features dimensions at end.
      fourfeat=args.fourfeat,
      max_deg=args.posenc_deg,
      ipe=args.ipe,
  )
  volume_model = nn.DataParallel(volume_model)
  volume_model = volume_model.to(device)
  scene_origin = scene.EMA(np.zeros(3, dtype=np.float64), decay=0.999)
  render_kwargs = dict(
      sigma_noise_std=args.sigma_noise_std,
      near=4. - math.sqrt(3) * args.volume_extent_world / 2,
      far=4. + math.sqrt(3) * args.volume_extent_world / 2,
      mask_rad=args.volume_extent_world / 2,
      n_pts_per_ray=args.n_pts_per_ray,
      device=device,
  )

  # Instantiate the Adam optimizer.
  optimizer = torch.optim.Adam(
      volume_model.parameters(), lr=args.lr_init, eps=args.adam_eps)
  scaler = torch.cuda.amp.GradScaler()

  # Embed the target caption with CLIP.
  if args.clip_lam:
    query_tok = clip.tokenize(query).to(device)
    z_clip = model.encode_text(query_tok).detach()
    z_clip = F.normalize(z_clip, dim=-1)

    clip_aug_fn = torchvision.transforms.RandomResizedCrop(
        clip_size, scale=args.crop_scale_range, ratio=(1.0, 1.0))

  if args.diffusion_lam:
    # Initialize GLIDE. Create base model.
    base_glide_model, diffusion, base_glide_options = load_diffusion(
        "base", device, has_cuda=has_cuda)
    base_glide_model.eval()

    # Embed the target caption with GLIDE.
    denoise_batch_size = (
        args.n_aug * args.n_views if args.denoise_augmented else args.n_views)
    tokens = base_glide_model.tokenizer.encode(query)
    tokens, mask = base_glide_model.tokenizer.padded_tokens_and_mask(
        tokens, base_glide_options["text_ctx"])

    # Create the classifier-free guidance tokens (empty).
    uncond_tokens, uncond_mask = base_glide_model.tokenizer.padded_tokens_and_mask(
        [], base_glide_options["text_ctx"])

    # Pack the tokens together into model kwargs.
    base_model_kwargs = dict(
        tokens=torch.tensor(
            [tokens] * denoise_batch_size +
            [uncond_tokens] * denoise_batch_size,
            device=device),
        mask=torch.tensor(
            [mask] * denoise_batch_size + [uncond_mask] * denoise_batch_size,
            dtype=torch.bool,
            device=device),
    )

    parallel_glide = nn.DataParallel(base_glide_model)

    # Create an classifier-free guidance sampling function.
    def base_model_fn(x_t, ts, **kwargs):
      half = x_t[:len(x_t) // 2]
      combined = torch.cat([half, half], dim=0)
      model_out = parallel_glide(combined, ts, **kwargs)
      eps, rest = model_out[:, :3], model_out[:, 3:]
      cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
      half_eps = uncond_eps + args.guidance_scale * (cond_eps - uncond_eps)
      eps = torch.cat([half_eps, half_eps], dim=0)
      return torch.cat([eps, rest], dim=1)

    def preprocess_glide(x, order="NHWC"):
      if order == "NHWC":
        # x is [NHWC]. Reshape to NCHW.
        x = x.movedim(-1, 1)
      x = x * 2 - 1  # Scale from [0, 1] to [-1, 1].
      if x.shape[-2:] != (64, 64):
        x = F.interpolate(x, (64, 64), mode="bilinear")
      return x

    def unprocess_glide(x):
      return (x + 1) / 2  # Scale from [-1, 1] to [0, 1].

    denoised_fn = lambda x_start: x_start
    denoise_aug_fn = torchvision.transforms.RandomResizedCrop(
        64, scale=args.crop_scale_range, ratio=(1.0, 1.0))

    glide_context_manager = (
        torch.no_grad if args.denoise_stop_grad else torch.enable_grad)

    # Initialize each chain.
    diffusion_x = torch.randn((args.n_views, 3, 64, 64),
                              device=device,
                              requires_grad=False)
    diffusion_t = torch.full(
        size=(args.n_views,),
        fill_value=args.t_respace - 1,
        requires_grad=False,
        dtype=torch.long,
        device=device)

    # Training uses n_iter iterations: 1 to n_iter (inclusive).
    # Diffusion uses t_respace timesteps: t_respace-1 to 0 (inclusive).
    # For now, check they are equal.
    # TODO(jainajay): implement sampling with non-unit timesteps.
    assert args.t_respace * args.denoise_every == args.n_iter

  # Get a batch of viewing angles and pre-generate rays.
  azimuths = np.arange(args.n_views) * 360. / args.n_views
  rads = np.full(args.n_views, 4.)
  focal_mults = np.full(args.n_views, 1.2)
  elevations = [
      scene.uniform_in_interval(args.elevation_range)
      for _ in range(args.n_views)
  ]
  cam2worlds = [
      scene.pose_spherical(azim, phi=elev, radius=rad)
      for azim, elev, rad in zip(azimuths, elevations, rads)
  ]
  height, width, focal = scene.scale_intrinsics(args.render_size)
  # Generate rays: 3-tuple of [n_views, H, W, n_pts_per_ray, 3 or 1].
  rays_all_views = scene.camera_rays_batched(cam2worlds, height, width,
                                             focal_mults * focal)

  pbar = tqdm.trange(1, args.n_iter + 1)
  for iteration in pbar:
    metrics = {}
    visualize_images = iteration % 25 == 0 or iteration == 1

    # Set learning rate
    lr = schedule.learning_rate_decay(
        iteration,
        args.lr_init,
        args.lr_final,
        args.n_iter,
        lr_delay_steps=min(args.n_iter // 8, 2500),
        lr_delay_mult=args.lr_delay_mult)
    for g in optimizer.param_groups:
      g["lr"] = float(lr)

    # Zero the optimizer gradient.
    optimizer.zero_grad()

    # Render the volumetric model from random perspectives.
    batch_idx = np.random.choice(
        args.n_views, size=args.batch_size, replace=False)
    rays_batched = [r[batch_idx] for r in rays_all_views]

    # Runs the forward pass with automatic precision casting.
    with torch.cuda.amp.autocast():
      (images, depths, disparities, silhouettes), _ = nerf.render_rays_mip(
          rays_batched,
          volume_model,
          origin=scene_origin.value,
          **render_kwargs)
      assert images.ndim == 4
      assert images.shape[0] == args.batch_size
      assert images.shape[-1] == 3

      # Transmittance loss. Anneal target opacity (1 - transmittance).
      target_opacity = schedule.anneal_logarithmically(
          iteration, args.target_transmittance_anneal_iters,
          1 - args.target_transmittance0, 1 - args.target_transmittance1)

      # The area of an object on the image plane grows with the focal length
      # and shrinks with increasing camera radius. Scale target opacity
      # proportionally with the squared focal multiplier and inversely
      # proportionally with the squared camera radius.
      target_opacities = np.minimum(
          np.ones(args.batch_size), focal_mults[batch_idx]**2 /
          (rads[batch_idx] / 4.)**2 * target_opacity)
      taus = torch.tensor(1 - target_opacities, device=device)
      avg_transmittance = 1 - silhouettes.mean(
          dim=tuple(range(1, silhouettes.ndim)))
      # NOTE(jainajay): Using a modified, two-sided transmittance loss that
      # differs from Dream Fields. It can encourage reducing transmittance if
      # the scene becomes too sparse. The original loss would penalize
      # -torch.mean(torch.min(avg_transmittance, taus)).
      transmittance_loss = torch.mean(torch.abs(avg_transmittance - taus))

      # Data augmentation.
      if (args.diffusion_lam > 0 and
          args.denoise_augmented) or args.clip_lam > 0:
        # NOTE(jainajay): this background is at the render resolution,
        #       not the resize, unlike Dream Fields.
        # Generate random backgrounds.
        bgs = augment.sample_backgrounds(
            num=args.n_aug * args.batch_size,
            res=args.render_size,
            checkerboard_nsq=args.nsq,
            min_blur_std=args.bg_blur_std_range[0],
            max_blur_std=args.bg_blur_std_range[1],
            device=device)

        # Composite renders with backgrounds.
        bgs = bgs.view(args.n_aug, args.batch_size, *bgs.shape[1:])  # ANCHW.
        bgs = bgs.movedim(2, -1)  # Convert ANCHW to ANHWC.
        composite_images = (
            silhouettes[None] * images[None] + (1 - silhouettes[None]) * bgs)
        composite_images = composite_images.reshape(  # to A*N,H,W,C.
            args.n_aug * args.batch_size, args.render_size, args.render_size, 3)
        composite_images = composite_images.movedim(3, 1)  # NHWC to NCHW.

      # Compute GLIDE loss.
      # Sample from the base model.
      if args.diffusion_lam:
        # Preprocess rendering (scale to [-1, 1]).
        if args.denoise_augmented:
          denoise_aug_images = denoise_aug_fn(composite_images)
          inp = preprocess_glide(denoise_aug_images, order="NCHW")
        else:
          inp = silhouettes * images + 1 - silhouettes  # white bg
          inp = preprocess_glide(inp, order="NHWC")

        if (iteration - 1) % args.denoise_every == 0:
          base_glide_model.del_cache()

          # Sampling step for every view in the cache.
          with glide_context_manager():
            assert diffusion_t.dtype == torch.long
            assert torch.all(diffusion_t == diffusion_t[0])
            metrics["diffusion/t"] = diffusion_t[0].item()

            xt = diffusion_x  # || x_hat(x_t) - render ||^2

            # Enable for loss: || x_hat(diffuse(render)) - x_hat(x_t) ||^2
            # x = diffusion.q_sample(
            #     inp, torch.tensor([diffusion_t] * denoise_batch_size,
            #     device=device))

            # Sample x_s from x_t using DDIM.
            # Based on glide-text2im/glide_text2im/gaussian_diffusion.py#L453
            assert args.batch_size == args.n_views  # Updating all chains.
            out = diffusion.p_mean_variance(
                base_model_fn,
                torch.cat([xt, xt], dim=0),
                torch.cat([diffusion_t, diffusion_t], dim=0),
                clip_denoised=True,
                denoised_fn=denoised_fn,  # TODO(jainajay): look into this,
                model_kwargs=base_model_kwargs,
            )
            assert out["pred_xstart"].shape[0] == 2 * args.batch_size
            pred_xstart = out["pred_xstart"][:args.batch_size]

            if iteration < args.independent_sampling_steps * args.denoise_every:
              # Ours: eps = pred_eps(x_t, t, tilde_x).
              # Ours: x_{t-1} = a * tilde_x + b * eps + sigma * noise.
              x0_for_sampling = pred_xstart
            else:
              # GLIDE: eps = pred_eps(x_t, t, x_hat(x_t)).
              # GLIDE: x_{t-1} = a * x_hat(x_t) + b * eps + sigma * noise.
              x0_for_sampling = inp.detach()

            # pylint: disable=protected-access
            eps = diffusion._predict_eps_from_xstart(diffusion_x, diffusion_t,
                                                     x0_for_sampling)
            # pylint: enable=protected-access

            assert eps.shape[0] == args.batch_size

            alpha_bar = _extract_into_tensor(diffusion.alphas_cumprod,
                                             diffusion_t, xt.shape)
            metrics["diffusion/alpha_bar"] = alpha_bar.mean().item()
            alpha_bar_prev = _extract_into_tensor(diffusion.alphas_cumprod_prev,
                                                  diffusion_t, xt.shape)
            metrics["diffusion/alpha_bar_prev"] = alpha_bar_prev.mean().item()
            sigma = (
                args.ddim_eta * torch.sqrt(
                    (1 - alpha_bar_prev) / (1 - alpha_bar)) *
                torch.sqrt(1 - alpha_bar / alpha_bar_prev))
            metrics["diffusion/sigma"] = sigma.mean().item()
            # Equation 12.
            mean_pred = (
                x0_for_sampling * torch.sqrt(alpha_bar_prev) +
                torch.sqrt(1 - alpha_bar_prev - sigma**2) * eps)
            nonzero_mask = (
                (diffusion_t != 0).float().view(-1,
                                                *([1] * (len(xt.shape) - 1)))
            )  # No noise when t == 0.
            noise = torch.randn_like(xt)
            sample = mean_pred + nonzero_mask * sigma * noise

            # Update multiview sampling chains.
            diffusion_x_prev = diffusion_x
            diffusion_x = sample
            diffusion_t = diffusion_t - 1

            # Don't backprop through the denoiser (forces stop_grad True).
            assert args.denoise_stop_grad
            pred_xstart = pred_xstart.detach()

          base_glide_model.del_cache()

        # Loss: ||x_hat(x_t) - render||^2.
        # Slicing the predictions only optimizes a few views.
        diffusion_loss = F.mse_loss(pred_xstart[:args.n_optimize],
                                    inp[:args.n_optimize])

        # TODO(jainajay): Try other losses. Some possibilities:
        #   ||x_hat(render) - render||^2 (change L480)
        #   ||x_hat(x_t) - x_hat(diffuse(render))||^2
        #         (change denosing code to denoise render and x_t)
        #   ||eps - eps_hat(diffuse(render), eps)||^2
        #   ||eps_hat(x_t) - eps_hat(diffuse(render), eps)||^2
        #       only makes sense if that's the eps in x_t
        metrics["loss/diffusion_mse"] = diffusion_loss
      else:
        diffusion_loss = torch.tensor([0.], device=device)

      # Compute the CLIP loss.
      if args.clip_lam:
        clip_aug_images = clip_aug_fn(composite_images)
        x = preprocess(clip_aug_images)  # Resize and normalize.
        z_est = model.encode_image(x)
        z_est = F.normalize(z_est, dim=-1)
        clip_loss = -torch.sum(z_est * z_clip, dim=-1).mean()
      else:
        clip_loss = torch.tensor([0.], device=device)

      # Compute total loss and take an optimization step.
      loss = (
          args.clip_lam * clip_loss +
          args.transmittance_lam * transmittance_loss +
          args.diffusion_lam * diffusion_loss)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    if args.track_scene_origin:
      raise NotImplementedError

    # Logging.
    with torch.inference_mode():
      volume_model.eval()

      metrics["train/depths/min"] = depths.min()
      metrics["train/depths/max"] = depths.max()
      metrics["train/disparities/min"] = disparities.min()
      metrics["train/disparities/max"] = disparities.max()

      metrics.update({
          "schedule/lr": lr,
          "loss/total_loss": loss.item(),
          "loss/clip": clip_loss.item(),
          "loss/transmittance": transmittance_loss.item(),
          "train/avg_transmittance": avg_transmittance.mean().item()
      })

      # Print the current values of the losses.
      if iteration % 10 == 0:
        pbar.set_description(
            f"Iteration {iteration:05d}:" +
            f" clip_loss = {float(clip_loss.item()):1.2f}" +
            f" diffusion_loss = {float(diffusion_loss.item()):1.5f}" +
            f" avg transmittance = {float(avg_transmittance.mean().item()):1.2f}"
        )

      # Visualize the renders.
      if visualize_images:
        metrics["render/rendered"] = wandb_grid(images)
        metrics["render/silhouettes"] = wandb_grid(silhouettes)
        metrics["render/rendered_depth"] = wandb_grid(depths)

        if args.clip_lam > 0:
          metrics["render/augmented"] = wandb_grid(clip_aug_images)

        if args.diffusion_lam:
          # Show diffusion_x_prev, diffusion_x (sample), out['pred_xstart'].
          for name, val in zip(["x_t", "x_tm1", "pred_xstart"],
                               [diffusion_x_prev, diffusion_x, pred_xstart]):
            print("diffusion", name, val.shape, val.min(), val.max())
            val = unprocess_glide(val)  # [n_views, C, 64, 64]
            metrics[f"diffusion/{name}"] = wandb_grid(val)

      # Validate from a held-out view.
      if iteration % 250 == 0 or iteration == 1:
        validation_view = render_validation_view(
            volume_model,
            scene_origin,
            test_clip_size,
            args.max_validation_size,
            **render_kwargs)
        assert validation_view.ndim == 3
        assert validation_view.shape[-1] == 3
        metrics["val/render"] = wandb.Image(clamp_and_detach(validation_view))

        rank, cosine_sim = compute_query_rank(
            test_model,
            test_preprocess,
            render=validation_view.movedim(-1, 0).unsqueeze(0),
            query=query,
            queries_r=queries,
            device=device)

        metrics["val/rank"] = rank
        metrics["val/acc"] = int(rank == 0)
        metrics["val/cosine_sim"] = cosine_sim

      if iteration % 250 == 0 or iteration == 1:
        # Visualize the optimized volume by rendering from multiple viewpoints
        # that rotate around the volume's y-axis.
        video_frames = render_rotating_volume(
            volume_model,
            scene_origin=scene_origin,
            video_size=args.video_size,
            n_frames=args.video_n_frames,
            **render_kwargs)

        for name, frames in zip(["rgb", "depth", "disparity", "silhouette"],
                                video_frames):
          # frames is in THWC order.
          filename = f"/tmp/{iteration:05d}_{name}.mp4"
          if frames.shape[-1] == 1:
            media.write_video(filename, frames[Ellipsis, 0], fps=30)
          else:
            media.write_video(filename, frames, fps=30)
          print("wrote", filename,
                f"range: [{frames.min():.4f}, {frames.max():.4f}]")

          metrics[f"render/video/{name}"] = wandb.Video(
              filename, fps=30, format="mp4")

      wandb.log(metrics, iteration)

      volume_model.train()