Exemplo n.º 1
0
 def latent_diffusion_function(self, x, t, param_dict, gp_matrices):
     if self.config["constant_diffusion"]:
         concentration, rate = self.gamma_params.build(
             param_dict["gamma_params"])
         self.key, subkey = random.split(self.key)
         inverse_lambdas = numpyro.sample("inverse_lambdas",
                                          InverseGamma(
                                              nn.softplus(concentration),
                                              nn.softplus(rate)),
                                          rng_key=subkey)
         return np.tile(
             np.expand_dims(aux_math.diag(inverse_lambdas), axis=0),
             [x.shape[0], 1, 1])
     else:
         if self.config["time_dependent_gp"]:
             time = np.ones(shape=(x.shape[0], 1)) * t
             y = np.concatenate((x, time), axis=1)
             return aux_math.diag(
                 np.transpose(
                     self.sde_gp_diffusion(y, param_dict["sde_gp"],
                                           gp_matrices)))
         return aux_math.diag(
             np.transpose(
                 self.sde_gp_diffusion(x, param_dict["sde_gp"],
                                       gp_matrices)))
Exemplo n.º 2
0
 def __call__(self, y, sc, multiplicative_factor=None):
     net = self.predict(sc["encoder_params"], y)
     if multiplicative_factor is None:
         scale_tril = aux_math.diag(nn.softplus(net[...,
                                                    self.output_dims:]))
     else:
         scale_tril = np.einsum(
             "ab,cbd->cad", multiplicative_factor,
             aux_math.diag(nn.softplus(net[..., self.output_dims:])))
     return net[..., :self.output_dims], scale_tril
Exemplo n.º 3
0
    def step(self, x_0, time):
        shape = [np.shape(x_0)[0], self.beta_dims]

        # Vector of zeros
        beta_mean_vector = np.concatenate((np.zeros(shape), np.zeros(shape)), axis=1)

        # Covariance matrix for the betas and gammas
        beta_covariance_top_left = self.delta_t ** 3 / 3 * aux_math.diag(np.ones(shape))
        beta_covariance_top_right = self.delta_t ** 2 / 2 * aux_math.diag(np.ones(shape))
        beta_covariance_bottom_right = self.delta_t * aux_math.diag(np.ones(shape))
        beta_covariance_top = np.concatenate((beta_covariance_top_left, beta_covariance_top_right), axis=2)
        beta_covariance_bottom = np.concatenate((beta_covariance_top_right, beta_covariance_bottom_right), axis=2)
        beta_covariance = np.concatenate((beta_covariance_top, beta_covariance_bottom), axis=1)

        self.key, subkey = random.split(self.key)
        delta_gamma_beta = numpyro.sample("delta_gamma_beta",
                                          MultivariateNormal(loc=beta_mean_vector,
                                                             covariance_matrix=beta_covariance),
                                          rng_key=subkey)

        delta_gamma = delta_gamma_beta[:, 0:self.beta_dims]
        delta_beta = delta_gamma_beta[:, self.beta_dims:]

        # Supporting values
        drift_0 = self.drift_function(x_0, time) * self.delta_t

        init_x_1 = x_0 + drift_0 + np.einsum("abc,ac->ab", self.diffusion_function(x_0, time), delta_beta)

        def scan_fn(carry, s):
            x_1 = carry
            x_0_plus = \
                x_0 + drift_0 / self.beta_dims + \
                self.diffusion_function(x_0, time)[..., s] * np.sqrt(self.delta_t)
            x_0_minus = \
                x_0 + drift_0 / self.beta_dims - \
                self.diffusion_function(x_0, time)[..., s] * np.sqrt(self.delta_t)

            drift_0_plus = self.drift_function(x_0_plus, time)
            drift_0_minus = self.drift_function(x_0_minus, time)
            x_1 += 0.25 * self.delta_t * (drift_0_plus + drift_0_minus)
            x_1 -= 0.5 * drift_0
            x_1 += \
                1. / (2 * np.sqrt(self.delta_t)) * (drift_0_plus-drift_0_minus) * \
                np.expand_dims(delta_gamma[:, s], axis=-1)

            return x_1, None

        final_x_1, _ = lax.scan(scan_fn, init_x_1, np.arange(self.beta_dims))

        return final_x_1
Exemplo n.º 4
0
    def loss(self, y_input, t_indices: list, param_dict, num_steps):

        if type(t_indices) is not list:
            raise TypeError("Time indices object must be a list")
        # if self.config["mapping"] == "neural_ode_with_softplus" and 0 not in t_indices:
        #     print(f"For mapping {self.config['mapping']}, the initial point (index 0) is required.")

        y_0 = y_input[0].reshape(y_input[0].shape[0], -1)

        gp_matrices, latent_drift_function, latent_diffusion_function = self.build(
            param_dict)

        self.sde_var.drift_function = latent_drift_function
        self.sde_var.diffusion_function = latent_diffusion_function

        self.y_t, self.paths_y = self.sde_var(y_0, num_steps)

        y_t_to_compare = self.paths_y[ops.index[t_indices]]
        metrics = dict()
        metrics["reco"] = \
            np.mean(
                aux_math.log_prob_multivariate_normal(
                    y_t_to_compare,
                    aux_math.diag(np.sqrt(nn.softplus(self.signal_variance.build(param_dict["likelihood"])))),
                    y_input[t_indices]))

        self.get_metrics(metrics, gp_matrices, param_dict)
        return -metrics["elbo"], metrics
Exemplo n.º 5
0
 def __call__(self, x: Array, t: float) -> Array:
     j = self.output_size
     diag = hk.get_parameter("diag",
                             shape=[j],
                             dtype=x.dtype,
                             init=hk.initializers.RandomNormal())
     return aux_math.diag(jax.nn.softplus(diag))
Exemplo n.º 6
0
    def __call__(self, y_input: Array, t_mask: Array,
                 training: int) -> Tuple[Metrics, ItoGeneralOutput]:
        x_0 = self.initial_latents()
        likelihood = self.likelihood()

        t_seq, paths_x, paths_y, paths_y_generated = self.sde(
            x_0, y_input, t_mask, training)

        # Drifts and diffusions
        drift_y = jax.vmap(lambda x, t: self.drift_y(x, t), (0, 0))(paths_x,
                                                                    t_seq)
        diffusion_y = jax.vmap(lambda x, t: self.diffusion_y(x, t),
                               (0, 0))(paths_x, t_seq)

        drift_x = jax.vmap(lambda x, t: self.drift_x(x, t), (0, 0))(paths_x,
                                                                    t_seq)
        diff_x = aux_math.diag_part(
            jax.vmap(lambda x, t: self.diffusion_x(x, t), (0, 0))(paths_x,
                                                                  t_seq))

        # Objectives
        y_objective = y_input * t_mask + jnp.abs(t_mask - 1) * paths_y

        elbo = jnp.mean(
            aux_math.log_prob_multivariate_normal(
                paths_y[1:],
                aux_math.diag(diffusion_y[:-1] * jnp.sqrt(self.data.delta_t)),
                y_objective[1:]))
        elbo_generated = jnp.mean(
            aux_math.log_prob_multivariate_normal(
                paths_y_generated[1:],
                aux_math.diag(diffusion_y[:-1] * jnp.sqrt(self.data.delta_t)),
                y_objective[1:]))
        mse = jnp.mean(jnp.sum((paths_y[1:] - y_objective[1:])**2, axis=-1))
        paths_y = paths_y_generated * training + jnp.abs(1 -
                                                         training) * paths_y

        return \
            Metrics(elbo, elbo_generated, mse), \
            ItoGeneralOutput(t_seq, paths_x, drift_x, diff_x, paths_y, drift_y, diffusion_y, t_mask)
Exemplo n.º 7
0
    def step(self, x_0, time):
        self.key, subkey = random.split(self.key)

        shape = np.array([np.shape(x_0)[0], self.beta_dims], dtype=np.int8)
        delta_beta = numpyro.sample("delta_gamma_beta",
                                    MultivariateNormal(
                                        loc=np.zeros(shape),
                                        scale_tril=np.sqrt(self.delta_t) * aux_math.diag(np.ones(shape))),
                                    rng_key=subkey)
        x_1 = x_0 + self.drift_function(x_0, time) * self.delta_t + \
            np.einsum("abc,ac->ab", self.diffusion_function(x_0, time), delta_beta)

        return x_1
Exemplo n.º 8
0
 def __call__(self) -> Array:
     likelihood = hk.get_parameter(
         "likelihood", [self.output_size],
         init=hk.initializers.RandomNormal(mean=-5))
     return aux_math.diag(jax.nn.softplus(likelihood))