Exemple #1
0
    def gt_state_and_latents(
        self,
        params: hk.Params,
        rng: jnp.ndarray,
        inputs: Dict[str, jnp.ndarray],
        seq_length: int,
        is_training: bool = False,
        unroll_direction: str = "forward",
        **kwargs: Dict[str, Any]
    ) -> Tuple[jnp.ndarray, jnp.ndarray, Union[distrax.Distribution,
                                               jnp.ndarray]]:
        """Computes the ground state and matching latents."""
        assert unroll_direction == "forward"
        images = utils.extract_image(inputs)
        gt_state = utils.extract_gt_state(inputs)
        image_data = images[:, :self.num_inference_steps]
        gt_state = gt_state[:, 1:seq_length + 1]

        _, z_in, z_out = self._models_core(
            params=params,
            keys=jnr.split(rng, 6),
            image_data=image_data,
            is_training=False,
            num_steps_forward=images.shape[1] - self.num_inference_steps,
            num_steps_backward=0,
            include_z0=False,
        )

        return gt_state, z_out, z_in
Exemple #2
0
    def gt_state_and_latents(
        self,
        params: hk.Params,
        rng: jnp.ndarray,
        inputs: Dict[str, jnp.ndarray],
        seq_length: int,
        is_training: bool = False,
        unroll_direction: str = "forward",
        **kwargs: Dict[str, Any]
    ) -> Tuple[jnp.ndarray, jnp.ndarray, Union[distrax.Distribution,
                                               jnp.ndarray]]:
        """Computes the ground state and matching latents."""
        assert unroll_direction in ("forward", "backward")
        if unroll_direction == "backward" and not self.can_run_backwards:
            raise ValueError("This model can not be unrolled backwards.")

        images = utils.extract_image(inputs)
        gt_state = utils.extract_gt_state(inputs)

        if unroll_direction == "forward":
            image_data = images[:, :self.num_inference_steps]
            if self.can_run_backwards:
                num_steps_backward = self.inferred_index
                gt_start_idx = 0
            else:
                num_steps_backward = 0
                gt_start_idx = self.inferred_index
            num_steps_forward = seq_length - num_steps_backward - 1
            gt_state = gt_state[:, gt_start_idx:seq_length + gt_start_idx]
        elif unroll_direction == "backward":
            inference_start_idx = seq_length - self.num_inference_steps
            image_data = images[:, inference_start_idx:seq_length]
            num_steps_forward = self.num_inference_steps - self.inferred_index - 1
            num_steps_backward = seq_length - num_steps_forward - 1
            gt_state = gt_state[:, :seq_length]
        else:
            raise NotImplementedError()

        _, q_z, _, z0, z, _ = self._models_core(
            params=params,
            keys=jnr.split(rng, 6),
            image_data=image_data,
            use_mean=True,
            is_training=False,
            num_steps_forward=num_steps_forward,
            num_steps_backward=num_steps_backward,
            include_z0=True,
        )

        if self.has_latent_transform:
            return gt_state, z, z0
        else:
            return gt_state, z, q_z
Exemple #3
0
    def training_objectives(
        self,
        params: hk.Params,
        state: hk.State,
        rng: jnp.ndarray,
        inputs: jnp.ndarray,
        step: jnp.ndarray,
        is_training: bool = True,
        use_mean_for_eval_stats: bool = True
    ) -> Tuple[jnp.ndarray, Sequence[Dict[str, jnp.ndarray]]]:
        """Computes the training objective and any supporting stats."""
        # Split all rng keys
        keys = jnr.split(rng, 6)

        # Process training data
        images = utils.extract_image(inputs)
        image_data, target_data, unroll_kwargs = self.train_data_split(images)

        p_x, _, _ = self._models_core(params=params,
                                      keys=keys,
                                      image_data=image_data,
                                      is_training=is_training,
                                      **unroll_kwargs)

        # Compute training statistics
        stats = metrics.training_statistics(
            p_x=p_x,
            targets=target_data,
            rescale_by=self.rescale_by,
            p_x_learned_sigma=self.decoder_kwargs.get("learned_sigma", False))

        # The loss is just the negative log-likelihood (e.g. the L2 loss)
        stats["loss"] = stats["neg_log_p_x"]

        if not is_training:
            # Optionally add the evaluation stats when not training
            # Add also the evaluation statistics
            # We need to be able to set `use_mean = False` for some of the tests
            stats.update(
                metrics.evaluation_only_statistics(
                    reconstruct_func=functools.partial(
                        self.reconstruct, use_mean=use_mean_for_eval_stats),
                    params=params,
                    inputs=inputs,
                    rng=rng,
                    rescale_by=self.rescale_by,
                    can_run_backwards=self.can_run_backwards,
                    train_sequence_length=self.train_sequence_length,
                    reconstruction_skip=1,
                    p_x_learned_sigma=self.decoder_kwargs.get(
                        "learned_sigma", False)))

        return stats["loss"], (dict(), stats, dict())
Exemple #4
0
 def init(
     self,
     rng: jnp.ndarray,
     inputs_or_shape: Union[jnp.ndarray, Mapping[str, jnp.ndarray],
                            Sequence[int]],
 ) -> Tuple[utils.Params, hk.State]:
   """Initializes the whole model parameters and state."""
   if (isinstance(inputs_or_shape, (tuple, list))
       and isinstance(inputs_or_shape[0], int)):
     images = jnp.zeros(inputs_or_shape)
   else:
     images = utils.extract_image(inputs_or_shape)
   if self._jit_init is None:
     self._jit_init = jax.jit(self._init)
   return self._jit_init(rng, images)
    def _eval_batch_vpt(self, params, state, rng_key, batch):
        full_trajectory = utils.extract_image(batch)
        prefixes = ("forward",
                    "backward") if self.model.can_run_backwards else (
                        "forward", )
        stats = dict()
        vpt_abs_scores = []
        vpt_rel_scores = []
        seq_length = None
        for prefix in prefixes:
            reconstruction, gt_images = self._reconstruct_and_align(
                rng_key, full_trajectory, prefix, "extrapolation")
            seq_length = gt_images.shape[2]

            mse_norm = np.mean(
                (gt_images - reconstruction)**2, axis=(3, 4, 5)) / np.mean(
                    gt_images**2, axis=(3, 4, 5))

            vpt_scores = []
            for i in range(mse_norm.shape[1]):
                vpt_ind = np.argwhere(
                    mse_norm[:, i:i +
                             1, :] > self.config.evaluation_vpt.vpt_threshold)

                if vpt_ind.shape[0] > 0:
                    vpt_ind = vpt_ind[0][2]
                else:
                    vpt_ind = mse_norm.shape[-1]

                vpt_scores.append(vpt_ind)

            vpt_abs_scores.append(np.median(vpt_scores))
            vpt_rel_scores.append(np.median(vpt_scores) / seq_length)
            scores = {
                "vpt_abs": vpt_abs_scores[-1],
                "vpt_rel": vpt_rel_scores[-1]
            }
            scores = utils.to_numpy(scores)
            scores = utils.filter_only_scalar_stats(scores)
            stats[prefix] = scores

        stats["vpt_abs"] = utils.to_numpy(np.mean(vpt_abs_scores))
        stats["vpt_rel"] = utils.to_numpy(np.mean(vpt_rel_scores))
        logging.info("vpt_abs: %s, seq_length: %d}", str(vpt_abs_scores),
                     seq_length)
        return stats
Exemple #6
0
    def reconstruct(
        self,
        params: utils.Params,
        inputs: jnp.ndarray,
        rng: jnp.ndarray,
        forward: bool,
        use_mean: bool = True,
    ) -> distrax.Distribution:
        """Reconstructs the input sequence."""
        if not forward:
            raise ValueError("This model can not run backwards.")
        images = utils.extract_image(inputs)
        image_data = images[:, :self.num_inference_steps]

        return self._models_core(
            params=params,
            keys=jnr.split(rng, 6),
            image_data=image_data,
            is_training=False,
            num_steps_forward=images.shape[1] - self.num_inference_steps,
            num_steps_backward=0,
            include_z0=False,
        )[0]
Exemple #7
0
    def reconstruct(
        self,
        params: utils.Params,
        inputs: jnp.ndarray,
        rng: Optional[jnp.ndarray],
        forward: bool,
        use_mean: bool = True,
    ) -> distrax.Distribution:
        if not self.can_run_backwards and not forward:
            raise ValueError("This model can not be run backwards.")
        images = utils.extract_image(inputs)
        # This is intentionally matching the split for the training stats
        if forward:
            num_steps_backward = self.inferred_index
            num_steps_forward = images.shape[1] - num_steps_backward - 1
        else:
            num_steps_forward = self.num_inference_steps - self.inferred_index - 1
            num_steps_backward = images.shape[1] - num_steps_forward - 1
        if not self.can_run_backwards:
            num_steps_backward = 0

        if forward:
            image_data = images[:, :self.num_inference_steps]
        else:
            image_data = images[:, -self.num_inference_steps:]

        return self._models_core(
            params=params,
            keys=jnr.split(rng, 6),
            image_data=image_data,
            use_mean=use_mean,
            is_training=False,
            num_steps_forward=num_steps_forward,
            num_steps_backward=num_steps_backward,
            include_z0=True,
        )[0]
def evaluation_only_statistics(
    reconstruct_func: _ReconstructFunc,
    params: hk.Params,
    inputs: jnp.ndarray,
    rng: jnp.ndarray,
    rescale_by: str,
    can_run_backwards: bool,
    train_sequence_length: int,
    reconstruction_skip: int,
    p_x_learned_sigma: bool = False,
) -> Dict[str, jnp.ndarray]:
    """Computes various statistics we track only during evaluation."""
    full_trajectory = utils.extract_image(inputs)
    prefixes = ("forward", "backward") if can_run_backwards else ("forward", )

    full_forward_targets = jax.tree_map(lambda x: x[:, reconstruction_skip:],
                                        full_trajectory)
    full_backward_targets = jax.tree_map(
        lambda x: x[:, :x.shape[1] - reconstruction_skip], full_trajectory)
    train_targets_length = train_sequence_length - reconstruction_skip
    full_targets_length = full_forward_targets.shape[1]

    stats = dict()
    keys = ()

    for prefix in prefixes:
        # Fully unroll the model and reconstruct the whole sequence
        full_prediction = reconstruct_func(params, full_trajectory, rng,
                                           prefix == "forward")
        assert isinstance(full_prediction, distrax.Normal)
        full_targets = (full_forward_targets
                        if prefix == "forward" else full_backward_targets)
        # In cases where the model can run backwards it is possible to reconstruct
        # parts which were indented to be skipped, so here we take care of that.
        if full_prediction.mean().shape[1] > full_targets_length:
            if prefix == "forward":
                full_prediction = jax.tree_map(
                    lambda x: x[:, -full_targets_length:], full_prediction)
            else:
                full_prediction = jax.tree_map(
                    lambda x: x[:, :full_targets_length], full_prediction)

        # Based on the prefix and suffix fetch correct predictions and targets
        for suffix in ("train", "extrapolation", "full"):
            if prefix == "forward" and suffix == "train":
                predict, targets = jax.tree_map(
                    lambda x: x[:, :train_targets_length],
                    (full_prediction, full_targets))
            elif prefix == "forward" and suffix == "extrapolation":
                predict, targets = jax.tree_map(
                    lambda x: x[:, train_targets_length:],
                    (full_prediction, full_targets))
            elif prefix == "backward" and suffix == "train":
                predict, targets = jax.tree_map(
                    lambda x: x[:, -train_targets_length:],
                    (full_prediction, full_targets))
            elif prefix == "backward" and suffix == "extrapolation":
                predict, targets = jax.tree_map(
                    lambda x: x[:, :-train_targets_length],
                    (full_prediction, full_targets))
            else:
                predict, targets = full_prediction, full_targets

            # Compute train statistics
            train_stats = training_statistics(
                predict,
                targets,
                rescale_by,
                p_x_learned_sigma=p_x_learned_sigma)
            for key, value in train_stats.items():
                stats[prefix + "_" + suffix + "_" + key] = value
            # Copy all stats keys
            keys = tuple(train_stats.keys())

    # Make a combined metric summing forward and backward
    if can_run_backwards:
        # Also compute
        for suffix in ("train", "extrapolation", "full"):
            for key in keys:
                forward = stats["forward_" + suffix + "_" + key]
                backward = stats["backward_" + suffix + "_" + key]
                combined = (forward + backward) / 2
                stats["combined_" + suffix + "_" + key] = combined

    return stats
Exemple #9
0
    def training_objectives(
        self,
        params: utils.Params,
        state: hk.State,
        rng: jnp.ndarray,
        inputs: jnp.ndarray,
        step: jnp.ndarray,
        is_training: bool = True,
        use_mean_for_eval_stats: bool = True
    ) -> Tuple[jnp.ndarray, Sequence[Dict[str, jnp.ndarray]]]:
        # Split all rng keys
        keys = jnr.split(rng, 6)

        # Process training data
        images = utils.extract_image(inputs)
        image_data, target_data, unroll_kwargs = self.train_data_split(images)

        p_x, q_z, prior, _, _, dyn_stats = self._models_core(
            params=params,
            keys=keys,
            image_data=image_data,
            use_mean=False,
            is_training=is_training,
            **unroll_kwargs)

        # Note: we reuse the rng key used to sample the latent variable here
        # so that it can be reused to evaluate a (non-analytical) KL at that sample.
        stats = metrics.training_statistics(
            p_x=p_x,
            targets=target_data,
            rescale_by=self.rescale_by,
            rng=keys[1],
            q_z=q_z,
            prior=prior,
            p_x_learned_sigma=self.decoder_kwargs.get("learned_sigma", False))
        stats.update(dyn_stats)

        # Compute other (non-reported statistics)
        z_stats = dict()
        other_stats = dict(x_reconstruct=p_x.mean(), z_stats=z_stats)

        # The loss computation and GECO state update
        new_state = dict()
        if self.objective_type == "GECO":
            geco_stats = metrics.geco_objective(
                l2_loss=stats["l2"],
                kl=stats["kl"],
                alpha=self.geco_alpha,
                kappa=self.geco_kappa,
                constraint_ema=state["GECO"]["geco_constraint_ema"],
                lambda_var=params["GECO"]["geco_lambda_var"],
                is_training=is_training)
            new_state["GECO"] = dict(
                geco_constraint_ema=geco_stats["geco_constraint_ema"])
            stats.update(geco_stats)
        elif self.objective_type == "ELBO":
            elbo_stats = metrics.elbo_objective(
                neg_log_p_x=stats["neg_log_p_x"],
                kl=stats["kl"],
                final_beta=self.elbo_beta_final,
                beta_delay=self.elbo_beta_delay,
                step=step)
            stats.update(elbo_stats)
        elif self.objective_type == "NON-PROB":
            stats["loss"] = stats["neg_log_p_x"]
        else:
            raise ValueError()

        if not is_training:
            if self.training_data_split == "overlap_by_one":
                reconstruction_skip = self.num_inference_steps - 1
            elif self.training_data_split == "no_overlap":
                reconstruction_skip = self.num_inference_steps
            elif self.training_data_split == "include_inference":
                reconstruction_skip = 0
            else:
                raise NotImplementedError()
            # We intentionally reuse the same rng as the training, in order to be able
            # to run tests and verify that the evaluation and reconstruction work
            # correctly.
            # We need to be able to set `use_mean = False` for some of the tests
            stats.update(
                metrics.evaluation_only_statistics(
                    reconstruct_func=functools.partial(
                        self.reconstruct, use_mean=use_mean_for_eval_stats),
                    params=params,
                    inputs=inputs,
                    rng=rng,
                    rescale_by=self.rescale_by,
                    can_run_backwards=self.can_run_backwards,
                    train_sequence_length=self.train_sequence_length,
                    reconstruction_skip=reconstruction_skip,
                    p_x_learned_sigma=self.decoder_kwargs.get(
                        "learned_sigma", False)))

        # Make new state the same type as state
        new_state = utils.convert_to_pytype(new_state, state)
        return stats["loss"], (new_state, stats, other_stats)