Beispiel #1
0
    def get_params(self, i, x, output_size):
        w_init = hk.initializers.VarianceScaling(
            scale=1.0, mode="fan_avg", distribution="truncated_normal")

        # Pass a singly batched input to the parameter functions.
        # Don't use autobatching here because we might end up reducing
        x, reshape = self.make_singly_batched(x)

        if self.parameter_norm == "weight_norm":
            w, b = init.weight_with_weight_norm(x=x,
                                                out_dim=output_size,
                                                name_suffix=str(i),
                                                w_init=self.w_init,
                                                b_init=jnp.zeros,
                                                is_training=True,
                                                use_bias=True)
        elif self.parameter_norm == "spectral_norm":
            w, b = init.weight_with_spectral_norm(x=x,
                                                  out_dim=output_size,
                                                  name_suffix=str(i),
                                                  w_init=self.w_init,
                                                  b_init=jnp.zeros,
                                                  is_training=True,
                                                  use_bias=True)
        else:
            w = hk.get_parameter(f"w_{i}", (output_size, x.shape[-1]),
                                 x.dtype,
                                 init=self.w_init)
            b = hk.get_parameter(f"b_{i}", (output_size, ), init=jnp.zeros)

        # x = reshape(x)

        return w.T, b
Beispiel #2
0
  def call(self,
           inputs: Mapping[str, jnp.ndarray],
           rng: jnp.ndarray=None,
           sample: Optional[bool]=False,
           **kwargs
  ) -> Mapping[str, jnp.ndarray]:
    x = inputs["x"]
    outputs = {}

    x_dim, dtype = x.shape[-1], inputs["x"].dtype

    if self.weight_norm:
      W, b = init.weight_with_weight_norm(x,
                                          out_dim=x_dim,
                                          w_init=hk.initializers.RandomNormal(0.1),
                                          b_init=jnp.zeros,
                                          is_trainig=kwargs.get("is_trainig", False),
                                          use_bias=True)
    else:
      W_init = hk.initializers.TruncatedNormal(1/jnp.sqrt(x_dim))
      W = hk.get_parameter("W", shape=(x_dim, x_dim), dtype=dtype, init=W_init)
      b = hk.get_parameter("b", shape=(x_dim,), dtype=dtype, init=jnp.zeros)

    if sample == False:
      outputs["x"] = jnp.dot(x, W.T) + b
    else:
      w_inv = jnp.linalg.inv(W)
      outputs["x"] = jnp.dot(x - b, w_inv.T)

    outputs["log_det"] = jnp.linalg.slogdet(W)[1]*jnp.ones(self.batch_shape)

    return outputs
Beispiel #3
0
def data_dependent_param_init(x: jnp.ndarray,
                              out_dim: int,
                              name_suffix: str = "",
                              w_init: Callable = None,
                              b_init: Callable = None,
                              is_training: bool = True,
                              parameter_norm: str = None,
                              use_bias: bool = True,
                              update_params: bool = True,
                              **kwargs):

    if parameter_norm == "spectral_norm":
        return init.weight_with_spectral_norm(x=x,
                                              out_dim=out_dim,
                                              name_suffix=name_suffix,
                                              w_init=w_init,
                                              b_init=b_init,
                                              is_training=is_training,
                                              use_bias=use_bias,
                                              **kwargs)
    elif parameter_norm == "differentiable_spectral_norm":
        return init.weight_with_good_spectral_norm(x=x,
                                                   out_dim=out_dim,
                                                   name_suffix=name_suffix,
                                                   w_init=w_init,
                                                   b_init=b_init,
                                                   is_training=is_training,
                                                   update_params=update_params,
                                                   use_bias=use_bias,
                                                   **kwargs)

    elif parameter_norm == "weight_norm":
        if x.shape[0] > 1:
            return init.weight_with_weight_norm(x=x,
                                                out_dim=out_dim,
                                                name_suffix=name_suffix,
                                                w_init=w_init,
                                                b_init=b_init,
                                                is_training=is_training,
                                                use_bias=use_bias,
                                                **kwargs)

    elif parameter_norm is not None:
        assert 0, "Invalid weight choice.  Expected 'spectral_norm' or 'weight_norm'"

    in_dim, dtype = x.shape[-1], x.dtype

    w = hk.get_parameter(f"w_{name_suffix}", (out_dim, in_dim), init=w_init)
    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_dim, ), init=b_init)

    if use_bias:
        return w, b
    return w
Beispiel #4
0
  def get_params(i, x, output_size):
    w_init = hk.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="truncated_normal")

    if self.parameter_norm == "weight_norm":
      w, b = init.weight_with_weight_norm(x=x,
                                          out_dim=output_size,
                                          name_suffix=str(i),
                                          w_init=w_init,
                                          b_init=jnp.zeros,
                                          is_training=True,
                                          use_bias=True)
    elif self.parameter_norm == "spectral_norm":
      w, b = init.weight_with_spectral_norm(x=x,
                                            out_dim=output_size,
                                            name_suffix=str(i),
                                            w_init=w_init,
                                            b_init=jnp.zeros,
                                            is_training=True,
                                            use_bias=True)
    else:
      w = hk.get_parameter(f"w_{i}", (output_size, x.shape[-1]), x.dtype, init=w_init)
      b = hk.get_parameter(f"b_{i}", (output_size,), init=jnp.zeros)

    return w.T, b
Beispiel #5
0
  def call(self,
           inputs: Mapping[str, jnp.ndarray],
           rng: PRNGKey,
           sample: Optional[bool]=False,
           no_noise: Optional[bool]=False,
           **kwargs
  ) -> Mapping[str, jnp.ndarray]:

    # p(gamma|s) = N(gamma|mu(s), Sigma(s))
    if self.image_in:
      out_shape = self.input_shape[:-1] + (2*self.input_shape[-1],)
    else:
      out_shape = (2*self.big_dim,)
    self.p_gamma_given_s = vae.ParametrizedGaussian(out_shape=out_shape,
                                                    create_network=self.create_network,
                                                    network_kwargs=self.network_kwargs)

    #######################
    assert self.big_dim - self.small_dim > 0

    # Initialize the tall or wide matrix.  We might want to choose to parametrize a tall
    # matrix as the pseudo-inverse of a wide matrix or vice-versa.  B is wide and A is tall.
    init_fun = hk.initializers.RandomNormal(stddev=0.05)
    dtype = inputs["x"].dtype
    if self.reverse_params:
      x = inputs["x"].reshape(self.batch_shape + (-1,))

      if self.spectral_norm:
        self.B = init.weight_with_spectral_norm(x,
                                                self.small_dim,
                                                use_bias=False,
                                                w_init=init_fun,
                                                force_in_dim=self.big_dim,
                                                is_training=kwargs.get("is_training", True),
                                                update_params=kwargs.get("is_training", True))
      else:
        if self.weight_norm and self.kind == "tall":
          self.B = init.weight_with_weight_norm(x, self.small_dim, use_bias=False, force_in_dim=self.big_dim)
        else:
          self.B = hk.get_parameter("B", shape=(self.small_dim, self.big_dim), dtype=dtype, init=init_fun)
        self.B = util.whiten(self.B)
    else:
      if self.spectral_norm:
        self.A = init.weight_with_spectral_norm(x,
                                                self.big_dim,
                                                use_bias=False,
                                                w_init=init_fun,
                                                force_in_dim=self.small_dim,
                                                is_training=kwargs.get("is_training", True),
                                                update_params=kwargs.get("is_training", True))
      else:
        self.A = hk.get_parameter("A", shape=(self.big_dim, self.small_dim), dtype=dtype, init=init_fun)
        self.A = util.whiten(self.A)

    # Compute the riemannian metric matrix for later use.
    if self.reverse_params:
      self.BBT     = [email protected]
      self.BBT_inv = jnp.linalg.inv(self.BBT)
    else:
      self.ATA     = [email protected]
      self.ATA_inv = jnp.linalg.inv(self.ATA)

    #######################

    # Figure out which direction we should go
    if sample == False:
      big_to_small = True if self.kind == "tall" else False
    else:
      big_to_small = False if self.kind == "tall" else True

    #######################

    # Compute the next value
    if big_to_small:
      t = inputs["x"]

      # If we're going from image -> vector, we need to flatten the image
      if self.image_in:
        t = t.reshape(self.batch_shape + (-1,))

      # Compute the pseudo inverse and projection
      # s <- self.A^+t
      s = self.pinv(t)
      t_proj = self.project(s=s)

      # Compute the perpendicular component of t for the log contribution
      # gamma_perp <- t - AA^+t
      gamma_perp = t - t_proj

      # Find mu(s), Sigma(s).  If we have an image as input, pass in the projected input image
      # mu, Sigma <- NN(s, theta)
      _, mu, log_diag_cov = self.orthogonal_distribution(s, t_proj, rng, no_noise=True)

      # Compute the log contribution
      # L <- logZ(mu - gamma_perp|self.A, Sigma)
      likelihood_contribution = self.likelihood_contribution(mu, gamma_perp, log_diag_cov, sample=sample, big_to_small=big_to_small)

      outputs = {"x": s, "log_det": likelihood_contribution}

    else:
      s = inputs["x"]

      # Compute the mean of t.  Primarily used if we have an image as input
      t_mean = self.project(s=s)

      # Find mu(s), Sigma(s).  If we have an image as input, pass in the projected input image
      # mu, Sigma <- NN(s, theta)
      # gamma ~ N(mu, Sigma)
      gamma, mu, log_diag_cov = self.orthogonal_distribution(s, t_mean, rng, no_noise=no_noise)

      # Compute the orthogonal component of the noise
      # gamma_perp <- gamma - AA^+ gamma
      gamma_proj = self.project(t=gamma)
      gamma_perp = gamma - gamma_proj

      # Add the orthogonal features
      # t <- As + gamma_perp
      t = t_mean + gamma_perp

      # Compute the log contribution
      # L <- logZ(mu - gamma_perp|self.A, Sigma)
      likelihood_contribution = -self.likelihood_contribution(mu, gamma_perp, log_diag_cov, sample=sample, big_to_small=big_to_small)

      # Reshape to an image if needed
      if self.image_in:
        t = t.reshape(self.batch_shape + self.input_shape)

      outputs = {"x": t, "log_det": likelihood_contribution}

    return outputs