def gt_state_and_latents( self, params: hk.Params, rng: jnp.ndarray, inputs: Dict[str, jnp.ndarray], seq_length: int, is_training: bool = False, unroll_direction: str = "forward", **kwargs: Dict[str, Any] ) -> Tuple[jnp.ndarray, jnp.ndarray, Union[distrax.Distribution, jnp.ndarray]]: """Computes the ground state and matching latents.""" assert unroll_direction == "forward" images = utils.extract_image(inputs) gt_state = utils.extract_gt_state(inputs) image_data = images[:, :self.num_inference_steps] gt_state = gt_state[:, 1:seq_length + 1] _, z_in, z_out = self._models_core( params=params, keys=jnr.split(rng, 6), image_data=image_data, is_training=False, num_steps_forward=images.shape[1] - self.num_inference_steps, num_steps_backward=0, include_z0=False, ) return gt_state, z_out, z_in
def gt_state_and_latents( self, params: hk.Params, rng: jnp.ndarray, inputs: Dict[str, jnp.ndarray], seq_length: int, is_training: bool = False, unroll_direction: str = "forward", **kwargs: Dict[str, Any] ) -> Tuple[jnp.ndarray, jnp.ndarray, Union[distrax.Distribution, jnp.ndarray]]: """Computes the ground state and matching latents.""" assert unroll_direction in ("forward", "backward") if unroll_direction == "backward" and not self.can_run_backwards: raise ValueError("This model can not be unrolled backwards.") images = utils.extract_image(inputs) gt_state = utils.extract_gt_state(inputs) if unroll_direction == "forward": image_data = images[:, :self.num_inference_steps] if self.can_run_backwards: num_steps_backward = self.inferred_index gt_start_idx = 0 else: num_steps_backward = 0 gt_start_idx = self.inferred_index num_steps_forward = seq_length - num_steps_backward - 1 gt_state = gt_state[:, gt_start_idx:seq_length + gt_start_idx] elif unroll_direction == "backward": inference_start_idx = seq_length - self.num_inference_steps image_data = images[:, inference_start_idx:seq_length] num_steps_forward = self.num_inference_steps - self.inferred_index - 1 num_steps_backward = seq_length - num_steps_forward - 1 gt_state = gt_state[:, :seq_length] else: raise NotImplementedError() _, q_z, _, z0, z, _ = self._models_core( params=params, keys=jnr.split(rng, 6), image_data=image_data, use_mean=True, is_training=False, num_steps_forward=num_steps_forward, num_steps_backward=num_steps_backward, include_z0=True, ) if self.has_latent_transform: return gt_state, z, z0 else: return gt_state, z, q_z
def training_objectives( self, params: hk.Params, state: hk.State, rng: jnp.ndarray, inputs: jnp.ndarray, step: jnp.ndarray, is_training: bool = True, use_mean_for_eval_stats: bool = True ) -> Tuple[jnp.ndarray, Sequence[Dict[str, jnp.ndarray]]]: """Computes the training objective and any supporting stats.""" # Split all rng keys keys = jnr.split(rng, 6) # Process training data images = utils.extract_image(inputs) image_data, target_data, unroll_kwargs = self.train_data_split(images) p_x, _, _ = self._models_core(params=params, keys=keys, image_data=image_data, is_training=is_training, **unroll_kwargs) # Compute training statistics stats = metrics.training_statistics( p_x=p_x, targets=target_data, rescale_by=self.rescale_by, p_x_learned_sigma=self.decoder_kwargs.get("learned_sigma", False)) # The loss is just the negative log-likelihood (e.g. the L2 loss) stats["loss"] = stats["neg_log_p_x"] if not is_training: # Optionally add the evaluation stats when not training # Add also the evaluation statistics # We need to be able to set `use_mean = False` for some of the tests stats.update( metrics.evaluation_only_statistics( reconstruct_func=functools.partial( self.reconstruct, use_mean=use_mean_for_eval_stats), params=params, inputs=inputs, rng=rng, rescale_by=self.rescale_by, can_run_backwards=self.can_run_backwards, train_sequence_length=self.train_sequence_length, reconstruction_skip=1, p_x_learned_sigma=self.decoder_kwargs.get( "learned_sigma", False))) return stats["loss"], (dict(), stats, dict())
def init( self, rng: jnp.ndarray, inputs_or_shape: Union[jnp.ndarray, Mapping[str, jnp.ndarray], Sequence[int]], ) -> Tuple[utils.Params, hk.State]: """Initializes the whole model parameters and state.""" if (isinstance(inputs_or_shape, (tuple, list)) and isinstance(inputs_or_shape[0], int)): images = jnp.zeros(inputs_or_shape) else: images = utils.extract_image(inputs_or_shape) if self._jit_init is None: self._jit_init = jax.jit(self._init) return self._jit_init(rng, images)
def _eval_batch_vpt(self, params, state, rng_key, batch): full_trajectory = utils.extract_image(batch) prefixes = ("forward", "backward") if self.model.can_run_backwards else ( "forward", ) stats = dict() vpt_abs_scores = [] vpt_rel_scores = [] seq_length = None for prefix in prefixes: reconstruction, gt_images = self._reconstruct_and_align( rng_key, full_trajectory, prefix, "extrapolation") seq_length = gt_images.shape[2] mse_norm = np.mean( (gt_images - reconstruction)**2, axis=(3, 4, 5)) / np.mean( gt_images**2, axis=(3, 4, 5)) vpt_scores = [] for i in range(mse_norm.shape[1]): vpt_ind = np.argwhere( mse_norm[:, i:i + 1, :] > self.config.evaluation_vpt.vpt_threshold) if vpt_ind.shape[0] > 0: vpt_ind = vpt_ind[0][2] else: vpt_ind = mse_norm.shape[-1] vpt_scores.append(vpt_ind) vpt_abs_scores.append(np.median(vpt_scores)) vpt_rel_scores.append(np.median(vpt_scores) / seq_length) scores = { "vpt_abs": vpt_abs_scores[-1], "vpt_rel": vpt_rel_scores[-1] } scores = utils.to_numpy(scores) scores = utils.filter_only_scalar_stats(scores) stats[prefix] = scores stats["vpt_abs"] = utils.to_numpy(np.mean(vpt_abs_scores)) stats["vpt_rel"] = utils.to_numpy(np.mean(vpt_rel_scores)) logging.info("vpt_abs: %s, seq_length: %d}", str(vpt_abs_scores), seq_length) return stats
def reconstruct( self, params: utils.Params, inputs: jnp.ndarray, rng: jnp.ndarray, forward: bool, use_mean: bool = True, ) -> distrax.Distribution: """Reconstructs the input sequence.""" if not forward: raise ValueError("This model can not run backwards.") images = utils.extract_image(inputs) image_data = images[:, :self.num_inference_steps] return self._models_core( params=params, keys=jnr.split(rng, 6), image_data=image_data, is_training=False, num_steps_forward=images.shape[1] - self.num_inference_steps, num_steps_backward=0, include_z0=False, )[0]
def reconstruct( self, params: utils.Params, inputs: jnp.ndarray, rng: Optional[jnp.ndarray], forward: bool, use_mean: bool = True, ) -> distrax.Distribution: if not self.can_run_backwards and not forward: raise ValueError("This model can not be run backwards.") images = utils.extract_image(inputs) # This is intentionally matching the split for the training stats if forward: num_steps_backward = self.inferred_index num_steps_forward = images.shape[1] - num_steps_backward - 1 else: num_steps_forward = self.num_inference_steps - self.inferred_index - 1 num_steps_backward = images.shape[1] - num_steps_forward - 1 if not self.can_run_backwards: num_steps_backward = 0 if forward: image_data = images[:, :self.num_inference_steps] else: image_data = images[:, -self.num_inference_steps:] return self._models_core( params=params, keys=jnr.split(rng, 6), image_data=image_data, use_mean=use_mean, is_training=False, num_steps_forward=num_steps_forward, num_steps_backward=num_steps_backward, include_z0=True, )[0]
def evaluation_only_statistics( reconstruct_func: _ReconstructFunc, params: hk.Params, inputs: jnp.ndarray, rng: jnp.ndarray, rescale_by: str, can_run_backwards: bool, train_sequence_length: int, reconstruction_skip: int, p_x_learned_sigma: bool = False, ) -> Dict[str, jnp.ndarray]: """Computes various statistics we track only during evaluation.""" full_trajectory = utils.extract_image(inputs) prefixes = ("forward", "backward") if can_run_backwards else ("forward", ) full_forward_targets = jax.tree_map(lambda x: x[:, reconstruction_skip:], full_trajectory) full_backward_targets = jax.tree_map( lambda x: x[:, :x.shape[1] - reconstruction_skip], full_trajectory) train_targets_length = train_sequence_length - reconstruction_skip full_targets_length = full_forward_targets.shape[1] stats = dict() keys = () for prefix in prefixes: # Fully unroll the model and reconstruct the whole sequence full_prediction = reconstruct_func(params, full_trajectory, rng, prefix == "forward") assert isinstance(full_prediction, distrax.Normal) full_targets = (full_forward_targets if prefix == "forward" else full_backward_targets) # In cases where the model can run backwards it is possible to reconstruct # parts which were indented to be skipped, so here we take care of that. if full_prediction.mean().shape[1] > full_targets_length: if prefix == "forward": full_prediction = jax.tree_map( lambda x: x[:, -full_targets_length:], full_prediction) else: full_prediction = jax.tree_map( lambda x: x[:, :full_targets_length], full_prediction) # Based on the prefix and suffix fetch correct predictions and targets for suffix in ("train", "extrapolation", "full"): if prefix == "forward" and suffix == "train": predict, targets = jax.tree_map( lambda x: x[:, :train_targets_length], (full_prediction, full_targets)) elif prefix == "forward" and suffix == "extrapolation": predict, targets = jax.tree_map( lambda x: x[:, train_targets_length:], (full_prediction, full_targets)) elif prefix == "backward" and suffix == "train": predict, targets = jax.tree_map( lambda x: x[:, -train_targets_length:], (full_prediction, full_targets)) elif prefix == "backward" and suffix == "extrapolation": predict, targets = jax.tree_map( lambda x: x[:, :-train_targets_length], (full_prediction, full_targets)) else: predict, targets = full_prediction, full_targets # Compute train statistics train_stats = training_statistics( predict, targets, rescale_by, p_x_learned_sigma=p_x_learned_sigma) for key, value in train_stats.items(): stats[prefix + "_" + suffix + "_" + key] = value # Copy all stats keys keys = tuple(train_stats.keys()) # Make a combined metric summing forward and backward if can_run_backwards: # Also compute for suffix in ("train", "extrapolation", "full"): for key in keys: forward = stats["forward_" + suffix + "_" + key] backward = stats["backward_" + suffix + "_" + key] combined = (forward + backward) / 2 stats["combined_" + suffix + "_" + key] = combined return stats
def training_objectives( self, params: utils.Params, state: hk.State, rng: jnp.ndarray, inputs: jnp.ndarray, step: jnp.ndarray, is_training: bool = True, use_mean_for_eval_stats: bool = True ) -> Tuple[jnp.ndarray, Sequence[Dict[str, jnp.ndarray]]]: # Split all rng keys keys = jnr.split(rng, 6) # Process training data images = utils.extract_image(inputs) image_data, target_data, unroll_kwargs = self.train_data_split(images) p_x, q_z, prior, _, _, dyn_stats = self._models_core( params=params, keys=keys, image_data=image_data, use_mean=False, is_training=is_training, **unroll_kwargs) # Note: we reuse the rng key used to sample the latent variable here # so that it can be reused to evaluate a (non-analytical) KL at that sample. stats = metrics.training_statistics( p_x=p_x, targets=target_data, rescale_by=self.rescale_by, rng=keys[1], q_z=q_z, prior=prior, p_x_learned_sigma=self.decoder_kwargs.get("learned_sigma", False)) stats.update(dyn_stats) # Compute other (non-reported statistics) z_stats = dict() other_stats = dict(x_reconstruct=p_x.mean(), z_stats=z_stats) # The loss computation and GECO state update new_state = dict() if self.objective_type == "GECO": geco_stats = metrics.geco_objective( l2_loss=stats["l2"], kl=stats["kl"], alpha=self.geco_alpha, kappa=self.geco_kappa, constraint_ema=state["GECO"]["geco_constraint_ema"], lambda_var=params["GECO"]["geco_lambda_var"], is_training=is_training) new_state["GECO"] = dict( geco_constraint_ema=geco_stats["geco_constraint_ema"]) stats.update(geco_stats) elif self.objective_type == "ELBO": elbo_stats = metrics.elbo_objective( neg_log_p_x=stats["neg_log_p_x"], kl=stats["kl"], final_beta=self.elbo_beta_final, beta_delay=self.elbo_beta_delay, step=step) stats.update(elbo_stats) elif self.objective_type == "NON-PROB": stats["loss"] = stats["neg_log_p_x"] else: raise ValueError() if not is_training: if self.training_data_split == "overlap_by_one": reconstruction_skip = self.num_inference_steps - 1 elif self.training_data_split == "no_overlap": reconstruction_skip = self.num_inference_steps elif self.training_data_split == "include_inference": reconstruction_skip = 0 else: raise NotImplementedError() # We intentionally reuse the same rng as the training, in order to be able # to run tests and verify that the evaluation and reconstruction work # correctly. # We need to be able to set `use_mean = False` for some of the tests stats.update( metrics.evaluation_only_statistics( reconstruct_func=functools.partial( self.reconstruct, use_mean=use_mean_for_eval_stats), params=params, inputs=inputs, rng=rng, rescale_by=self.rescale_by, can_run_backwards=self.can_run_backwards, train_sequence_length=self.train_sequence_length, reconstruction_skip=reconstruction_skip, p_x_learned_sigma=self.decoder_kwargs.get( "learned_sigma", False))) # Make new state the same type as state new_state = utils.convert_to_pytype(new_state, state) return stats["loss"], (new_state, stats, other_stats)