Ejemplo n.º 1
0
 def __call__(self) -> Array:
     initial_values = hk.get_parameter("initial_values", [self.output_size],
                                       init=hk.initializers.Constant(0.))
     return initial_values
Ejemplo n.º 2
0
 def __call__(self) -> Array:
     likelihood = hk.get_parameter(
         "likelihood", [self.output_size],
         init=hk.initializers.RandomNormal(mean=-5))
     return aux_math.diag(jax.nn.softplus(likelihood))
Ejemplo n.º 3
0
 def field(x, aux, dropout: bool = False):
     mlp = nets.MLP(self.sizes)
     scale = hk.get_parameter("scale", (),
                              init=lambda *args: np.ones(*args))
     mlp_input = np.concatenate([x, aux]) if self.aux else x
     return scale * mlp(mlp_input, dropout)
Ejemplo n.º 4
0
    def __call__(self, inputs_1d, inputs_2d, mask, affine):
        """Compute geometry-aware attention.

    Given a set of query residues (defined by affines and associated scalar
    features), this function computes geometry-aware attention between the
    query residues and target residues.

    The residues produce points in their local reference frame, which
    are converted into the global frame in order to compute attention via
    euclidean distance.

    Equivalently, the target residues produce points in their local frame to be
    used as attention values, which are converted into the query residues'
    local frames.

    Args:
      inputs_1d: (N, C) 1D input embedding that is the basis for the
        scalar queries.
      inputs_2d: (N, M, C') 2D input embedding, used for biases and values.
      mask: (N, 1) mask to indicate which elements of inputs_1d participate
        in the attention.
      affine: QuatAffine object describing the position and orientation of
        every element in inputs_1d.

    Returns:
      Transformation of the input embedding.
    """
        num_residues, _ = inputs_1d.shape

        # Improve readability by removing a large number of 'self's.
        num_head = self.config.num_head
        num_scalar_qk = self.config.num_scalar_qk
        num_point_qk = self.config.num_point_qk
        num_scalar_v = self.config.num_scalar_v
        num_point_v = self.config.num_point_v
        num_output = self.config.num_channel

        assert num_scalar_qk > 0
        assert num_point_qk > 0
        assert num_point_v > 0

        # Construct scalar queries of shape:
        # [num_query_residues, num_head, num_points]
        q_scalar = common_modules.Linear(num_head * num_scalar_qk,
                                         name='q_scalar')(inputs_1d)
        q_scalar = jnp.reshape(q_scalar,
                               [num_residues, num_head, num_scalar_qk])

        # Construct scalar keys/values of shape:
        # [num_target_residues, num_head, num_points]
        kv_scalar = common_modules.Linear(num_head *
                                          (num_scalar_v + num_scalar_qk),
                                          name='kv_scalar')(inputs_1d)
        kv_scalar = jnp.reshape(
            kv_scalar, [num_residues, num_head, num_scalar_v + num_scalar_qk])
        k_scalar, v_scalar = jnp.split(kv_scalar, [num_scalar_qk], axis=-1)

        # Construct query points of shape:
        # [num_residues, num_head, num_point_qk]

        # First construct query points in local frame.
        q_point_local = common_modules.Linear(num_head * 3 * num_point_qk,
                                              name='q_point_local')(inputs_1d)
        q_point_local = jnp.split(q_point_local, 3, axis=-1)
        # Project query points into global frame.
        q_point_global = affine.apply_to_point(q_point_local, extra_dims=1)
        # Reshape query point for later use.
        q_point = [
            jnp.reshape(x, [num_residues, num_head, num_point_qk])
            for x in q_point_global
        ]

        # Construct key and value points.
        # Key points have shape [num_residues, num_head, num_point_qk]
        # Value points have shape [num_residues, num_head, num_point_v]

        # Construct key and value points in local frame.
        kv_point_local = common_modules.Linear(
            num_head * 3 * (num_point_qk + num_point_v),
            name='kv_point_local')(inputs_1d)
        kv_point_local = jnp.split(kv_point_local, 3, axis=-1)
        # Project key and value points into global frame.
        kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1)
        kv_point_global = [
            jnp.reshape(x,
                        [num_residues, num_head, (num_point_qk + num_point_v)])
            for x in kv_point_global
        ]
        # Split key and value points.
        k_point, v_point = list(
            zip(*[
                jnp.split(x, [
                    num_point_qk,
                ], axis=-1) for x in kv_point_global
            ]))

        # We assume that all queries and keys come iid from N(0, 1) distribution
        # and compute the variances of the attention logits.
        # Each scalar pair (q, k) contributes Var q*k = 1
        scalar_variance = max(num_scalar_qk, 1) * 1.
        # Each point pair (q, k) contributes Var [0.5 ||q||^2 - <q, k>] = 9 / 2
        point_variance = max(num_point_qk, 1) * 9. / 2

        # Allocate equal variance to scalar, point and attention 2d parts so that
        # the sum is 1.

        num_logit_terms = 3

        scalar_weights = np.sqrt(1.0 / (num_logit_terms * scalar_variance))
        point_weights = np.sqrt(1.0 / (num_logit_terms * point_variance))
        attention_2d_weights = np.sqrt(1.0 / (num_logit_terms))

        # Trainable per-head weights for points.
        trainable_point_weights = jax.nn.softplus(
            hk.get_parameter(
                'trainable_point_weights',
                shape=[num_head],
                # softplus^{-1} (1)
                init=hk.initializers.Constant(np.log(np.exp(1.) - 1.))))
        point_weights *= jnp.expand_dims(trainable_point_weights, axis=1)

        v_point = [jnp.swapaxes(x, -2, -3) for x in v_point]

        q_point = [jnp.swapaxes(x, -2, -3) for x in q_point]
        k_point = [jnp.swapaxes(x, -2, -3) for x in k_point]
        dist2 = [
            squared_difference(qx[:, :, None, :], kx[:, None, :, :])
            for qx, kx in zip(q_point, k_point)
        ]
        dist2 = sum(dist2)
        attn_qk_point = -0.5 * jnp.sum(point_weights[:, None, None, :] * dist2,
                                       axis=-1)

        v = jnp.swapaxes(v_scalar, -2, -3)
        q = jnp.swapaxes(scalar_weights * q_scalar, -2, -3)
        k = jnp.swapaxes(k_scalar, -2, -3)
        attn_qk_scalar = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
        attn_logits = attn_qk_scalar + attn_qk_point

        attention_2d = common_modules.Linear(num_head,
                                             name='attention_2d')(inputs_2d)

        attention_2d = jnp.transpose(attention_2d, [2, 0, 1])
        attention_2d = attention_2d_weights * attention_2d
        attn_logits += attention_2d

        mask_2d = mask * jnp.swapaxes(mask, -1, -2)
        attn_logits -= 1e5 * (1. - mask_2d)

        # [num_head, num_query_residues, num_target_residues]
        attn = jax.nn.softmax(attn_logits)

        # [num_head, num_query_residues, num_head * num_scalar_v]
        result_scalar = jnp.matmul(attn, v)

        # For point result, implement matmul manually so that it will be a float32
        # on TPU.  This is equivalent to
        # result_point_global = [jnp.einsum('bhqk,bhkc->bhqc', attn, vx)
        #                        for vx in v_point]
        # but on the TPU, doing the multiply and reduce_sum ensures the
        # computation happens in float32 instead of bfloat16.
        result_point_global = [
            jnp.sum(attn[:, :, :, None] * vx[:, None, :, :], axis=-2)
            for vx in v_point
        ]

        # [num_query_residues, num_head, num_head * num_(scalar|point)_v]
        result_scalar = jnp.swapaxes(result_scalar, -2, -3)
        result_point_global = [
            jnp.swapaxes(x, -2, -3) for x in result_point_global
        ]

        # Features used in the linear output projection. Should have the size
        # [num_query_residues, ?]
        output_features = []

        result_scalar = jnp.reshape(result_scalar,
                                    [num_residues, num_head * num_scalar_v])
        output_features.append(result_scalar)

        result_point_global = [
            jnp.reshape(r, [num_residues, num_head * num_point_v])
            for r in result_point_global
        ]
        result_point_local = affine.invert_point(result_point_global,
                                                 extra_dims=1)
        output_features.extend(result_point_local)

        output_features.append(
            jnp.sqrt(self._dist_epsilon + jnp.square(result_point_local[0]) +
                     jnp.square(result_point_local[1]) +
                     jnp.square(result_point_local[2])))

        # Dimensions: h = heads, i and j = residues,
        # c = inputs_2d channels
        # Contraction happens over the second residue dimension, similarly to how
        # the usual attention is performed.
        result_attention_over_2d = jnp.einsum('hij, ijc->ihc', attn, inputs_2d)
        num_out = num_head * result_attention_over_2d.shape[-1]
        output_features.append(
            jnp.reshape(result_attention_over_2d, [num_residues, num_out]))

        final_init = 'zeros' if self._zero_initialize_last else 'linear'

        final_act = jnp.concatenate(output_features, axis=-1)

        return common_modules.Linear(num_output,
                                     initializer=final_init,
                                     name='output_projection')(final_act)
Ejemplo n.º 5
0
 def fn(x):
     return x * hk.get_parameter('p', [],
                                 init=hk.initializers.Constant(0.))
Ejemplo n.º 6
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: PRNGKey,
             sample: Optional[bool] = False,
             reconstruction: Optional[bool] = False,
             manifold_sample: Optional[bool] = False,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        assert len(
            self.unbatched_input_shapes["x"]) == 1, "Only works with 1d inputs"
        assert self.z_dim < self.x_dim

        dtype = inputs["x"].dtype
        init_fun = hk.initializers.RandomNormal(0.01)
        A = hk.get_parameter("A",
                             shape=(self.x_dim, self.z_dim),
                             dtype=dtype,
                             init=init_fun)
        b = hk.get_parameter("b",
                             shape=(self.x_dim, ),
                             dtype=dtype,
                             init=init_fun)
        log_diag_cov = hk.get_parameter("log_diag_cov",
                                        shape=(self.x_dim, ),
                                        dtype=dtype,
                                        init=init_fun)
        diag_cov = jnp.exp(log_diag_cov)

        # Go from x -> z or z -> x
        if sample == False:
            x = inputs["x"]
            x -= b

            # Compute the posterior natural parameters
            J = jnp.eye(self.z_dim) + (A.T / diag_cov) @ A
            J_inv = jnp.linalg.inv(J)
            sigma_inv_x = x / diag_cov
            h = jnp.dot(sigma_inv_x, A)

            # Compute the posterior parameters
            Sigma_z = J_inv
            mu_z = jnp.dot(h, Sigma_z)

            # Sample z
            Sigma_z_chol = jnp.linalg.cholesky(Sigma_z)
            noise = random.normal(rng, mu_z.shape)
            z = mu_z + jnp.dot(noise, Sigma_z_chol.T)

            # Compute the log likelihood contribution
            J_inv_h = jnp.dot(h, J_inv.T)

            llc = 0.5 * jnp.sum(h * J_inv_h, axis=-1)
            llc -= 0.5 * jnp.linalg.slogdet(J)[1]
            llc -= 0.5 * jnp.sum(x * sigma_inv_x, axis=-1)
            llc -= 0.5 * log_diag_cov.sum()
            llc -= 0.5 * self.x_dim * jnp.log(2 * jnp.pi)

            outputs = {"x": z, "log_pz": llc}

        else:
            k1, k2 = random.split(rng, 2)
            z = inputs["x"]

            if reconstruction == False:
                z = random.normal(k1, z.shape)

            # Sample x
            mu_x = jnp.dot(z, A.T) + b
            if manifold_sample == True:
                noise = random.normal(k2, mu_x.shape)
            else:
                noise = jnp.zeros_like(mu_x)

            x = mu_x + jnp.sqrt(diag_cov) * noise

            # If we're doing a reconstruction, we need to compute log p(x|z)
            llc = -0.5 * jnp.sum(noise**2, axis=-1)
            llc -= 0.5 * jnp.sum(log_diag_cov)
            llc -= 0.5 * self.x_dim * jnp.log(2 * jnp.pi)

            outputs = {"x": x, "log_pz": llc}

        return outputs
Ejemplo n.º 7
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: jnp.ndarray = None,
             sample: Optional[bool] = False,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        outputs = {}

        dim, dtype = inputs["x"].shape[-1], inputs["x"].dtype

        L = hk.get_parameter("L",
                             shape=(dim, dim),
                             dtype=dtype,
                             init=hk.initializers.RandomNormal(0.01))
        U = hk.get_parameter("U",
                             shape=(dim, dim),
                             dtype=dtype,
                             init=hk.initializers.RandomNormal(0.01))
        log_d = hk.get_parameter("log_d",
                                 shape=(dim, ),
                                 dtype=dtype,
                                 init=jnp.zeros)
        lower_mask = jnp.ones((dim, dim), dtype=bool)
        lower_mask = jax.ops.index_update(lower_mask, jnp.triu_indices(dim),
                                          False)

        if self.safe_diag:
            d = util.proximal_relu(log_d) + 1e-5
            log_d = jnp.log(d)

        if self.use_bias:

            def b_init(shape, dtype):
                x = inputs["x"]
                if x.ndim == 1:
                    return jnp.zeros(shape, dtype=dtype)

                # Initialize to the batch mean
                z = jnp.dot(x, (U * lower_mask.T).T) + x
                z *= jnp.exp(log_d)
                z = jnp.dot(z, (L * lower_mask).T) + z
                b = -jnp.mean(z, axis=0)
                return b

            b = hk.get_parameter("b", shape=(dim, ), dtype=dtype, init=b_init)

        # Its way faster to allocate a full matrix for L and U and then mask than it
        # is to allocate only the lower/upper parts and the reshape.
        if sample == False:
            x = inputs["x"]
            z = jnp.dot(x, (U * lower_mask.T).T) + x
            z *= jnp.exp(log_d)
            z = jnp.dot(z, (L * lower_mask).T) + z
            outputs["x"] = z
            if self.use_bias:
                outputs["x"] += b
        else:
            z = inputs["x"]

            @self.auto_batch
            def invert(z):
                if self.use_bias:
                    x = L_solve(L, z - b)
                else:
                    x = L_solve(L, z)
                x = x * jnp.exp(-log_d)
                return U_solve(U, x)

            outputs["x"] = invert(z)

        outputs["log_det"] = jnp.sum(log_d, axis=-1) * jnp.ones(
            self.batch_shape)
        return outputs
Ejemplo n.º 8
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: jnp.ndarray = None,
             sample: Optional[bool] = False,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        outputs = {}
        x = inputs["x"]
        height, width, channel = x.shape[-3:]

        # Using lax.conv instead of matrix multiplication over the channel dimension
        # is faster and also more numerically stable for some reason.
        @partial(self.auto_batch, in_axes=(None, 0), expected_depth=1)
        def conv(W, x):
            return jax.lax.conv_general_dilated(x,
                                                W[None, None, ...], (1, 1),
                                                'SAME', (1, 1), (1, 1),
                                                dimension_numbers=('NHWC',
                                                                   'HWIO',
                                                                   'NHWC'))

        dtype = x.dtype
        W = hk.get_parameter("W",
                             shape=(channel, channel),
                             dtype=dtype,
                             init=self.W_init)

        # Initialize with weight norm https://arxiv.org/pdf/1602.07868.pdf
        # This seems to improve performance.
        if self.weight_norm and x.ndim > 3:
            W *= jax.lax.rsqrt(jnp.sum(W**2, axis=0))

            def g_init(shape, dtype):
                t = conv(W, x)
                g = 1 / (jnp.std(t, axis=(0, 1, 2)) + 1e-5)
                return g

            def b_init(shape, dtype):
                t = conv(W, x)
                return -jnp.mean(t, axis=(0, 1, 2)) / (
                    jnp.std(t, axis=(0, 1, 2)) + 1e-5)

            g = hk.get_parameter("g", (channel, ), dtype, init=g_init)
            b = hk.get_parameter("b", (channel, ), dtype, init=b_init)

            W *= g

        else:
            b = hk.get_parameter("b",
                                 shape=(channel, ),
                                 dtype=dtype,
                                 init=jnp.zeros)

        # Run the flow
        if sample == False:
            z = conv(W, x)
            outputs["x"] = z + b
        else:
            W_inv = jnp.linalg.inv(W)
            outputs["x"] = conv(W_inv, x - b)

        outputs["log_det"] = jnp.linalg.slogdet(
            W)[1] * height * width * jnp.ones(self.batch_shape)

        return outputs
Ejemplo n.º 9
0
  def call(self,
           inputs: Mapping[str, jnp.ndarray],
           rng: PRNGKey,
           sample: Optional[bool]=False,
           no_noise: Optional[bool]=False,
           **kwargs
  ) -> Mapping[str, jnp.ndarray]:

    # p(gamma|s) = N(gamma|mu(s), Sigma(s))
    if self.image_in:
      out_shape = self.input_shape[:-1] + (2*self.input_shape[-1],)
    else:
      out_shape = (2*self.big_dim,)
    self.p_gamma_given_s = vae.ParametrizedGaussian(out_shape=out_shape,
                                                    create_network=self.create_network,
                                                    network_kwargs=self.network_kwargs)

    #######################
    assert self.big_dim - self.small_dim > 0

    # Initialize the tall or wide matrix.  We might want to choose to parametrize a tall
    # matrix as the pseudo-inverse of a wide matrix or vice-versa.  B is wide and A is tall.
    init_fun = hk.initializers.RandomNormal(stddev=0.05)
    dtype = inputs["x"].dtype
    if self.reverse_params:
      x = inputs["x"].reshape(self.batch_shape + (-1,))

      if self.spectral_norm:
        self.B = init.weight_with_spectral_norm(x,
                                                self.small_dim,
                                                use_bias=False,
                                                w_init=init_fun,
                                                force_in_dim=self.big_dim,
                                                is_training=kwargs.get("is_training", True),
                                                update_params=kwargs.get("is_training", True))
      else:
        if self.weight_norm and self.kind == "tall":
          self.B = init.weight_with_weight_norm(x, self.small_dim, use_bias=False, force_in_dim=self.big_dim)
        else:
          self.B = hk.get_parameter("B", shape=(self.small_dim, self.big_dim), dtype=dtype, init=init_fun)
        self.B = util.whiten(self.B)
    else:
      if self.spectral_norm:
        self.A = init.weight_with_spectral_norm(x,
                                                self.big_dim,
                                                use_bias=False,
                                                w_init=init_fun,
                                                force_in_dim=self.small_dim,
                                                is_training=kwargs.get("is_training", True),
                                                update_params=kwargs.get("is_training", True))
      else:
        self.A = hk.get_parameter("A", shape=(self.big_dim, self.small_dim), dtype=dtype, init=init_fun)
        self.A = util.whiten(self.A)

    # Compute the riemannian metric matrix for later use.
    if self.reverse_params:
      self.BBT     = [email protected]
      self.BBT_inv = jnp.linalg.inv(self.BBT)
    else:
      self.ATA     = [email protected]
      self.ATA_inv = jnp.linalg.inv(self.ATA)

    #######################

    # Figure out which direction we should go
    if sample == False:
      big_to_small = True if self.kind == "tall" else False
    else:
      big_to_small = False if self.kind == "tall" else True

    #######################

    # Compute the next value
    if big_to_small:
      t = inputs["x"]

      # If we're going from image -> vector, we need to flatten the image
      if self.image_in:
        t = t.reshape(self.batch_shape + (-1,))

      # Compute the pseudo inverse and projection
      # s <- self.A^+t
      s = self.pinv(t)
      t_proj = self.project(s=s)

      # Compute the perpendicular component of t for the log contribution
      # gamma_perp <- t - AA^+t
      gamma_perp = t - t_proj

      # Find mu(s), Sigma(s).  If we have an image as input, pass in the projected input image
      # mu, Sigma <- NN(s, theta)
      _, mu, log_diag_cov = self.orthogonal_distribution(s, t_proj, rng, no_noise=True)

      # Compute the log contribution
      # L <- logZ(mu - gamma_perp|self.A, Sigma)
      likelihood_contribution = self.likelihood_contribution(mu, gamma_perp, log_diag_cov, sample=sample, big_to_small=big_to_small)

      outputs = {"x": s, "log_det": likelihood_contribution}

    else:
      s = inputs["x"]

      # Compute the mean of t.  Primarily used if we have an image as input
      t_mean = self.project(s=s)

      # Find mu(s), Sigma(s).  If we have an image as input, pass in the projected input image
      # mu, Sigma <- NN(s, theta)
      # gamma ~ N(mu, Sigma)
      gamma, mu, log_diag_cov = self.orthogonal_distribution(s, t_mean, rng, no_noise=no_noise)

      # Compute the orthogonal component of the noise
      # gamma_perp <- gamma - AA^+ gamma
      gamma_proj = self.project(t=gamma)
      gamma_perp = gamma - gamma_proj

      # Add the orthogonal features
      # t <- As + gamma_perp
      t = t_mean + gamma_perp

      # Compute the log contribution
      # L <- logZ(mu - gamma_perp|self.A, Sigma)
      likelihood_contribution = -self.likelihood_contribution(mu, gamma_perp, log_diag_cov, sample=sample, big_to_small=big_to_small)

      # Reshape to an image if needed
      if self.image_in:
        t = t.reshape(self.batch_shape + self.input_shape)

      outputs = {"x": t, "log_det": likelihood_contribution}

    return outputs
Ejemplo n.º 10
0
 def mass_matrix_inv_mul(self, q: jnp.ndarray, v: jnp.ndarray,
                         **kwargs) -> jnp.ndarray:
     """Computes the product of the inverse mass matrix with a vector."""
     if self.kinetic_func_form in ("separable_net", "dep_net"):
         raise ValueError(
             "It is not possible to compute `M^-1 p` when using a "
             "network for the kinetic energy.")
     if self.kinetic_func_form in ("pure_quad", "embed_quad"):
         return v
     if self.kinetic_func_form == "matrix_diag_quad":
         if self.parametrize_mass_matrix:
             m_diag_log = hk.get_parameter(
                 "MassMatrixDiagLog",
                 shape=[self.system_dim],
                 init=hk.initializers.Constant(0.0))
             m_inv_diag = 1.0 / (jnp.exp(m_diag_log) + self.mass_eps)
         else:
             m_inv_diag_log = hk.get_parameter(
                 "InvMassMatrixDiagLog",
                 shape=[self.system_dim],
                 init=hk.initializers.Constant(0.0))
             m_inv_diag = jnp.exp(m_inv_diag_log) + self.mass_eps
         return m_inv_diag * v
     if self.kinetic_func_form == "matrix_quad":
         if self.parametrize_mass_matrix:
             m_triu = hk.get_parameter(
                 "MassMatrixU",
                 shape=[self.system_dim, self.system_dim],
                 init=hk.initializers.Identity())
             m_triu = jnp.triu(m_triu)
             m = jnp.matmul(m_triu.T, m_triu)
             m = m + self.mass_eps * jnp.eye(self.system_dim)
             solve = jnp.linalg.solve
             for _ in range(v.ndim + 1 - m.ndim):
                 solve = jax.vmap(solve, in_axes=(None, 0))
             return solve(m, v)
         else:
             m_inv_triu = hk.get_parameter(
                 "InvMassMatrixU",
                 shape=[self.system_dim, self.system_dim],
                 init=hk.initializers.Identity())
             m_inv_triu = jnp.triu(m_inv_triu)
             m_inv = jnp.matmul(m_inv_triu.T, m_inv_triu)
             m_inv = m_inv + self.mass_eps * jnp.eye(self.system_dim)
             return self.feature_matrix_vector(m_inv, v)
     if self.kinetic_func_form in ("matrix_dep_diag_quad",
                                   "matrix_dep_diag_embed_quad"):
         if self.parametrize_mass_matrix:
             m_diag_log = self.mass_matrix_net(q, **kwargs)
             m_inv_diag = 1.0 / (jnp.exp(m_diag_log) + self.mass_eps)
         else:
             m_inv_diag_log = self.mass_matrix_net(q, **kwargs)
             m_inv_diag = jnp.exp(m_inv_diag_log) + self.mass_eps
         return m_inv_diag * v
     if self.kinetic_func_form in ("matrix_dep_quad",
                                   "matrix_dep_embed_quad"):
         if self.parametrize_mass_matrix:
             m_triu = self.mass_matrix_net(q, **kwargs)
             m_triu = utils.triu_matrix_from_v(m_triu, self.system_dim)
             m = jnp.matmul(jnp.swapaxes(m_triu, -2, -1), m_triu)
             m = m + self.mass_eps * jnp.eye(self.system_dim)
             return jnp.linalg.solve(m, v)
         else:
             m_inv_triu = self.mass_matrix_net(q, **kwargs)
             m_inv_triu = utils.triu_matrix_from_v(m_inv_triu,
                                                   self.system_dim)
             m_inv = jnp.matmul(jnp.swapaxes(m_inv_triu, -2, -1),
                                m_inv_triu)
             m_inv = m_inv + self.mass_eps * jnp.eye(self.system_dim)
             return self.feature_matrix_vector(m_inv, v)
     raise NotImplementedError()
Ejemplo n.º 11
0
def apply_sn(*,
             mvp,
             mvpT,
             w_shape,
             b_shape,
             out_shape,
             dtype,
             w_init,
             b_init,
             name_suffix,
             is_training,
             use_bias,
             max_singular_value,
             max_power_iters,
             use_proximal_gradient=False,
             monitor_progress=False,
             monitor_iters=20,
             return_sigma=False,
             **kwargs):

    w_exists = util.check_if_parameter_exists(f"w_{name_suffix}")

    w = hk.get_parameter(f"w_{name_suffix}", w_shape, dtype, init=w_init)
    u = hk.get_state(f"u_{name_suffix}",
                     out_shape,
                     dtype,
                     init=hk.initializers.RandomNormal())
    if use_proximal_gradient == False:
        zeta = hk.get_state(f"zeta_{name_suffix}",
                            out_shape,
                            dtype,
                            init=hk.initializers.RandomNormal())
        state = (u, zeta)
    else:
        state = (u, )

    if use_proximal_gradient == False:
        estimate_max_singular_value = jax.jit(sn.max_singular_value,
                                              static_argnums=(0, 1))
    else:
        estimate_max_singular_value = jax.jit(sn.max_singular_value_no_grad,
                                              static_argnums=(0, 1))

    if w_exists == False:
        max_power_iters = 1000

    if monitor_progress:
        estimates = []

    for i in range(max_power_iters):
        sigma, *state = estimate_max_singular_value(mvp, mvpT, w, *state)
        if monitor_progress:
            estimates.append(sigma)

    if monitor_progress:
        sigma_for_test = sigma
        state_for_test = state
        for i in range(monitor_iters - max_power_iters):
            sigma_for_test, *state_for_test = estimate_max_singular_value(
                mvp, mvpT, w, *state_for_test)
            estimates.append(sigma_for_test)

        estimates = jnp.array(estimates)

        sigma_for_test = jax.lax.stop_gradient(sigma_for_test)
        state_for_test = jax.lax.stop_gradient(state_for_test)

    state = jax.lax.stop_gradient(state)

    if is_training == True or w_exists == False:
        u = state[0]
        hk.set_state(f"u_{name_suffix}", u)
        if use_proximal_gradient == False:
            zeta = state[1]
            hk.set_state(f"zeta_{name_suffix}", zeta)

    if return_sigma == False:
        factor = jnp.where(max_singular_value < sigma,
                           max_singular_value / sigma, 1.0)
        w = w * factor
        w_ret = w
    else:
        w_ret = (w, sigma)

    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", b_shape, dtype, init=b_init)
        ret = (w_ret, b)
    else:
        ret = w_ret

    if monitor_progress:
        ret = (ret, estimates)

    return ret
Ejemplo n.º 12
0
    def safe_init(self,
                  x,
                  weight_logits,
                  means,
                  log_scales,
                  log_s=None,
                  t=None,
                  conditioned_params=False):
        """ We want to initialize this to be close to the identity funtion
        but also want the means to be spread out initially
    """
        wl_scale = hk.get_parameter("weight_logits_scale",
                                    shape=(),
                                    dtype=x.dtype,
                                    init=jnp.zeros)
        ls_scale = hk.get_parameter("log_scales_scale",
                                    shape=(),
                                    dtype=x.dtype,
                                    init=jnp.zeros)

        weight_logits *= wl_scale
        log_scales *= ls_scale

        if self.with_affine_coupling:
            batch_dim = x.ndim - len(self.batch_shape)
            name_prefix = "transform_" if conditioned_params else ""

            # Initialize log_s to divide off the stddev
            def log_s_shift_init(shape, dtype):
                if x.ndim == len(shape):
                    return jnp.zeros(shape, dtype)

                z = self.f(weight_logits, means, log_scales, x)
                axes = tuple(jnp.arange(len(z.shape) - len(shape)))
                return jnp.log(jnp.std(z, axis=axes) + 1e-5)

            log_s_shape = log_s.shape[
                batch_dim:] if conditioned_params else log_s.shape
            log_s_shift = hk.get_parameter(f"{name_prefix}log_s_shift",
                                           shape=log_s_shape,
                                           dtype=x.dtype,
                                           init=log_s_shift_init)
            log_s_scale = hk.get_parameter(f"{name_prefix}log_s_scale",
                                           shape=log_s_shape,
                                           dtype=x.dtype,
                                           init=jnp.zeros)

            # Constrain between -1 and 1 so that things don't blow up
            log_s_shift = -jnp.maximum(-1.0, -log_s_shift)
            log_s_shift = jnp.maximum(-1.0, log_s_shift)
            log_s = log_s * log_s_scale + log_s_shift

            # Initialize t to subtract off the mean
            def t_shift_init(shape, dtype):
                if x.ndim == len(shape):
                    return jnp.zeros(shape, dtype)

                z = self.f(weight_logits, means, log_scales, x)
                axes = tuple(jnp.arange(len(z.shape) - len(shape)))
                return jnp.mean(z, axis=axes)

            name_prefix = "transform_" if conditioned_params else ""
            t_shape = t.shape[batch_dim:] if conditioned_params else t.shape
            t_shift = hk.get_parameter(f"{name_prefix}t_shift",
                                       shape=t_shape,
                                       dtype=x.dtype,
                                       init=t_shift_init)
            t_scale = hk.get_parameter(f"{name_prefix}t_scale",
                                       shape=t_shape,
                                       dtype=x.dtype,
                                       init=jnp.zeros)
            t = t * t_scale + t_shift

            return weight_logits, means, log_scales, log_s, t

        return weight_logits, means, log_scales
Ejemplo n.º 13
0
 def __call__(self, x):
     alpha = hk.get_parameter("alpha",
                              shape=(),
                              dtype=x.dtype,
                              init=self.alpha_init)
     return jnp.where(x < 0, alpha * x, x)
Ejemplo n.º 14
0
    def __call__(self,
                 x: jnp.ndarray,
                 mask: Optional[jnp.ndarray] = None,
                 is_training: bool = True,
                 should_reset: Optional[jnp.ndarray] = None,
                 cache_steps: int = 0,
                 extra: Optional[jnp.ndarray] = None,
                 extra_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray:
        """Computes the outputs of the TransformerXL.

    Args:
      x: [batch, timesteps]. Inputs at time step t.
      mask: [batch, timesteps]. It indicates what tokens to be predicted. In
        other words it corresponds to non-pad tokens in x_{t+1}.
      is_training: whether the current stage is training or not.
      should_reset: reset marker [batch, timesteps].
      cache_steps: number of timesteps in the cache.
      extra: if provided should be extra key-value input
        [batch, extra_timesteps, in_dim].
      extra_mask: if provided should be the mask for extra key-value input,
        [batch, extra_timesteps].

    Returns:
      output: transformer output [batch, timesteps].
    """
        if cache_steps == 0:
            cache_steps = x.shape[1]
        if should_reset is None:
            should_reset = jnp.where(x == 1, 1, 0)
        h = self._io_emb.embed_input(x)

        if mask is not None:
            attention_mask = mask[:, None, None, :]
        else:
            attention_mask = None

        head_dim = self._emb_dim // self._num_heads
        assert self._emb_dim % self._num_heads == 0, 'Head dim should be an int.'

        # Biases for relative position embedding shared across all layers
        r_w_bias = hk.get_parameter(
            'r_w_bias', [1, 1, self._num_heads, head_dim],
            init=init.RandomNormal(stddev=self._self_att_init_scale))
        r_r_bias = hk.get_parameter(
            'r_r_bias', [1, 1, self._num_heads, head_dim],
            init=init.RandomNormal(stddev=self._self_att_init_scale))

        for i in range(self._num_layers):
            if mask is not None:
                h *= mask[:, :, None]
            h = transformer_block.GPT2Block(
                r_w_bias=r_w_bias,
                r_r_bias=r_r_bias,
                causal=True,
                dense_dim=self._dense_dim,
                dropout_prob=self._dropout_prob,
                dropout_attn_prob=self._dropout_attn_prob,
                num_heads=self._num_heads,
                self_att_init_scale=self._self_att_init_scale,
                dense_init_scale=self._dense_init_scale,
                relative_pos_clamp_len=self._relative_pos_clamp_len,
                name='transformer_block_{}'.format(i),
            )(h,
              mask=attention_mask,
              is_training=is_training,
              should_reset=should_reset,
              cache_steps=cache_steps,
              extra=extra,
              extra_mask=extra_mask)

        if mask is not None:
            h *= mask[:, :, None]
        return self._io_emb.embed_output(h)
Ejemplo n.º 15
0
    def __call__(self,
                 inputs,
                 targets=None,
                 train_forced=True,
                 return_bare_loss=False):
        dimensionality = self.dimensionality
        num_symbols = self.num_symbols
        batch_size, input_seq_length = inputs.shape

        # input handling
        oh_inputs = jax.nn.one_hot(inputs, num_classes=num_symbols)

        embedding = hk.Linear(dimensionality)
        embedded_inputs = embedding(oh_inputs)

        input_position_embeddings = positional_encodings(
            input_seq_length, dimensionality)
        encoder_outputs = embedded_inputs + input_position_embeddings

        # encoder stack

        for layer_i in range(self.num_encoder_layers):
            encoder_outputs = EncoderLayer(
                dimensionality=dimensionality,
                num_heads=self.num_heads)(encoder_outputs)

        # target handling
        if targets is not None:
            oh_targets = jax.nn.one_hot(targets, num_classes=num_symbols)

        # decoder stack
        start_token = hk.get_parameter("start",
                                       shape=[1, 1, dimensionality],
                                       init=hk.initializers.TruncatedNormal(
                                           1. / np.sqrt(dimensionality)))
        start_token = jnp.tile(start_token, [batch_size, 1, 1])

        output_embedding = hk.Linear(dimensionality)

        decoder_layers = [
            DecoderLayer(dimensionality=dimensionality,
                         num_heads=self.num_heads)
            for layer_i in range(self.num_decoder_layers)
        ]

        output_decoding = hk.Linear(num_symbols)

        if train_forced:  # train decoding w/ teacher forcing
            encoder_inputs = output_embedding(oh_targets[:, :-1, :])
            encoder_inputs = jnp.concatenate([start_token, encoder_inputs],
                                             axis=1)
            results = encoder_inputs
            for decoder_layer in decoder_layers:
                results = decoder_layer(encoder_outputs=encoder_outputs,
                                        inputs=results)

            output_logits = output_decoding(results)
            hard_outputs = jnp.argmax(output_logits, axis=-1)

        else:  # not train_forced, eval decoding (one step at a time)
            encoder_inputs = start_token

            output_logits = []
            hard_outputs = []
            for output_step_i in range(self.output_seq_length):
                results = encoder_inputs
                for decoder_layer in decoder_layers:
                    results = decoder_layer(encoder_outputs=encoder_outputs,
                                            inputs=results)

                this_output_logits = output_decoding(results[:, -1:, :])
                this_hard_output = jnp.argmax(this_output_logits, axis=-1)
                output_logits.append(this_output_logits)
                hard_outputs.append(this_hard_output)

                # cue next with previous output
                this_encoded_hard_output = jax.nn.one_hot(
                    this_hard_output, num_classes=num_symbols)
                this_encoded_hard_output = output_embedding(
                    this_encoded_hard_output)
                #                this_encoded_hard_output = jnp.expand_dims(
                #                    this_encoded_hard_output, axis=1)
                encoder_inputs = jnp.concatenate(
                    [encoder_inputs, this_encoded_hard_output], axis=1)

            output_logits = jnp.concatenate(output_logits, axis=1)
            hard_outputs = jnp.concatenate(hard_outputs, axis=1)

        if targets is not None:
            loss = softmax_xe(output_logits, oh_targets)
            total_loss = jnp.mean(jnp.sum(loss, axis=-1))

            if return_bare_loss:
                return total_loss

            accuracy = jnp.mean(hard_outputs == targets)

            return {
                "outputs": hard_outputs,
                "loss": total_loss,
                "accuracy": accuracy
            }
        else:
            return {"outputs": hard_outputs}
Ejemplo n.º 16
0
    def __init__(self,
                 dim: int,
                 vocab_size: int,
                 cutoffs: List[int],
                 tail_shrink_factor: int = 4,
                 hierarchical: bool = True,
                 init_std: float = 0.02,
                 init_proj_std: float = 0.01,
                 dtype: jnp.dtype = jnp.float32,
                 name: Optional[str] = None):
        """Initialize a AdaptiveSoftmaxEmbedding.

    Args:
      dim: dimensionality of the hidden space.
      vocab_size: the size of the vocabulary.
      cutoffs: the cutoff indices of the vocabulary used for the adaptive
        softmax embedding.
      tail_shrink_factor: how many times to shrink the hidden dimensionality
        for low-frequency vocabulary after each cutoff.
      hierarchical: whether to use hierarchical softmax.
      init_std: standard deviation of the Normal distribution used to initialize
        the embedding weights.
      init_proj_std: standard deviation of the Normal distribution used to
        initialize the projection weights.
      dtype: Optional data type default to jnp.float32.
      name: Optional name for this Haiku module.
    """
        super(AdaptiveSoftmaxEmbedding, self).__init__(name=name)
        self._hidden_size = dim
        self._vocab_size = vocab_size
        self._cutoffs = [0] + list(cutoffs) + [self._vocab_size]
        self._tail_shrink_factor = tail_shrink_factor
        self._hierarchical = hierarchical
        self._dtype = dtype
        self._embeddings = []
        self._projections = []

        self._bias = hk.get_parameter('bias', [self._vocab_size],
                                      dtype=self._dtype,
                                      init=jnp.zeros)

        l_cutoffs = self._cutoffs[:-1]
        r_cutoffs = self._cutoffs[1:]
        for i, (l_cutoff, r_cutoff) in enumerate(zip(l_cutoffs, r_cutoffs)):
            hidden_size = self._hidden_size // (self._tail_shrink_factor**i)
            embedding = hk.get_parameter(
                f'embeddings_{l_cutoff}_{r_cutoff}',
                [r_cutoff - l_cutoff, hidden_size],
                dtype=self._dtype,
                init=hk.initializers.RandomNormal(stddev=init_std))
            self._embeddings += [embedding]
            if self._tail_shrink_factor != 1:
                projection = hk.get_parameter(
                    f'projection_{l_cutoff}_{r_cutoff}',
                    [hidden_size, self._hidden_size],
                    dtype=self._dtype,
                    init=hk.initializers.RandomNormal(stddev=init_proj_std))
                self._projections += [projection]

        if self._tail_shrink_factor != 1:
            self._output_projection = hk.get_parameter(
                'output_head_projection',
                [self._hidden_size, self._hidden_size],
                dtype=self._dtype,
                init=hk.initializers.RandomNormal(stddev=init_proj_std))

        if self._hierarchical:
            self._class_weights = hk.get_parameter(
                'tail_class_weights',
                [self._hidden_size, len(cutoffs)],
                init=hk.initializers.RandomNormal(stddev=init_std))
            self._class_bias = hk.get_parameter('tail_class_bias',
                                                [len(cutoffs)],
                                                dtype=self._dtype,
                                                init=jnp.zeros)
Ejemplo n.º 17
0
    def __call__(
        self,
        inputs: jnp.ndarray,
        is_training: bool,
        test_local_stats: bool = False,
        scale: Optional[jnp.ndarray] = None,
        offset: Optional[jnp.ndarray] = None,
        return_lipschitz_const: bool = False,
    ) -> jnp.ndarray:
        """Computes the normalized version of the input.
    Args:
      inputs: An array, where the data format is ``[..., C]``.
      is_training: Whether this is during training.
      test_local_stats: Whether local stats are used when is_training=False.
      scale: An array up to n-D. The shape of this tensor must be broadcastable
        to the shape of ``inputs``. This is the scale applied to the normalized
        inputs. This cannot be passed in if the module was constructed with
        ``create_scale=True``.
      offset: An array up to n-D. The shape of this tensor must be broadcastable
        to the shape of ``inputs``. This is the offset applied to the normalized
        inputs. This cannot be passed in if the module was constructed with
        ``create_offset=True``.
    Returns:
      The array, normalized across all but the last dimension.
    """
        if self.create_scale and scale is not None:
            raise ValueError(
                "Cannot pass `scale` at call time if `create_scale=True`.")
        if self.create_offset and offset is not None:
            raise ValueError(
                "Cannot pass `offset` at call time if `create_offset=True`.")

        channel_index = self.channel_index
        if channel_index < 0:
            channel_index += inputs.ndim

        if self.axis is not None:
            axis = self.axis
        else:
            axis = [i for i in range(inputs.ndim) if i != channel_index]

        if is_training or test_local_stats:
            mean = jnp.mean(inputs, axis, keepdims=True)
            if self.mean_only == False:
                mean_of_squares = jnp.mean(inputs**2, axis, keepdims=True)
            if self.cross_replica_axis:
                mean = jax.lax.pmean(
                    mean,
                    axis_name=self.cross_replica_axis,
                    axis_index_groups=self.cross_replica_axis_index_groups)
                if self.mean_only == False:
                    mean_of_squares = jax.lax.pmean(
                        mean_of_squares,
                        axis_name=self.cross_replica_axis,
                        axis_index_groups=self.cross_replica_axis_index_groups)

            if self.mean_only == False:
                var = mean_of_squares - mean**2
        else:
            mean = self.mean_ema.average
            if self.mean_only == False:
                var = self.var_ema.average

        if is_training:
            self.mean_ema(mean)
            if self.mean_only == False:
                self.var_ema(var)

        w_shape = [
            1 if i in axis else inputs.shape[i] for i in range(inputs.ndim)
        ]
        w_dtype = inputs.dtype

        if self.mean_only == False:
            if self.create_scale:
                scale = hk.get_parameter("scale", w_shape, w_dtype,
                                         self.scale_init)
            elif scale is None:
                scale = np.ones([], dtype=w_dtype)

        if self.create_offset:
            offset = hk.get_parameter("offset", w_shape, w_dtype,
                                      self.offset_init)
        elif offset is None:
            offset = np.zeros([], dtype=w_dtype)

        if self.mean_only == False:
            eps = jax.lax.convert_element_type(self.eps, var.dtype)
            inv = scale * jax.lax.rsqrt(var + eps)
            ret = (inputs - mean) * inv + offset
            lip = jnp.max(inv)
        else:
            ret = inputs - mean + offset
            lip = 1.0

        if return_lipschitz_const:
            return ret, lip

        return ret
Ejemplo n.º 18
0
 def fn(x):
     return jnp.arctanh(x) * hk.get_parameter(
         'p', [], init=hk.initializers.Constant(0.))
Ejemplo n.º 19
0
    def __call__(self, inputs: jnp.ndarray,
                 latents: jnp.ndarray) -> jnp.ndarray:
        """Computes the transposed convolution of the input.
        Args:
        inputs: An array of shape ``[spatial_dims, C]`` and rank-N+1 if unbatched,
            or an array of shape ``[N, spatial_dims, C]`` and rank-N+2 if batched.
        Returns:
        An array of shape ``[spatial_dims, output_channels]`` and rank-N+1 if
            unbatched, or an array of shape ``[N, spatial_dims, output_channels]``
            and rank-N+2 if batched.
        """
        assert self.mask is None
        unbatched_rank = self.num_spatial_dims + 1
        allowed_ranks = [unbatched_rank, unbatched_rank + 1]
        if inputs.ndim not in allowed_ranks:
            raise ValueError(
                f"Input to ConvNDTranspose needs to have rank in "
                f"{allowed_ranks}, but input has shape {inputs.shape}.")

        unbatched = inputs.ndim == unbatched_rank
        if unbatched:
            inputs = jnp.expand_dims(inputs, axis=0)
            latents = jnp.expand_dims(latents, axis=0)
        assert latents.ndim == 2

        input_channels = inputs.shape[self.channel_index]
        w_shape = self.kernel_shape + (self.output_channels, input_channels)

        w_init = self.w_init
        if w_init is None:
            fan_in_shape = self.kernel_shape + (input_channels, )
            stddev = 1.0 / np.sqrt(np.prod(fan_in_shape))
            w_init = hk.initializers.TruncatedNormal(stddev=stddev)
        w = hk.get_parameter("w", w_shape, inputs.dtype, init=w_init)

        conv_fn = ft.partial(
            mod_demod_conv_transpose,
            orig_weight=w,
            channel_index=self.channel_index,
            demodulate=self.demodulate,
            strides=self.stride,
            padding=self.padding,
            dimension_numbers=self.dimension_numbers,
        )
        # Modulate; +1 do have default bias be == 1
        styles = hk.Linear(inputs.shape[self.channel_index])(latents) + 1
        out = jax.vmap(conv_fn)(inputs, styles)

        if self.with_bias:
            if self.channel_index == -1:
                bias_shape = (self.output_channels, )
            else:
                bias_shape = (
                    self.output_channels, ) + (1, ) * self.num_spatial_dims
            b = hk.get_parameter("b",
                                 bias_shape,
                                 inputs.dtype,
                                 init=self.b_init)
            b = jnp.broadcast_to(b, out.shape)
            out = out + b

        if unbatched:
            out = jnp.squeeze(out, axis=0)
        return out
Ejemplo n.º 20
0
 def __call__(self, x):
     j, k = x.shape[-1], self.output_size
     w_init = hk.initializers.TruncatedNormal(1.0 / np.sqrt(j))
     w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
     b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)
     return jnp.dot(x, w) + b
Ejemplo n.º 21
0
    def __call__(self, inputs: jnp.ndarray, rng_key,
                 stochastic) -> jnp.ndarray:
        """Connects ``ConvND`` layer.

    Args:
      inputs: An array of shape ``[spatial_dims, C]`` and rank-N+1 if
        unbatched, or an array of shape ``[N, spatial_dims, C]`` and rank-N+2
        if batched.
      rng_key: RNG.
      stochastic: whether or not in stochastic or deterministic mode.

    Returns:
      An array of shape ``[spatial_dims, output_channels]`` and rank-N+1 if
      unbatched, or an array of shape ``[N, spatial_dims, output_channels]``
      and rank-N+2 if batched.
    """
        dtype = inputs.dtype

        unbatched_rank = self.num_spatial_dims + 1
        allowed_ranks = [unbatched_rank, unbatched_rank + 1]
        if inputs.ndim not in allowed_ranks:
            raise ValueError(
                f"Input to ConvND needs to have rank in {allowed_ranks},"
                f" but input has shape {inputs.shape}.")

        unbatched = inputs.ndim == unbatched_rank
        if unbatched:
            inputs = jnp.expand_dims(inputs, axis=0)

        if inputs.shape[self.channel_index] % self.feature_group_count != 0:
            raise ValueError(
                f"Inputs channels {inputs.shape[self.channel_index]} "
                f"should be a multiple of feature_group_count "
                f"{self.feature_group_count}")
        w_shape = self.kernel_shape + (
            inputs.shape[self.channel_index] // self.feature_group_count,
            self.output_channels,
        )

        if self.mask is not None and self.mask.shape != w_shape:
            raise ValueError("Mask needs to have the same shape as weights. "
                             f"Shapes are: {self.mask.shape}, {w_shape}")

        fan_in_shape = np.prod(w_shape[:-1])
        stddev = 1.0 / np.sqrt(fan_in_shape)
        self.w_init = parse_w_init(init_type=self.w_init,
                                   uniform_stddev=stddev)
        self.b_init = parse_b_init(init_type=self.b_init,
                                   uniform_stddev=stddev)

        w_mu = hk.get_parameter("w_mu", w_shape, dtype,
                                init=self.w_init)  ### changed code!

        if self.stochastic_parameters:
            w_logvar = hk.get_parameter(
                "w_logvar",
                w_shape,
                dtype=dtype,
                init=uniform_initializer(self.uniform_init_minval,
                                         self.uniform_init_maxval),
            )
            rng_key, sub_key = jax.random.split(rng_key)
            w = gaussian_sample(w_mu, w_logvar, stochastic, sub_key)
            out = lax.conv_general_dilated(
                inputs,
                w,
                window_strides=self.stride,
                padding=self.padding,
                lhs_dilation=self.lhs_dilation,
                rhs_dilation=self.kernel_dilation,
                dimension_numbers=self.dimension_numbers,
                feature_group_count=self.feature_group_count,
            )
        else:
            out = lax.conv_general_dilated(
                inputs,
                w_mu,
                window_strides=self.stride,
                padding=self.padding,
                lhs_dilation=self.lhs_dilation,
                rhs_dilation=self.kernel_dilation,
                dimension_numbers=self.dimension_numbers,
                feature_group_count=self.feature_group_count,
            )

        if self.with_bias:
            if self.channel_index == -1:
                bias_shape = (self.output_channels, )
            else:
                bias_shape = (
                    self.output_channels, ) + (1, ) * self.num_spatial_dims
            b_mu = hk.get_parameter("b_mu",
                                    bias_shape,
                                    inputs.dtype,
                                    init=self.b_init)
            if self.stochastic_parameters:
                b_logvar = hk.get_parameter(
                    "b_logvar",
                    shape=bias_shape,
                    dtype=inputs.dtype,
                    init=uniform_initializer(self.uniform_init_minval,
                                             self.uniform_init_maxval),
                )
                rng_key, sub_key = jax.random.split(rng_key)
                b = gaussian_sample(b_mu, b_logvar, stochastic, sub_key)
                b = jnp.broadcast_to(b, out.shape)
            else:
                b = jnp.broadcast_to(b_mu, out.shape)
            out = out + b

        if unbatched:
            out = jnp.squeeze(out, axis=0)
        return out
Ejemplo n.º 22
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: PRNGKey,
             sample: Optional[bool] = False,
             reconstruction: Optional[bool] = False,
             is_training: bool = True,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        x = inputs["x"]
        outputs = {}
        x_shape = self.get_unbatched_shapes(sample)["x"]

        # Make sure that we're using 1d inputs
        x = inputs["x"].reshape(self.batch_shape + (-1, ))
        x_dim = x.shape[-1]

        # Work with one-hot labels
        y_one_hot = inputs.get("y", None)
        if y_one_hot is not None:
            assert y_one_hot.shape == self.batch_shape + (self.n_classes, )
            y_one_hot *= 1.0
        else:
            if sample == False:
                # Assign equal probability to each class
                y_one_hot = jnp.ones(self.batch_shape + (self.n_classes, ))
            else:
                # Sample class labels
                y = random.randint(rng,
                                   minval=0,
                                   maxval=self.n_classes,
                                   shape=self.batch_shape)
                y_one_hot = y[..., None] == jnp.arange(self.n_classes)[..., :]
                y_one_hot *= 1.0

        # GMM parameters.  Assume uniform mixture component weights so that things are differentiable.
        means = hk.get_parameter("means",
                                 shape=(self.n_classes, x_dim),
                                 dtype=x.dtype,
                                 init=hk.initializers.RandomNormal())
        log_diag_covs = hk.get_parameter("log_diag_covs",
                                         shape=(self.n_classes, x_dim),
                                         dtype=x.dtype,
                                         init=jnp.ones)
        diag_covs = util.proximal_relu(log_diag_covs) + 1e-3
        log_diag_covs = jnp.log(diag_covs)

        # Sample a new input
        if sample == True and reconstruction == False:
            # Sample from all of the clusters
            noise = random.normal(rng,
                                  self.batch_shape + (self.n_classes, x_dim))
            xs = means + jnp.exp(0.5 * log_diag_covs) * noise

            # Select the mixture component
            x = xs * y_one_hot[..., None]
            x = x.sum(axis=-2)

        # Evaluate the log pdf for each mixture component
        @partial(jax.vmap, in_axes=(0, 0, None))
        def diag_gaussian(mean, log_diag_cov, x):
            dx = x - mean
            log_pdf = jnp.sum(dx**2 * jnp.exp(-log_diag_cov), axis=-1)
            log_pdf += log_diag_cov.sum()
            log_pdf += x_dim * jnp.log(2 * jnp.pi)
            return -0.5 * log_pdf

        # Last axis will be across the mixture components
        log_pdfs = self.auto_batch(partial(diag_gaussian, means,
                                           log_diag_covs))(x)

        # Make a class prediction
        y_pred = jnp.argmax(log_pdfs, axis=-1)
        y_pred_one_hot = y_pred[...,
                                None] == jnp.arange(self.n_classes)[..., :]
        y_pred_one_hot *= 1.0

        # Compute p(x,y) = p(x|y)p(y) if we have a label, p(x) otherwise.
        # If we have a label, zero out all but the label index then reduce.
        # Otherwise, reduce over all of the indices.
        if is_training:

            # Apply the label masks
            if "y_is_labeled" in inputs:
                y_is_labeled = inputs["y_is_labeled"][..., None].astype(bool)
                y_one_hot = y_one_hot * y_is_labeled + jnp.ones_like(
                    y_one_hot) * (~y_is_labeled)

            log_pz = util.lse(log_pdfs, b=y_one_hot, axis=-1)
            # log_pz = logsumexp(log_pdfs, b=y_one_hot, axis=-1)
        else:
            # If we're doing classification, use the predicted label
            if "y" in inputs:
                log_pz = util.lse(log_pdfs, b=y_pred_one_hot, axis=-1)
            else:
                log_pz = logsumexp(log_pdfs, axis=-1)

        # Account for p(y)=1/N or 1/N when we take the mean
        log_pz -= jnp.log(self.n_classes)

        # p(y|x) is a categorical distribution
        log_pygx = jax.nn.log_softmax(log_pdfs)
        if is_training:
            log_pygx *= y_one_hot

            if "y_is_labeled" in inputs:
                # This time, zero out values that aren't labeled
                log_pygx *= y_is_labeled

        else:
            if "y" in inputs:
                log_pygx *= y_pred_one_hot

        log_pygx = log_pygx.sum(axis=-1)

        # Reshape the output
        x = x.reshape(self.batch_shape + x_shape)

        outputs = {"x": x, "log_pz": log_pz, "log_pygx": log_pygx}
        outputs["prediction"] = y_pred
        outputs["prediction_one_hot"] = outputs["prediction"][
            ..., None] == jnp.arange(self.n_classes)[..., :]
        return outputs
Ejemplo n.º 23
0
    def __call__(
        self,
        query: jnp.ndarray,
        key: tp.Optional[jnp.ndarray] = None,
        value: tp.Optional[jnp.ndarray] = None,
        mask=None,
        is_training=None,
    ):

        # einsum nomenclature
        # ------------------------
        # N = query elements
        # M = key/value elements
        # H = heads
        # I = input features
        # O = output features

        if key is None:
            key = query

        if value is None:
            value = key

        output_size = (self.output_size
                       if self.output_size is not None else value.shape[-1])

        # verify shapes
        if key.shape[-2] != value.shape[-2]:
            raise ValueError(
                "the number of elements in 'key' must be equal to the same as the number of elements in 'value'"
            )

        if mask is not None:
            if len(mask.shape) < 2:
                raise ValueError("'mask' must have atleast 2 dimensions")
            if query.shape[-2] != mask.shape[-2]:
                raise ValueError(
                    "mask's second to last dimension must be equal to the number of elements in 'query'"
                )
            if key.shape[-2] != mask.shape[-1]:
                raise ValueError(
                    "mask's last dimension must be equal to the number of elements in 'key'"
                )

        # get weights
        query_kernel = hk.get_parameter(
            "query_kernel",
            [self.num_heads, query.shape[-1], self.head_size],
            init=self.kernel_initializer,
        )
        key_kernel = hk.get_parameter(
            "key_kernel",
            [self.num_heads, key.shape[-1], self.head_size],
            init=self.kernel_initializer,
        )
        value_kernel = hk.get_parameter(
            "value_kernel",
            [self.num_heads, value.shape[-1], self.head_size],
            init=self.kernel_initializer,
        )
        projection_kernel = hk.get_parameter(
            name="projection_kernel",
            shape=[self.num_heads, self.head_size, output_size],
            init=self.kernel_initializer,
        )

        # Linear transformations
        query = jnp.einsum("...NI , HIO -> ...NHO", query, query_kernel)
        key = jnp.einsum("...MI , HIO -> ...MHO", key, key_kernel)
        value = jnp.einsum("...MI , HIO -> ...MHO", value, value_kernel)

        # Scale dot-product, doing the division to either query or key
        # instead of their product saves some computation
        query /= jnp.sqrt(self.head_size)

        # Calculate dot product attention
        logits = jnp.einsum("...NHO,...MHO->...HNM", query, key)

        # apply mask
        if mask is not None:
            mask = mask.astype(jnp.float32)

            # possibly expand on the head dimension so broadcasting works
            if len(mask.shape) != len(logits.shape):
                mask = jnp.expand_dims(mask, -3)

            logits += -10e9 * (1.0 - mask)

        attn_coef = jax.nn.softmax(logits)

        # attention dropout
        attn_coef_dropout = hk.dropout(
            hk.next_rng_key(), self.droput_rate if is_training else 0.0,
            attn_coef)

        # attention * value
        multihead_output = jnp.einsum("...HNM,...MHI->...NHI",
                                      attn_coef_dropout, value)

        # Run the outputs through another linear projection layer. Recombining heads
        # is automatically done.
        output = jnp.einsum("...NHI,HIO->...NO", multihead_output,
                            projection_kernel)

        if self.use_projection_bias:
            output += hk.get_parameter(
                name="projection_bias",
                shape=[output_size],
                init=self.bias_initializer,
            )

        if self.return_attn_coef:
            return output, attn_coef
        else:
            return output