Example #1
0
 def __call__(self, x: JaxArray, training: bool = True) -> JaxArray:
     """Returns the results of applying group normalization to input x."""
     # This method has unused training argument, so group norm can be used as a drop-in replacement of batch norm.
     del training
     group_shape = ((-1, self.groups, self.nin // self.groups) + x.shape[2:])
     x = x.reshape(group_shape)
     mean = x.mean(axis=self.redux, keepdims=True)
     var = x.var(axis=self.redux, keepdims=True)
     x = (x - mean) * functional.rsqrt(var + self.eps)
     x = x.reshape((-1, self.nin,) + group_shape[3:])
     x = x * self.gamma.value + self.beta.value
     return x
Example #2
0
def upscale_nn(x: JaxArray, scale: int = 2) -> JaxArray:
    """Nearest neighbor upscale for image batches of shape (N, C, H, W).

    Args:
        x: input tensor of shape (N, C, H, W).
        scale: integer scaling factor.

    Returns:
        Output tensor of shape (N, C, H * scale, W * scale).
    """
    s = x.shape
    x = x.reshape(s[:2] + (s[2], 1, s[3], 1))
    x = jn.tile(x, (1, 1, 1, scale, 1, scale))
    return x.reshape(s[:2] + (scale * s[2], scale * s[3]))
Example #3
0
def mixture_gaussian_log_pdf(
    x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray
) -> JaxArray:
    """
    Args:
        x (JaxArray): input vector
            (D,)
        prior_logits (JaxArray): prior logits to weight the components
            (D, K)
        means (JaxArray): means per component per feature
            (D, K)
        scales (JaxArray): scales per component per feature
            (D, K)
    Returns:
        log_pdf (JaxArray) : log PDF for the mixture distribution
    """
    # n_components = prior_logits.shape[1]
    #
    # x_r = np.tile(x, (n_components))
    x_r = x.reshape(-1, 1)
    # normalize logit weights to 1, (D,K)->(D,K)
    prior_logits = log_softmax(prior_logits, axis=1)

    # calculate the log pdf, (D,K)->(D,K)
    log_pdfs = prior_logits + jax.scipy.stats.norm.logpdf(x_r, means, scales)

    # normalize distribution for components, (D,K)->(D,)
    log_pdf = logsumexp(log_pdfs, axis=1)

    return log_pdf
Example #4
0
def mixture_gaussian_cdf(
    x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray
) -> JaxArray:
    """
    Args:
        x (JaxArray): input vector
            (D,)
        prior_logits (JaxArray): prior logits to weight the components
            (D, K)
        means (JaxArray): means per component per feature
            (D, K)
        scales (JaxArray): scales per component per feature
            (D, K)
    Returns:
        x_cdf (JaxArray) : CDF for the mixture distribution
    """
    # n_features, n_components = prior_logits
    #
    # x_r = np.tile(x, (n_features, n_components))
    x_r = x.reshape(-1, 1)
    # normalize logit weights to 1
    prior_logits = jax.nn.log_softmax(prior_logits)

    # calculate the log cdf
    log_cdfs = prior_logits + jax.scipy.stats.norm.logcdf(x_r, means, scales)

    # normalize distribution
    log_cdf = jax.scipy.special.logsumexp(log_cdfs, axis=1)

    return np.exp(log_cdf)
Example #5
0
def mixture_logistic_cdf(
    x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray
) -> JaxArray:
    """
    Args:
        x (JaxArray): input vector
            (D,)
        prior_logits (JaxArray): prior logits to weight the components
            (D, K)
        means (JaxArray): means per component per feature
            (D, K)
        scales (JaxArray): scales per component per feature
            (D, K)
    Returns:
        log_cdf (JaxArray) : log CDF for the mixture distribution
    """
    # print(prior_logits.shape)
    # n_features, n_components = prior_logits
    x_r = x.reshape(-1, 1)
    #
    # x_r = np.tile(x, (n_features, n_components))
    # print(x.shape, x_r.shape)
    # normalize logit weights to 1, (D,K)->(D,K)
    prior_logits = log_softmax(prior_logits, axis=1)

    # calculate the log pdf, (D,K)->(D,K)
    log_cdfs = prior_logits + logistic_log_cdf(x_r, means, scales)

    # normalize distribution for components, (D,K)->(D,)
    log_cdf = logsumexp(log_cdfs, axis=1)

    return np.exp(log_cdf)
Example #6
0
def mixture_logistic_log_pdf(
    x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray
) -> JaxArray:
    """
    Args:
        x (JaxArray): input vector
            (D,)
        prior_logits (JaxArray): prior logits to weight the components
            (D, K)
        means (JaxArray): means per component per feature
            (D, K)
        scales (JaxArray): scales per component per feature
            (D, K)
    Returns:
        log_pdf (JaxArray) : log PDF for the mixture distribution
    """
    # n_components = prior_logits.shape[1]
    #

    # add component dimension, (D,)->(D,1)
    # will allow for broadcasting
    x_r = x.reshape(-1, 1)

    # normalize logit weights to 1, (D,K)->(D,K)
    prior_logits = log_softmax(prior_logits, axis=1)

    # calculate the log pdf, (D,K)->(D,K)
    # print(x.shape, prior_logits.shape, )
    log_pdfs = prior_logits + logistic_log_pdf(x_r, means, scales)
    # print("Log PDFS:", log_pdfs.shape)

    # normalize distribution for components, (D,K)->(D,)
    log_pdf = logsumexp(log_pdfs, axis=1)

    return log_pdf
Example #7
0
 def device_reshape(self, x: JaxArray) -> JaxArray:
     """Utility to reshape an input array in order to broadcast to multiple devices."""
     assert hasattr(x, 'ndim'), f'Expected JaxArray, got {type(x)}. If you are trying to pass a scalar to ' \
                                f'parallel, first convert it to a JaxArray, for example np.float(0.5)'
     if x.ndim == 0:
         return jn.broadcast_to(x, [self.ndevices])
     assert x.shape[0] % self.ndevices == 0, f'Must be able to equally divide batch {x.shape} among ' \
                                             f'{self.ndevices} devices, but does not go equally.'
     return x.reshape((self.ndevices, x.shape[0] // self.ndevices) + x.shape[1:])
Example #8
0
def flatten(x: JaxArray) -> JaxArray:
    """Flattens input tensor to a 2D tensor.

    Args:
        x: input tensor with dimensions (n_1, n_2, ..., n_k)

    Returns:
        The input tensor reshaped to two dimensions (n_1, n_prod),
        where n_prod is equal to the product of n_2 to n_k.
    """
    return x.reshape([x.shape[0], -1])
Example #9
0
    def reshape_microbatch(self, x: JaxArray) -> JaxArray:
        """Reshapes examples into microbatches.
        DP-SGD requires that per-example gradients are clipped and noised, however this can be inefficient.
        To speed this up, it is possible to clip and noise a microbatch of examples, at a sight cost to privacy.
        If speed is not an issue, the microbatch size should be set to 1.

        If x has shape [D0, D1, ..., Dn], the reshaped output will
        have shape [number_of_microbatches, microbatch_size, D1, ..., DN].

        Args:
            x: items to be reshaped.

        Returns:
            The reshaped items.
        """
        s = x.shape
        return x.reshape([s[0] // self.microbatch, self.microbatch, *s[1:]])
Example #10
0
def channel_to_space2d(x: JaxArray,
                       size: Union[Tuple[int, int], int] = 2) -> JaxArray:
    """Transfer channel dimension C into spatial dimensions (H, W).

    Args:
        x: input tensor of shape (N, C, H, W).
        size: size of spatial area.

    Returns:
        output tensor of shape (N, C // (size[0] * size[1]), H * size[0], W * size[1]).
    """
    size = to_tuple(size, 2)
    s = x.shape
    y = x.reshape((s[0], -1, size[0], size[1], s[2], s[3]))
    y = y.transpose((0, 1, 4, 2, 5, 3))
    return y.reshape(
        (s[0], s[1] // (size[0] * size[1]), s[2] * size[0], s[3] * size[1]))
Example #11
0
def space_to_channel2d(x: JaxArray,
                       size: Union[Tuple[int, int], int] = 2) -> JaxArray:
    """Transfer spatial dimensions (H, W) into channel dimension C.

    Args:
        x: input tensor of shape (N, C, H, W).
        size: size of spatial area.

    Returns:
        output tensor of shape (N, C * size[0] * size[1]), H // size[0], W // size[1]).
    """
    size = to_tuple(size, 2)
    s = x.shape
    y = x.reshape(
        (s[0], s[1], s[2] // size[0], size[0], s[3] // size[1], size[1]))
    y = y.transpose((0, 1, 3, 5, 2, 4))
    return y.reshape(
        (s[0], s[1] * size[0] * size[1], s[2] // size[0], s[3] // size[1]))
Example #12
0
 def device_reshape(self, x: JaxArray) -> JaxArray:
     """Utility to reshape an input array in order to broadcast to multiple devices."""
     assert x.shape[0] % self.ndevices == 0, f'Must be able to equally divide batch {x.shape} among ' \
                                             f'{self.ndevices} devices, but does not go equally.'
     return x.reshape((self.ndevices, x.shape[0] // self.ndevices) +
                      x.shape[1:])
Example #13
0
 def device_reshape(self, x: JaxArray) -> JaxArray:
     """Utility to reshape an input array in order to broadcast to multiple devices."""
     return x.reshape((self.ndevices, x.shape[0] // self.ndevices) +
                      x.shape[1:])