def make_optimization_video(directory, pattern="im_*.png", return_frames=False): """Concatenate images into a video.""" paths = glob.glob(os.path.join(directory, pattern)) paths.sort(key=lambda fname: int( os.path.basename(fname)[len("im_"):-len(".png")])) images = [media.read_image(path) for path in tqdm.tqdm(paths)] if images: # Pad frames to the same shape. heights, widths = zip(*[im.shape[:2] for im in images]) max_height, max_width = max(heights), max(widths) for i, im in enumerate(images): assert im.ndim == 3 pad_height = max_height - im.shape[0] pad_width = max_width - im.shape[1] if pad_height != 0 or pad_width != 0: images[i] = np.pad(im, ((0, pad_height), (0, pad_width), (0, 0))) # Stack and write video. images = np.stack(images)[Ellipsis, :3] if return_frames: return images media.write_video(os.path.join(directory, "optimization.mp4"), images, fps=30) else: print( f"WARNING: No images to combine in {directory} for optimization history." )
def test_video_read_write_odd_dimensions(self): video = media.moving_circle((35, 97), num_images=4, dtype=np.uint8) fps = 60 bps = 40_000_000 with tempfile.TemporaryDirectory() as directory_name: path = pathlib.Path(directory_name) / 'test5.mp4' media.write_video(path, video, fps=fps, bps=bps) new_video = media.read_video(path) self.assertEqual(new_video.dtype, np.uint8) self._check_similar(video, new_video, max_rms=5.0)
def test_video_non_streaming_write_read_roundtrip(self, use_generator): shape = (240, 320) num_images = 10 fps = 40 qp = 20 original_video = media.to_uint8(media.moving_circle(shape, num_images)) video = ( image for image in original_video) if use_generator else original_video with tempfile.TemporaryDirectory() as directory_name: filename = os.path.join(directory_name, 'test.mp4') media.write_video(filename, video, fps=fps, qp=qp) new_video = media.read_video(filename) self._check_similar(original_video, new_video, 3.0)
def test_video_streaming_read_write(self): shape = (400, 400) num_images = 4 fps = 60 bps = 40_000_000 video = media.to_uint8(media.moving_circle(shape, num_images)) with tempfile.TemporaryDirectory() as directory_name: filename1 = os.path.join(directory_name, 'test1.mp4') filename2 = os.path.join(directory_name, 'test2.mp4') media.write_video(filename1, video, fps=fps, bps=bps) with media.VideoReader(filename1) as reader: with media.VideoWriter(filename2, reader.shape, fps=reader.fps, bps=reader.bps, encoded_format='yuv420p') as writer: for image in reader: writer.add_image(image) new_video = media.read_video(filename2) self._check_similar(video, new_video, 3.0)
def test_video_read_write_10bit(self): shape = (256, 256) num_images = 4 fps = 60 bps = 40_000_000 horizontal_gray_ramp = media.to_type( np.indices(shape)[1] / shape[1], np.uint16) video = np.broadcast_to(horizontal_gray_ramp, (num_images, *shape)) with tempfile.TemporaryDirectory() as directory_name: filename = os.path.join(directory_name, 'test3.mp4') media.write_video(filename, video, fps=fps, bps=bps, encoded_format='yuv420p10le') new_video = media.read_video(filename, dtype=np.uint16, output_format='gray') self.assertEqual(new_video.dtype, np.uint16) value_1_of_10bit_encoded_in_16bits = 64 self._check_similar(video, new_video, max_rms=value_1_of_10bit_encoded_in_16bits * 0.8)
def log_video(writer, video, tb_key, name, step, work_unit_dir, save_raw=False, scale=False): """Save video frames to tensorboard and a file.""" video_raw = video if scale: video = scale_depth(video) if writer is not None: logging.info('Logging video frames') writer.write_images(step, {f'{tb_key}/{name}': make_image_grid(video)}) filename = f'{tb_key}_{name}_{step:05d}.mp4' local_path = os.path.join('/tmp', filename) logging.info('Writing video to %s', local_path) media.write_video(local_path, video, fps=30) wu_path = os.path.join(work_unit_dir, filename) logging.info('Copying video to %s', wu_path) gfile.copy(local_path, wu_path, overwrite=True) gfile.remove(local_path) if save_raw: # save raw floating point values to scale depth properly raw_filename = f'{tb_key}_{name}_{step:05d}.npy' raw_path = os.path.join(work_unit_dir, raw_filename) logging.info('Saving raw video to %s', raw_path) with gfile.GFile(raw_path, 'wb') as raw_f: onp.save(raw_f, video_raw) logging.info('Done logging video.')
def main(): # Reproducibility. torch.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) # Obtain the utilized device. if torch.cuda.is_available(): device = torch.device("cuda:0") torch.cuda.set_device(device) has_cuda = True else: device = torch.device("cpu") has_cuda = False # Setup logging. query = queries[args.query_idx] exp_name = f"{args.exp_name_prefix}_{args.query_idx:03d}_{query}" wandb.init( entity=args.wandb_entity, project=args.wandb_project, name=exp_name, config=args) # Initialize CLIP if args.clip_lam: model, preprocess, clip_size = load_clip(args.loss_model, device) model.eval() if args.retrieve_model == args.loss_model and args.clip_lam: test_model, test_preprocess, test_clip_size = model, preprocess, clip_size else: test_model, test_preprocess, test_clip_size = load_clip( args.retrieve_model, device) test_model.eval() # Initialize the volumetric model. volume_model = nerf.DreamFieldsMLP( activation="SiLU", features_early=[96], # Dense layers before residual blocks. features_residual=[(128, 96)] * 3, # Resid block feature dimensions. features_late=[96, 4], # Features dimensions at end. fourfeat=args.fourfeat, max_deg=args.posenc_deg, ipe=args.ipe, ) volume_model = nn.DataParallel(volume_model) volume_model = volume_model.to(device) scene_origin = scene.EMA(np.zeros(3, dtype=np.float64), decay=0.999) render_kwargs = dict( sigma_noise_std=args.sigma_noise_std, near=4. - math.sqrt(3) * args.volume_extent_world / 2, far=4. + math.sqrt(3) * args.volume_extent_world / 2, mask_rad=args.volume_extent_world / 2, n_pts_per_ray=args.n_pts_per_ray, device=device, ) # Instantiate the Adam optimizer. optimizer = torch.optim.Adam( volume_model.parameters(), lr=args.lr_init, eps=args.adam_eps) scaler = torch.cuda.amp.GradScaler() # Embed the target caption with CLIP. if args.clip_lam: query_tok = clip.tokenize(query).to(device) z_clip = model.encode_text(query_tok).detach() z_clip = F.normalize(z_clip, dim=-1) clip_aug_fn = torchvision.transforms.RandomResizedCrop( clip_size, scale=args.crop_scale_range, ratio=(1.0, 1.0)) if args.diffusion_lam: # Initialize GLIDE. Create base model. base_glide_model, diffusion, base_glide_options = load_diffusion( "base", device, has_cuda=has_cuda) base_glide_model.eval() # Embed the target caption with GLIDE. denoise_batch_size = ( args.n_aug * args.n_views if args.denoise_augmented else args.n_views) tokens = base_glide_model.tokenizer.encode(query) tokens, mask = base_glide_model.tokenizer.padded_tokens_and_mask( tokens, base_glide_options["text_ctx"]) # Create the classifier-free guidance tokens (empty). uncond_tokens, uncond_mask = base_glide_model.tokenizer.padded_tokens_and_mask( [], base_glide_options["text_ctx"]) # Pack the tokens together into model kwargs. base_model_kwargs = dict( tokens=torch.tensor( [tokens] * denoise_batch_size + [uncond_tokens] * denoise_batch_size, device=device), mask=torch.tensor( [mask] * denoise_batch_size + [uncond_mask] * denoise_batch_size, dtype=torch.bool, device=device), ) parallel_glide = nn.DataParallel(base_glide_model) # Create an classifier-free guidance sampling function. def base_model_fn(x_t, ts, **kwargs): half = x_t[:len(x_t) // 2] combined = torch.cat([half, half], dim=0) model_out = parallel_glide(combined, ts, **kwargs) eps, rest = model_out[:, :3], model_out[:, 3:] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + args.guidance_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) return torch.cat([eps, rest], dim=1) def preprocess_glide(x, order="NHWC"): if order == "NHWC": # x is [NHWC]. Reshape to NCHW. x = x.movedim(-1, 1) x = x * 2 - 1 # Scale from [0, 1] to [-1, 1]. if x.shape[-2:] != (64, 64): x = F.interpolate(x, (64, 64), mode="bilinear") return x def unprocess_glide(x): return (x + 1) / 2 # Scale from [-1, 1] to [0, 1]. denoised_fn = lambda x_start: x_start denoise_aug_fn = torchvision.transforms.RandomResizedCrop( 64, scale=args.crop_scale_range, ratio=(1.0, 1.0)) glide_context_manager = ( torch.no_grad if args.denoise_stop_grad else torch.enable_grad) # Initialize each chain. diffusion_x = torch.randn((args.n_views, 3, 64, 64), device=device, requires_grad=False) diffusion_t = torch.full( size=(args.n_views,), fill_value=args.t_respace - 1, requires_grad=False, dtype=torch.long, device=device) # Training uses n_iter iterations: 1 to n_iter (inclusive). # Diffusion uses t_respace timesteps: t_respace-1 to 0 (inclusive). # For now, check they are equal. # TODO(jainajay): implement sampling with non-unit timesteps. assert args.t_respace * args.denoise_every == args.n_iter # Get a batch of viewing angles and pre-generate rays. azimuths = np.arange(args.n_views) * 360. / args.n_views rads = np.full(args.n_views, 4.) focal_mults = np.full(args.n_views, 1.2) elevations = [ scene.uniform_in_interval(args.elevation_range) for _ in range(args.n_views) ] cam2worlds = [ scene.pose_spherical(azim, phi=elev, radius=rad) for azim, elev, rad in zip(azimuths, elevations, rads) ] height, width, focal = scene.scale_intrinsics(args.render_size) # Generate rays: 3-tuple of [n_views, H, W, n_pts_per_ray, 3 or 1]. rays_all_views = scene.camera_rays_batched(cam2worlds, height, width, focal_mults * focal) pbar = tqdm.trange(1, args.n_iter + 1) for iteration in pbar: metrics = {} visualize_images = iteration % 25 == 0 or iteration == 1 # Set learning rate lr = schedule.learning_rate_decay( iteration, args.lr_init, args.lr_final, args.n_iter, lr_delay_steps=min(args.n_iter // 8, 2500), lr_delay_mult=args.lr_delay_mult) for g in optimizer.param_groups: g["lr"] = float(lr) # Zero the optimizer gradient. optimizer.zero_grad() # Render the volumetric model from random perspectives. batch_idx = np.random.choice( args.n_views, size=args.batch_size, replace=False) rays_batched = [r[batch_idx] for r in rays_all_views] # Runs the forward pass with automatic precision casting. with torch.cuda.amp.autocast(): (images, depths, disparities, silhouettes), _ = nerf.render_rays_mip( rays_batched, volume_model, origin=scene_origin.value, **render_kwargs) assert images.ndim == 4 assert images.shape[0] == args.batch_size assert images.shape[-1] == 3 # Transmittance loss. Anneal target opacity (1 - transmittance). target_opacity = schedule.anneal_logarithmically( iteration, args.target_transmittance_anneal_iters, 1 - args.target_transmittance0, 1 - args.target_transmittance1) # The area of an object on the image plane grows with the focal length # and shrinks with increasing camera radius. Scale target opacity # proportionally with the squared focal multiplier and inversely # proportionally with the squared camera radius. target_opacities = np.minimum( np.ones(args.batch_size), focal_mults[batch_idx]**2 / (rads[batch_idx] / 4.)**2 * target_opacity) taus = torch.tensor(1 - target_opacities, device=device) avg_transmittance = 1 - silhouettes.mean( dim=tuple(range(1, silhouettes.ndim))) # NOTE(jainajay): Using a modified, two-sided transmittance loss that # differs from Dream Fields. It can encourage reducing transmittance if # the scene becomes too sparse. The original loss would penalize # -torch.mean(torch.min(avg_transmittance, taus)). transmittance_loss = torch.mean(torch.abs(avg_transmittance - taus)) # Data augmentation. if (args.diffusion_lam > 0 and args.denoise_augmented) or args.clip_lam > 0: # NOTE(jainajay): this background is at the render resolution, # not the resize, unlike Dream Fields. # Generate random backgrounds. bgs = augment.sample_backgrounds( num=args.n_aug * args.batch_size, res=args.render_size, checkerboard_nsq=args.nsq, min_blur_std=args.bg_blur_std_range[0], max_blur_std=args.bg_blur_std_range[1], device=device) # Composite renders with backgrounds. bgs = bgs.view(args.n_aug, args.batch_size, *bgs.shape[1:]) # ANCHW. bgs = bgs.movedim(2, -1) # Convert ANCHW to ANHWC. composite_images = ( silhouettes[None] * images[None] + (1 - silhouettes[None]) * bgs) composite_images = composite_images.reshape( # to A*N,H,W,C. args.n_aug * args.batch_size, args.render_size, args.render_size, 3) composite_images = composite_images.movedim(3, 1) # NHWC to NCHW. # Compute GLIDE loss. # Sample from the base model. if args.diffusion_lam: # Preprocess rendering (scale to [-1, 1]). if args.denoise_augmented: denoise_aug_images = denoise_aug_fn(composite_images) inp = preprocess_glide(denoise_aug_images, order="NCHW") else: inp = silhouettes * images + 1 - silhouettes # white bg inp = preprocess_glide(inp, order="NHWC") if (iteration - 1) % args.denoise_every == 0: base_glide_model.del_cache() # Sampling step for every view in the cache. with glide_context_manager(): assert diffusion_t.dtype == torch.long assert torch.all(diffusion_t == diffusion_t[0]) metrics["diffusion/t"] = diffusion_t[0].item() xt = diffusion_x # || x_hat(x_t) - render ||^2 # Enable for loss: || x_hat(diffuse(render)) - x_hat(x_t) ||^2 # x = diffusion.q_sample( # inp, torch.tensor([diffusion_t] * denoise_batch_size, # device=device)) # Sample x_s from x_t using DDIM. # Based on glide-text2im/glide_text2im/gaussian_diffusion.py#L453 assert args.batch_size == args.n_views # Updating all chains. out = diffusion.p_mean_variance( base_model_fn, torch.cat([xt, xt], dim=0), torch.cat([diffusion_t, diffusion_t], dim=0), clip_denoised=True, denoised_fn=denoised_fn, # TODO(jainajay): look into this, model_kwargs=base_model_kwargs, ) assert out["pred_xstart"].shape[0] == 2 * args.batch_size pred_xstart = out["pred_xstart"][:args.batch_size] if iteration < args.independent_sampling_steps * args.denoise_every: # Ours: eps = pred_eps(x_t, t, tilde_x). # Ours: x_{t-1} = a * tilde_x + b * eps + sigma * noise. x0_for_sampling = pred_xstart else: # GLIDE: eps = pred_eps(x_t, t, x_hat(x_t)). # GLIDE: x_{t-1} = a * x_hat(x_t) + b * eps + sigma * noise. x0_for_sampling = inp.detach() # pylint: disable=protected-access eps = diffusion._predict_eps_from_xstart(diffusion_x, diffusion_t, x0_for_sampling) # pylint: enable=protected-access assert eps.shape[0] == args.batch_size alpha_bar = _extract_into_tensor(diffusion.alphas_cumprod, diffusion_t, xt.shape) metrics["diffusion/alpha_bar"] = alpha_bar.mean().item() alpha_bar_prev = _extract_into_tensor(diffusion.alphas_cumprod_prev, diffusion_t, xt.shape) metrics["diffusion/alpha_bar_prev"] = alpha_bar_prev.mean().item() sigma = ( args.ddim_eta * torch.sqrt( (1 - alpha_bar_prev) / (1 - alpha_bar)) * torch.sqrt(1 - alpha_bar / alpha_bar_prev)) metrics["diffusion/sigma"] = sigma.mean().item() # Equation 12. mean_pred = ( x0_for_sampling * torch.sqrt(alpha_bar_prev) + torch.sqrt(1 - alpha_bar_prev - sigma**2) * eps) nonzero_mask = ( (diffusion_t != 0).float().view(-1, *([1] * (len(xt.shape) - 1))) ) # No noise when t == 0. noise = torch.randn_like(xt) sample = mean_pred + nonzero_mask * sigma * noise # Update multiview sampling chains. diffusion_x_prev = diffusion_x diffusion_x = sample diffusion_t = diffusion_t - 1 # Don't backprop through the denoiser (forces stop_grad True). assert args.denoise_stop_grad pred_xstart = pred_xstart.detach() base_glide_model.del_cache() # Loss: ||x_hat(x_t) - render||^2. # Slicing the predictions only optimizes a few views. diffusion_loss = F.mse_loss(pred_xstart[:args.n_optimize], inp[:args.n_optimize]) # TODO(jainajay): Try other losses. Some possibilities: # ||x_hat(render) - render||^2 (change L480) # ||x_hat(x_t) - x_hat(diffuse(render))||^2 # (change denosing code to denoise render and x_t) # ||eps - eps_hat(diffuse(render), eps)||^2 # ||eps_hat(x_t) - eps_hat(diffuse(render), eps)||^2 # only makes sense if that's the eps in x_t metrics["loss/diffusion_mse"] = diffusion_loss else: diffusion_loss = torch.tensor([0.], device=device) # Compute the CLIP loss. if args.clip_lam: clip_aug_images = clip_aug_fn(composite_images) x = preprocess(clip_aug_images) # Resize and normalize. z_est = model.encode_image(x) z_est = F.normalize(z_est, dim=-1) clip_loss = -torch.sum(z_est * z_clip, dim=-1).mean() else: clip_loss = torch.tensor([0.], device=device) # Compute total loss and take an optimization step. loss = ( args.clip_lam * clip_loss + args.transmittance_lam * transmittance_loss + args.diffusion_lam * diffusion_loss) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() if args.track_scene_origin: raise NotImplementedError # Logging. with torch.inference_mode(): volume_model.eval() metrics["train/depths/min"] = depths.min() metrics["train/depths/max"] = depths.max() metrics["train/disparities/min"] = disparities.min() metrics["train/disparities/max"] = disparities.max() metrics.update({ "schedule/lr": lr, "loss/total_loss": loss.item(), "loss/clip": clip_loss.item(), "loss/transmittance": transmittance_loss.item(), "train/avg_transmittance": avg_transmittance.mean().item() }) # Print the current values of the losses. if iteration % 10 == 0: pbar.set_description( f"Iteration {iteration:05d}:" + f" clip_loss = {float(clip_loss.item()):1.2f}" + f" diffusion_loss = {float(diffusion_loss.item()):1.5f}" + f" avg transmittance = {float(avg_transmittance.mean().item()):1.2f}" ) # Visualize the renders. if visualize_images: metrics["render/rendered"] = wandb_grid(images) metrics["render/silhouettes"] = wandb_grid(silhouettes) metrics["render/rendered_depth"] = wandb_grid(depths) if args.clip_lam > 0: metrics["render/augmented"] = wandb_grid(clip_aug_images) if args.diffusion_lam: # Show diffusion_x_prev, diffusion_x (sample), out['pred_xstart']. for name, val in zip(["x_t", "x_tm1", "pred_xstart"], [diffusion_x_prev, diffusion_x, pred_xstart]): print("diffusion", name, val.shape, val.min(), val.max()) val = unprocess_glide(val) # [n_views, C, 64, 64] metrics[f"diffusion/{name}"] = wandb_grid(val) # Validate from a held-out view. if iteration % 250 == 0 or iteration == 1: validation_view = render_validation_view( volume_model, scene_origin, test_clip_size, args.max_validation_size, **render_kwargs) assert validation_view.ndim == 3 assert validation_view.shape[-1] == 3 metrics["val/render"] = wandb.Image(clamp_and_detach(validation_view)) rank, cosine_sim = compute_query_rank( test_model, test_preprocess, render=validation_view.movedim(-1, 0).unsqueeze(0), query=query, queries_r=queries, device=device) metrics["val/rank"] = rank metrics["val/acc"] = int(rank == 0) metrics["val/cosine_sim"] = cosine_sim if iteration % 250 == 0 or iteration == 1: # Visualize the optimized volume by rendering from multiple viewpoints # that rotate around the volume's y-axis. video_frames = render_rotating_volume( volume_model, scene_origin=scene_origin, video_size=args.video_size, n_frames=args.video_n_frames, **render_kwargs) for name, frames in zip(["rgb", "depth", "disparity", "silhouette"], video_frames): # frames is in THWC order. filename = f"/tmp/{iteration:05d}_{name}.mp4" if frames.shape[-1] == 1: media.write_video(filename, frames[Ellipsis, 0], fps=30) else: media.write_video(filename, frames, fps=30) print("wrote", filename, f"range: [{frames.min():.4f}, {frames.max():.4f}]") metrics[f"render/video/{name}"] = wandb.Video( filename, fps=30, format="mp4") wandb.log(metrics, iteration) volume_model.train()
def write_video(video_frames, output_dir, filename, pbar=None, fps=30): """Write video to disk.""" filename = os.path.join(output_dir, filename) media.write_video(filename, video_frames, fps=fps) if pbar: pbar.write(f"Wrote {filename}.")