Пример #1
0
  def get_normalized_weights(self,
                             weights: jnp.ndarray,
                             renormalize: bool = False) -> jnp.ndarray:

    def _l2_normalize(x, axis=None, eps=1e-12):
      return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps)

    output_size = self.output_size
    dtype = weights.dtype
    assert output_size == weights.shape[-1]
    sigma = hk.get_state('sigma', (), init=jnp.ones)
    if renormalize:
      # Power iterations to compute spectral norm V*W*U^T.
      u = hk.get_state(
          'u', (1, output_size), dtype, init=hk.initializers.RandomNormal())
      for _ in range(self.num_iterations):
        v = _l2_normalize(jnp.matmul(u, weights.transpose()), eps=self.eps)
        u = _l2_normalize(jnp.matmul(v, weights), eps=self.eps)
      u = jax.lax.stop_gradient(u)
      v = jax.lax.stop_gradient(v)
      sigma = jnp.matmul(jnp.matmul(v, weights), jnp.transpose(u))[0, 0]
      hk.set_state('u', u)
      hk.set_state('v', v)
      hk.set_state('sigma', sigma)
    factor = jnp.maximum(1, sigma / self.lipschitz_coeff)
    return weights / factor
Пример #2
0
def transp(a: jnp.ndarray) -> jnp.ndarray:
    """Returns transposed matrix.

    Args:
        a: tensor of shape (..., n1, n2)

    Returns:
        tensor of shape (..., n2, n1)"""

    matrix_shape = a.shape[-2:]
    bs_shape = a.shape[:-2]
    a = a.reshape((-1, *matrix_shape))
    a = a.transpose((0, 2, 1))
    a = a.reshape((*bs_shape, matrix_shape[1], matrix_shape[0]))
    return a
Пример #3
0
    def __call__(self, x: jnp.ndarray):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1],
                      -1)  # shape = [*, width, grid ** 2]
        x = x.transpose((0, 2, 1))  # shape = [*, grid ** 2, width]
        x = jnp.concatenate([
            self.class_embedding + jnp.zeros(
                (x.shape[0], 1, x.shape[-1]), dtype=x.dtype), x
        ],
                            axis=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding

        x = self.ln_pre(x)
        x = x.transpose((1, 0, 2))  # NLD -> LND

        x = self.transformer(x)
        x = x.transpose((1, 0, 2))  # LND -> NLD

        x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj

        return x
Пример #4
0
def adj(a: jnp.ndarray) -> jnp.ndarray:
    """Returns adjoint matrix.

    Args:
        a: complex valued tensor of shape (..., n1, n2)

    Returns:
        complex valued tensor of shape (..., n2, n1)"""

    matrix_shape = a.shape[-2:]
    bs_shape = a.shape[:-2]
    a = a.reshape((-1, *matrix_shape))
    a = a.transpose((0, 2, 1))
    a = a.reshape((*bs_shape, matrix_shape[1], matrix_shape[0]))
    a = a.conj()
    return a