Beispiel #1
0
def make_optimization_video(directory,
                            pattern="im_*.png",
                            return_frames=False):
    """Concatenate images into a video."""
    paths = glob.glob(os.path.join(directory, pattern))
    paths.sort(key=lambda fname: int(
        os.path.basename(fname)[len("im_"):-len(".png")]))
    images = [media.read_image(path) for path in tqdm.tqdm(paths)]
    if images:
        # Pad frames to the same shape.
        heights, widths = zip(*[im.shape[:2] for im in images])
        max_height, max_width = max(heights), max(widths)
        for i, im in enumerate(images):
            assert im.ndim == 3
            pad_height = max_height - im.shape[0]
            pad_width = max_width - im.shape[1]
            if pad_height != 0 or pad_width != 0:
                images[i] = np.pad(im,
                                   ((0, pad_height), (0, pad_width), (0, 0)))

        # Stack and write video.
        images = np.stack(images)[Ellipsis, :3]

        if return_frames:
            return images

        media.write_video(os.path.join(directory, "optimization.mp4"),
                          images,
                          fps=30)
    else:
        print(
            f"WARNING: No images to combine in {directory} for optimization history."
        )
Beispiel #2
0
 def test_video_read_write_odd_dimensions(self):
     video = media.moving_circle((35, 97), num_images=4, dtype=np.uint8)
     fps = 60
     bps = 40_000_000
     with tempfile.TemporaryDirectory() as directory_name:
         path = pathlib.Path(directory_name) / 'test5.mp4'
         media.write_video(path, video, fps=fps, bps=bps)
         new_video = media.read_video(path)
     self.assertEqual(new_video.dtype, np.uint8)
     self._check_similar(video, new_video, max_rms=5.0)
Beispiel #3
0
 def test_video_non_streaming_write_read_roundtrip(self, use_generator):
     shape = (240, 320)
     num_images = 10
     fps = 40
     qp = 20
     original_video = media.to_uint8(media.moving_circle(shape, num_images))
     video = (
         image
         for image in original_video) if use_generator else original_video
     with tempfile.TemporaryDirectory() as directory_name:
         filename = os.path.join(directory_name, 'test.mp4')
         media.write_video(filename, video, fps=fps, qp=qp)
         new_video = media.read_video(filename)
         self._check_similar(original_video, new_video, 3.0)
Beispiel #4
0
    def test_video_streaming_read_write(self):
        shape = (400, 400)
        num_images = 4
        fps = 60
        bps = 40_000_000
        video = media.to_uint8(media.moving_circle(shape, num_images))
        with tempfile.TemporaryDirectory() as directory_name:
            filename1 = os.path.join(directory_name, 'test1.mp4')
            filename2 = os.path.join(directory_name, 'test2.mp4')
            media.write_video(filename1, video, fps=fps, bps=bps)

            with media.VideoReader(filename1) as reader:
                with media.VideoWriter(filename2,
                                       reader.shape,
                                       fps=reader.fps,
                                       bps=reader.bps,
                                       encoded_format='yuv420p') as writer:
                    for image in reader:
                        writer.add_image(image)

            new_video = media.read_video(filename2)
            self._check_similar(video, new_video, 3.0)
Beispiel #5
0
 def test_video_read_write_10bit(self):
     shape = (256, 256)
     num_images = 4
     fps = 60
     bps = 40_000_000
     horizontal_gray_ramp = media.to_type(
         np.indices(shape)[1] / shape[1], np.uint16)
     video = np.broadcast_to(horizontal_gray_ramp, (num_images, *shape))
     with tempfile.TemporaryDirectory() as directory_name:
         filename = os.path.join(directory_name, 'test3.mp4')
         media.write_video(filename,
                           video,
                           fps=fps,
                           bps=bps,
                           encoded_format='yuv420p10le')
         new_video = media.read_video(filename,
                                      dtype=np.uint16,
                                      output_format='gray')
     self.assertEqual(new_video.dtype, np.uint16)
     value_1_of_10bit_encoded_in_16bits = 64
     self._check_similar(video,
                         new_video,
                         max_rms=value_1_of_10bit_encoded_in_16bits * 0.8)
Beispiel #6
0
def log_video(writer,
              video,
              tb_key,
              name,
              step,
              work_unit_dir,
              save_raw=False,
              scale=False):
  """Save video frames to tensorboard and a file."""
  video_raw = video
  if scale:
    video = scale_depth(video)

  if writer is not None:
    logging.info('Logging video frames')
    writer.write_images(step, {f'{tb_key}/{name}': make_image_grid(video)})

  filename = f'{tb_key}_{name}_{step:05d}.mp4'
  local_path = os.path.join('/tmp', filename)
  logging.info('Writing video to %s', local_path)
  media.write_video(local_path, video, fps=30)

  wu_path = os.path.join(work_unit_dir, filename)
  logging.info('Copying video to %s', wu_path)
  gfile.copy(local_path, wu_path, overwrite=True)
  gfile.remove(local_path)

  if save_raw:
    # save raw floating point values to scale depth properly
    raw_filename = f'{tb_key}_{name}_{step:05d}.npy'
    raw_path = os.path.join(work_unit_dir, raw_filename)
    logging.info('Saving raw video to %s', raw_path)
    with gfile.GFile(raw_path, 'wb') as raw_f:
      onp.save(raw_f, video_raw)

  logging.info('Done logging video.')
Beispiel #7
0
def main():
  # Reproducibility.
  torch.manual_seed(args.seed)
  random.seed(args.seed)
  np.random.seed(args.seed)

  # Obtain the utilized device.
  if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
    has_cuda = True
  else:
    device = torch.device("cpu")
    has_cuda = False

  # Setup logging.
  query = queries[args.query_idx]
  exp_name = f"{args.exp_name_prefix}_{args.query_idx:03d}_{query}"
  wandb.init(
      entity=args.wandb_entity,
      project=args.wandb_project,
      name=exp_name,
      config=args)

  # Initialize CLIP
  if args.clip_lam:
    model, preprocess, clip_size = load_clip(args.loss_model, device)
    model.eval()

  if args.retrieve_model == args.loss_model and args.clip_lam:
    test_model, test_preprocess, test_clip_size = model, preprocess, clip_size
  else:
    test_model, test_preprocess, test_clip_size = load_clip(
        args.retrieve_model, device)
    test_model.eval()

  # Initialize the volumetric model.
  volume_model = nerf.DreamFieldsMLP(
      activation="SiLU",
      features_early=[96],  # Dense layers before residual blocks.
      features_residual=[(128, 96)] * 3,  # Resid block feature dimensions.
      features_late=[96, 4],  # Features dimensions at end.
      fourfeat=args.fourfeat,
      max_deg=args.posenc_deg,
      ipe=args.ipe,
  )
  volume_model = nn.DataParallel(volume_model)
  volume_model = volume_model.to(device)
  scene_origin = scene.EMA(np.zeros(3, dtype=np.float64), decay=0.999)
  render_kwargs = dict(
      sigma_noise_std=args.sigma_noise_std,
      near=4. - math.sqrt(3) * args.volume_extent_world / 2,
      far=4. + math.sqrt(3) * args.volume_extent_world / 2,
      mask_rad=args.volume_extent_world / 2,
      n_pts_per_ray=args.n_pts_per_ray,
      device=device,
  )

  # Instantiate the Adam optimizer.
  optimizer = torch.optim.Adam(
      volume_model.parameters(), lr=args.lr_init, eps=args.adam_eps)
  scaler = torch.cuda.amp.GradScaler()

  # Embed the target caption with CLIP.
  if args.clip_lam:
    query_tok = clip.tokenize(query).to(device)
    z_clip = model.encode_text(query_tok).detach()
    z_clip = F.normalize(z_clip, dim=-1)

    clip_aug_fn = torchvision.transforms.RandomResizedCrop(
        clip_size, scale=args.crop_scale_range, ratio=(1.0, 1.0))

  if args.diffusion_lam:
    # Initialize GLIDE. Create base model.
    base_glide_model, diffusion, base_glide_options = load_diffusion(
        "base", device, has_cuda=has_cuda)
    base_glide_model.eval()

    # Embed the target caption with GLIDE.
    denoise_batch_size = (
        args.n_aug * args.n_views if args.denoise_augmented else args.n_views)
    tokens = base_glide_model.tokenizer.encode(query)
    tokens, mask = base_glide_model.tokenizer.padded_tokens_and_mask(
        tokens, base_glide_options["text_ctx"])

    # Create the classifier-free guidance tokens (empty).
    uncond_tokens, uncond_mask = base_glide_model.tokenizer.padded_tokens_and_mask(
        [], base_glide_options["text_ctx"])

    # Pack the tokens together into model kwargs.
    base_model_kwargs = dict(
        tokens=torch.tensor(
            [tokens] * denoise_batch_size +
            [uncond_tokens] * denoise_batch_size,
            device=device),
        mask=torch.tensor(
            [mask] * denoise_batch_size + [uncond_mask] * denoise_batch_size,
            dtype=torch.bool,
            device=device),
    )

    parallel_glide = nn.DataParallel(base_glide_model)

    # Create an classifier-free guidance sampling function.
    def base_model_fn(x_t, ts, **kwargs):
      half = x_t[:len(x_t) // 2]
      combined = torch.cat([half, half], dim=0)
      model_out = parallel_glide(combined, ts, **kwargs)
      eps, rest = model_out[:, :3], model_out[:, 3:]
      cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
      half_eps = uncond_eps + args.guidance_scale * (cond_eps - uncond_eps)
      eps = torch.cat([half_eps, half_eps], dim=0)
      return torch.cat([eps, rest], dim=1)

    def preprocess_glide(x, order="NHWC"):
      if order == "NHWC":
        # x is [NHWC]. Reshape to NCHW.
        x = x.movedim(-1, 1)
      x = x * 2 - 1  # Scale from [0, 1] to [-1, 1].
      if x.shape[-2:] != (64, 64):
        x = F.interpolate(x, (64, 64), mode="bilinear")
      return x

    def unprocess_glide(x):
      return (x + 1) / 2  # Scale from [-1, 1] to [0, 1].

    denoised_fn = lambda x_start: x_start
    denoise_aug_fn = torchvision.transforms.RandomResizedCrop(
        64, scale=args.crop_scale_range, ratio=(1.0, 1.0))

    glide_context_manager = (
        torch.no_grad if args.denoise_stop_grad else torch.enable_grad)

    # Initialize each chain.
    diffusion_x = torch.randn((args.n_views, 3, 64, 64),
                              device=device,
                              requires_grad=False)
    diffusion_t = torch.full(
        size=(args.n_views,),
        fill_value=args.t_respace - 1,
        requires_grad=False,
        dtype=torch.long,
        device=device)

    # Training uses n_iter iterations: 1 to n_iter (inclusive).
    # Diffusion uses t_respace timesteps: t_respace-1 to 0 (inclusive).
    # For now, check they are equal.
    # TODO(jainajay): implement sampling with non-unit timesteps.
    assert args.t_respace * args.denoise_every == args.n_iter

  # Get a batch of viewing angles and pre-generate rays.
  azimuths = np.arange(args.n_views) * 360. / args.n_views
  rads = np.full(args.n_views, 4.)
  focal_mults = np.full(args.n_views, 1.2)
  elevations = [
      scene.uniform_in_interval(args.elevation_range)
      for _ in range(args.n_views)
  ]
  cam2worlds = [
      scene.pose_spherical(azim, phi=elev, radius=rad)
      for azim, elev, rad in zip(azimuths, elevations, rads)
  ]
  height, width, focal = scene.scale_intrinsics(args.render_size)
  # Generate rays: 3-tuple of [n_views, H, W, n_pts_per_ray, 3 or 1].
  rays_all_views = scene.camera_rays_batched(cam2worlds, height, width,
                                             focal_mults * focal)

  pbar = tqdm.trange(1, args.n_iter + 1)
  for iteration in pbar:
    metrics = {}
    visualize_images = iteration % 25 == 0 or iteration == 1

    # Set learning rate
    lr = schedule.learning_rate_decay(
        iteration,
        args.lr_init,
        args.lr_final,
        args.n_iter,
        lr_delay_steps=min(args.n_iter // 8, 2500),
        lr_delay_mult=args.lr_delay_mult)
    for g in optimizer.param_groups:
      g["lr"] = float(lr)

    # Zero the optimizer gradient.
    optimizer.zero_grad()

    # Render the volumetric model from random perspectives.
    batch_idx = np.random.choice(
        args.n_views, size=args.batch_size, replace=False)
    rays_batched = [r[batch_idx] for r in rays_all_views]

    # Runs the forward pass with automatic precision casting.
    with torch.cuda.amp.autocast():
      (images, depths, disparities, silhouettes), _ = nerf.render_rays_mip(
          rays_batched,
          volume_model,
          origin=scene_origin.value,
          **render_kwargs)
      assert images.ndim == 4
      assert images.shape[0] == args.batch_size
      assert images.shape[-1] == 3

      # Transmittance loss. Anneal target opacity (1 - transmittance).
      target_opacity = schedule.anneal_logarithmically(
          iteration, args.target_transmittance_anneal_iters,
          1 - args.target_transmittance0, 1 - args.target_transmittance1)

      # The area of an object on the image plane grows with the focal length
      # and shrinks with increasing camera radius. Scale target opacity
      # proportionally with the squared focal multiplier and inversely
      # proportionally with the squared camera radius.
      target_opacities = np.minimum(
          np.ones(args.batch_size), focal_mults[batch_idx]**2 /
          (rads[batch_idx] / 4.)**2 * target_opacity)
      taus = torch.tensor(1 - target_opacities, device=device)
      avg_transmittance = 1 - silhouettes.mean(
          dim=tuple(range(1, silhouettes.ndim)))
      # NOTE(jainajay): Using a modified, two-sided transmittance loss that
      # differs from Dream Fields. It can encourage reducing transmittance if
      # the scene becomes too sparse. The original loss would penalize
      # -torch.mean(torch.min(avg_transmittance, taus)).
      transmittance_loss = torch.mean(torch.abs(avg_transmittance - taus))

      # Data augmentation.
      if (args.diffusion_lam > 0 and
          args.denoise_augmented) or args.clip_lam > 0:
        # NOTE(jainajay): this background is at the render resolution,
        #       not the resize, unlike Dream Fields.
        # Generate random backgrounds.
        bgs = augment.sample_backgrounds(
            num=args.n_aug * args.batch_size,
            res=args.render_size,
            checkerboard_nsq=args.nsq,
            min_blur_std=args.bg_blur_std_range[0],
            max_blur_std=args.bg_blur_std_range[1],
            device=device)

        # Composite renders with backgrounds.
        bgs = bgs.view(args.n_aug, args.batch_size, *bgs.shape[1:])  # ANCHW.
        bgs = bgs.movedim(2, -1)  # Convert ANCHW to ANHWC.
        composite_images = (
            silhouettes[None] * images[None] + (1 - silhouettes[None]) * bgs)
        composite_images = composite_images.reshape(  # to A*N,H,W,C.
            args.n_aug * args.batch_size, args.render_size, args.render_size, 3)
        composite_images = composite_images.movedim(3, 1)  # NHWC to NCHW.

      # Compute GLIDE loss.
      # Sample from the base model.
      if args.diffusion_lam:
        # Preprocess rendering (scale to [-1, 1]).
        if args.denoise_augmented:
          denoise_aug_images = denoise_aug_fn(composite_images)
          inp = preprocess_glide(denoise_aug_images, order="NCHW")
        else:
          inp = silhouettes * images + 1 - silhouettes  # white bg
          inp = preprocess_glide(inp, order="NHWC")

        if (iteration - 1) % args.denoise_every == 0:
          base_glide_model.del_cache()

          # Sampling step for every view in the cache.
          with glide_context_manager():
            assert diffusion_t.dtype == torch.long
            assert torch.all(diffusion_t == diffusion_t[0])
            metrics["diffusion/t"] = diffusion_t[0].item()

            xt = diffusion_x  # || x_hat(x_t) - render ||^2

            # Enable for loss: || x_hat(diffuse(render)) - x_hat(x_t) ||^2
            # x = diffusion.q_sample(
            #     inp, torch.tensor([diffusion_t] * denoise_batch_size,
            #     device=device))

            # Sample x_s from x_t using DDIM.
            # Based on glide-text2im/glide_text2im/gaussian_diffusion.py#L453
            assert args.batch_size == args.n_views  # Updating all chains.
            out = diffusion.p_mean_variance(
                base_model_fn,
                torch.cat([xt, xt], dim=0),
                torch.cat([diffusion_t, diffusion_t], dim=0),
                clip_denoised=True,
                denoised_fn=denoised_fn,  # TODO(jainajay): look into this,
                model_kwargs=base_model_kwargs,
            )
            assert out["pred_xstart"].shape[0] == 2 * args.batch_size
            pred_xstart = out["pred_xstart"][:args.batch_size]

            if iteration < args.independent_sampling_steps * args.denoise_every:
              # Ours: eps = pred_eps(x_t, t, tilde_x).
              # Ours: x_{t-1} = a * tilde_x + b * eps + sigma * noise.
              x0_for_sampling = pred_xstart
            else:
              # GLIDE: eps = pred_eps(x_t, t, x_hat(x_t)).
              # GLIDE: x_{t-1} = a * x_hat(x_t) + b * eps + sigma * noise.
              x0_for_sampling = inp.detach()

            # pylint: disable=protected-access
            eps = diffusion._predict_eps_from_xstart(diffusion_x, diffusion_t,
                                                     x0_for_sampling)
            # pylint: enable=protected-access

            assert eps.shape[0] == args.batch_size

            alpha_bar = _extract_into_tensor(diffusion.alphas_cumprod,
                                             diffusion_t, xt.shape)
            metrics["diffusion/alpha_bar"] = alpha_bar.mean().item()
            alpha_bar_prev = _extract_into_tensor(diffusion.alphas_cumprod_prev,
                                                  diffusion_t, xt.shape)
            metrics["diffusion/alpha_bar_prev"] = alpha_bar_prev.mean().item()
            sigma = (
                args.ddim_eta * torch.sqrt(
                    (1 - alpha_bar_prev) / (1 - alpha_bar)) *
                torch.sqrt(1 - alpha_bar / alpha_bar_prev))
            metrics["diffusion/sigma"] = sigma.mean().item()
            # Equation 12.
            mean_pred = (
                x0_for_sampling * torch.sqrt(alpha_bar_prev) +
                torch.sqrt(1 - alpha_bar_prev - sigma**2) * eps)
            nonzero_mask = (
                (diffusion_t != 0).float().view(-1,
                                                *([1] * (len(xt.shape) - 1)))
            )  # No noise when t == 0.
            noise = torch.randn_like(xt)
            sample = mean_pred + nonzero_mask * sigma * noise

            # Update multiview sampling chains.
            diffusion_x_prev = diffusion_x
            diffusion_x = sample
            diffusion_t = diffusion_t - 1

            # Don't backprop through the denoiser (forces stop_grad True).
            assert args.denoise_stop_grad
            pred_xstart = pred_xstart.detach()

          base_glide_model.del_cache()

        # Loss: ||x_hat(x_t) - render||^2.
        # Slicing the predictions only optimizes a few views.
        diffusion_loss = F.mse_loss(pred_xstart[:args.n_optimize],
                                    inp[:args.n_optimize])

        # TODO(jainajay): Try other losses. Some possibilities:
        #   ||x_hat(render) - render||^2 (change L480)
        #   ||x_hat(x_t) - x_hat(diffuse(render))||^2
        #         (change denosing code to denoise render and x_t)
        #   ||eps - eps_hat(diffuse(render), eps)||^2
        #   ||eps_hat(x_t) - eps_hat(diffuse(render), eps)||^2
        #       only makes sense if that's the eps in x_t
        metrics["loss/diffusion_mse"] = diffusion_loss
      else:
        diffusion_loss = torch.tensor([0.], device=device)

      # Compute the CLIP loss.
      if args.clip_lam:
        clip_aug_images = clip_aug_fn(composite_images)
        x = preprocess(clip_aug_images)  # Resize and normalize.
        z_est = model.encode_image(x)
        z_est = F.normalize(z_est, dim=-1)
        clip_loss = -torch.sum(z_est * z_clip, dim=-1).mean()
      else:
        clip_loss = torch.tensor([0.], device=device)

      # Compute total loss and take an optimization step.
      loss = (
          args.clip_lam * clip_loss +
          args.transmittance_lam * transmittance_loss +
          args.diffusion_lam * diffusion_loss)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    if args.track_scene_origin:
      raise NotImplementedError

    # Logging.
    with torch.inference_mode():
      volume_model.eval()

      metrics["train/depths/min"] = depths.min()
      metrics["train/depths/max"] = depths.max()
      metrics["train/disparities/min"] = disparities.min()
      metrics["train/disparities/max"] = disparities.max()

      metrics.update({
          "schedule/lr": lr,
          "loss/total_loss": loss.item(),
          "loss/clip": clip_loss.item(),
          "loss/transmittance": transmittance_loss.item(),
          "train/avg_transmittance": avg_transmittance.mean().item()
      })

      # Print the current values of the losses.
      if iteration % 10 == 0:
        pbar.set_description(
            f"Iteration {iteration:05d}:" +
            f" clip_loss = {float(clip_loss.item()):1.2f}" +
            f" diffusion_loss = {float(diffusion_loss.item()):1.5f}" +
            f" avg transmittance = {float(avg_transmittance.mean().item()):1.2f}"
        )

      # Visualize the renders.
      if visualize_images:
        metrics["render/rendered"] = wandb_grid(images)
        metrics["render/silhouettes"] = wandb_grid(silhouettes)
        metrics["render/rendered_depth"] = wandb_grid(depths)

        if args.clip_lam > 0:
          metrics["render/augmented"] = wandb_grid(clip_aug_images)

        if args.diffusion_lam:
          # Show diffusion_x_prev, diffusion_x (sample), out['pred_xstart'].
          for name, val in zip(["x_t", "x_tm1", "pred_xstart"],
                               [diffusion_x_prev, diffusion_x, pred_xstart]):
            print("diffusion", name, val.shape, val.min(), val.max())
            val = unprocess_glide(val)  # [n_views, C, 64, 64]
            metrics[f"diffusion/{name}"] = wandb_grid(val)

      # Validate from a held-out view.
      if iteration % 250 == 0 or iteration == 1:
        validation_view = render_validation_view(
            volume_model,
            scene_origin,
            test_clip_size,
            args.max_validation_size,
            **render_kwargs)
        assert validation_view.ndim == 3
        assert validation_view.shape[-1] == 3
        metrics["val/render"] = wandb.Image(clamp_and_detach(validation_view))

        rank, cosine_sim = compute_query_rank(
            test_model,
            test_preprocess,
            render=validation_view.movedim(-1, 0).unsqueeze(0),
            query=query,
            queries_r=queries,
            device=device)

        metrics["val/rank"] = rank
        metrics["val/acc"] = int(rank == 0)
        metrics["val/cosine_sim"] = cosine_sim

      if iteration % 250 == 0 or iteration == 1:
        # Visualize the optimized volume by rendering from multiple viewpoints
        # that rotate around the volume's y-axis.
        video_frames = render_rotating_volume(
            volume_model,
            scene_origin=scene_origin,
            video_size=args.video_size,
            n_frames=args.video_n_frames,
            **render_kwargs)

        for name, frames in zip(["rgb", "depth", "disparity", "silhouette"],
                                video_frames):
          # frames is in THWC order.
          filename = f"/tmp/{iteration:05d}_{name}.mp4"
          if frames.shape[-1] == 1:
            media.write_video(filename, frames[Ellipsis, 0], fps=30)
          else:
            media.write_video(filename, frames, fps=30)
          print("wrote", filename,
                f"range: [{frames.min():.4f}, {frames.max():.4f}]")

          metrics[f"render/video/{name}"] = wandb.Video(
              filename, fps=30, format="mp4")

      wandb.log(metrics, iteration)

      volume_model.train()
Beispiel #8
0
def write_video(video_frames, output_dir, filename, pbar=None, fps=30):
    """Write video to disk."""
    filename = os.path.join(output_dir, filename)
    media.write_video(filename, video_frames, fps=fps)
    if pbar:
        pbar.write(f"Wrote {filename}.")