def get_params(self, i, x, output_size): w_init = hk.initializers.VarianceScaling( scale=1.0, mode="fan_avg", distribution="truncated_normal") # Pass a singly batched input to the parameter functions. # Don't use autobatching here because we might end up reducing x, reshape = self.make_singly_batched(x) if self.parameter_norm == "weight_norm": w, b = init.weight_with_weight_norm(x=x, out_dim=output_size, name_suffix=str(i), w_init=self.w_init, b_init=jnp.zeros, is_training=True, use_bias=True) elif self.parameter_norm == "spectral_norm": w, b = init.weight_with_spectral_norm(x=x, out_dim=output_size, name_suffix=str(i), w_init=self.w_init, b_init=jnp.zeros, is_training=True, use_bias=True) else: w = hk.get_parameter(f"w_{i}", (output_size, x.shape[-1]), x.dtype, init=self.w_init) b = hk.get_parameter(f"b_{i}", (output_size, ), init=jnp.zeros) # x = reshape(x) return w.T, b
def call(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray=None, sample: Optional[bool]=False, **kwargs ) -> Mapping[str, jnp.ndarray]: x = inputs["x"] outputs = {} x_dim, dtype = x.shape[-1], inputs["x"].dtype if self.weight_norm: W, b = init.weight_with_weight_norm(x, out_dim=x_dim, w_init=hk.initializers.RandomNormal(0.1), b_init=jnp.zeros, is_trainig=kwargs.get("is_trainig", False), use_bias=True) else: W_init = hk.initializers.TruncatedNormal(1/jnp.sqrt(x_dim)) W = hk.get_parameter("W", shape=(x_dim, x_dim), dtype=dtype, init=W_init) b = hk.get_parameter("b", shape=(x_dim,), dtype=dtype, init=jnp.zeros) if sample == False: outputs["x"] = jnp.dot(x, W.T) + b else: w_inv = jnp.linalg.inv(W) outputs["x"] = jnp.dot(x - b, w_inv.T) outputs["log_det"] = jnp.linalg.slogdet(W)[1]*jnp.ones(self.batch_shape) return outputs
def data_dependent_param_init(x: jnp.ndarray, out_dim: int, name_suffix: str = "", w_init: Callable = None, b_init: Callable = None, is_training: bool = True, parameter_norm: str = None, use_bias: bool = True, update_params: bool = True, **kwargs): if parameter_norm == "spectral_norm": return init.weight_with_spectral_norm(x=x, out_dim=out_dim, name_suffix=name_suffix, w_init=w_init, b_init=b_init, is_training=is_training, use_bias=use_bias, **kwargs) elif parameter_norm == "differentiable_spectral_norm": return init.weight_with_good_spectral_norm(x=x, out_dim=out_dim, name_suffix=name_suffix, w_init=w_init, b_init=b_init, is_training=is_training, update_params=update_params, use_bias=use_bias, **kwargs) elif parameter_norm == "weight_norm": if x.shape[0] > 1: return init.weight_with_weight_norm(x=x, out_dim=out_dim, name_suffix=name_suffix, w_init=w_init, b_init=b_init, is_training=is_training, use_bias=use_bias, **kwargs) elif parameter_norm is not None: assert 0, "Invalid weight choice. Expected 'spectral_norm' or 'weight_norm'" in_dim, dtype = x.shape[-1], x.dtype w = hk.get_parameter(f"w_{name_suffix}", (out_dim, in_dim), init=w_init) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_dim, ), init=b_init) if use_bias: return w, b return w
def get_params(i, x, output_size): w_init = hk.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="truncated_normal") if self.parameter_norm == "weight_norm": w, b = init.weight_with_weight_norm(x=x, out_dim=output_size, name_suffix=str(i), w_init=w_init, b_init=jnp.zeros, is_training=True, use_bias=True) elif self.parameter_norm == "spectral_norm": w, b = init.weight_with_spectral_norm(x=x, out_dim=output_size, name_suffix=str(i), w_init=w_init, b_init=jnp.zeros, is_training=True, use_bias=True) else: w = hk.get_parameter(f"w_{i}", (output_size, x.shape[-1]), x.dtype, init=w_init) b = hk.get_parameter(f"b_{i}", (output_size,), init=jnp.zeros) return w.T, b
def call(self, inputs: Mapping[str, jnp.ndarray], rng: PRNGKey, sample: Optional[bool]=False, no_noise: Optional[bool]=False, **kwargs ) -> Mapping[str, jnp.ndarray]: # p(gamma|s) = N(gamma|mu(s), Sigma(s)) if self.image_in: out_shape = self.input_shape[:-1] + (2*self.input_shape[-1],) else: out_shape = (2*self.big_dim,) self.p_gamma_given_s = vae.ParametrizedGaussian(out_shape=out_shape, create_network=self.create_network, network_kwargs=self.network_kwargs) ####################### assert self.big_dim - self.small_dim > 0 # Initialize the tall or wide matrix. We might want to choose to parametrize a tall # matrix as the pseudo-inverse of a wide matrix or vice-versa. B is wide and A is tall. init_fun = hk.initializers.RandomNormal(stddev=0.05) dtype = inputs["x"].dtype if self.reverse_params: x = inputs["x"].reshape(self.batch_shape + (-1,)) if self.spectral_norm: self.B = init.weight_with_spectral_norm(x, self.small_dim, use_bias=False, w_init=init_fun, force_in_dim=self.big_dim, is_training=kwargs.get("is_training", True), update_params=kwargs.get("is_training", True)) else: if self.weight_norm and self.kind == "tall": self.B = init.weight_with_weight_norm(x, self.small_dim, use_bias=False, force_in_dim=self.big_dim) else: self.B = hk.get_parameter("B", shape=(self.small_dim, self.big_dim), dtype=dtype, init=init_fun) self.B = util.whiten(self.B) else: if self.spectral_norm: self.A = init.weight_with_spectral_norm(x, self.big_dim, use_bias=False, w_init=init_fun, force_in_dim=self.small_dim, is_training=kwargs.get("is_training", True), update_params=kwargs.get("is_training", True)) else: self.A = hk.get_parameter("A", shape=(self.big_dim, self.small_dim), dtype=dtype, init=init_fun) self.A = util.whiten(self.A) # Compute the riemannian metric matrix for later use. if self.reverse_params: self.BBT = [email protected] self.BBT_inv = jnp.linalg.inv(self.BBT) else: self.ATA = [email protected] self.ATA_inv = jnp.linalg.inv(self.ATA) ####################### # Figure out which direction we should go if sample == False: big_to_small = True if self.kind == "tall" else False else: big_to_small = False if self.kind == "tall" else True ####################### # Compute the next value if big_to_small: t = inputs["x"] # If we're going from image -> vector, we need to flatten the image if self.image_in: t = t.reshape(self.batch_shape + (-1,)) # Compute the pseudo inverse and projection # s <- self.A^+t s = self.pinv(t) t_proj = self.project(s=s) # Compute the perpendicular component of t for the log contribution # gamma_perp <- t - AA^+t gamma_perp = t - t_proj # Find mu(s), Sigma(s). If we have an image as input, pass in the projected input image # mu, Sigma <- NN(s, theta) _, mu, log_diag_cov = self.orthogonal_distribution(s, t_proj, rng, no_noise=True) # Compute the log contribution # L <- logZ(mu - gamma_perp|self.A, Sigma) likelihood_contribution = self.likelihood_contribution(mu, gamma_perp, log_diag_cov, sample=sample, big_to_small=big_to_small) outputs = {"x": s, "log_det": likelihood_contribution} else: s = inputs["x"] # Compute the mean of t. Primarily used if we have an image as input t_mean = self.project(s=s) # Find mu(s), Sigma(s). If we have an image as input, pass in the projected input image # mu, Sigma <- NN(s, theta) # gamma ~ N(mu, Sigma) gamma, mu, log_diag_cov = self.orthogonal_distribution(s, t_mean, rng, no_noise=no_noise) # Compute the orthogonal component of the noise # gamma_perp <- gamma - AA^+ gamma gamma_proj = self.project(t=gamma) gamma_perp = gamma - gamma_proj # Add the orthogonal features # t <- As + gamma_perp t = t_mean + gamma_perp # Compute the log contribution # L <- logZ(mu - gamma_perp|self.A, Sigma) likelihood_contribution = -self.likelihood_contribution(mu, gamma_perp, log_diag_cov, sample=sample, big_to_small=big_to_small) # Reshape to an image if needed if self.image_in: t = t.reshape(self.batch_shape + self.input_shape) outputs = {"x": t, "log_det": likelihood_contribution} return outputs