Esempio n. 1
0
 def __call__(self, x: JaxArray, training: bool = True) -> JaxArray:
     """Returns the results of applying group normalization to input x."""
     # This method has unused training argument, so group norm can be used as a drop-in replacement of batch norm.
     del training
     group_shape = ((-1, self.groups, self.nin // self.groups) + x.shape[2:])
     x = x.reshape(group_shape)
     mean = x.mean(axis=self.redux, keepdims=True)
     var = x.var(axis=self.redux, keepdims=True)
     x = (x - mean) * functional.rsqrt(var + self.eps)
     x = x.reshape((-1, self.nin,) + group_shape[3:])
     x = x * self.gamma.value + self.beta.value
     return x
Esempio n. 2
0
def upscale_nn(x: JaxArray, scale: int = 2) -> JaxArray:
    """Nearest neighbor upscale for image batches of shape (N, C, H, W).

    Args:
        x: input tensor of shape (N, C, H, W).
        scale: integer scaling factor.

    Returns:
        Output tensor of shape (N, C, H * scale, W * scale).
    """
    s = x.shape
    x = x.reshape(s[:2] + (s[2], 1, s[3], 1))
    x = jn.tile(x, (1, 1, 1, scale, 1, scale))
    return x.reshape(s[:2] + (scale * s[2], scale * s[3]))
Esempio n. 3
0
    def __call__(self, x: JaxArray, training: bool) -> JaxArray:
        """Performs batch normalization of input tensor.

        Args:
            x: input tensor.
            training: if True compute batch normalization in training mode (accumulating batch statistics),
                otherwise compute in evaluation mode (using already accumulated batch statistics).

        Returns:
            Batch normalized tensor.
        """
        if training:
            m = x.mean(self.redux, keepdims=True)
            v = ((x - m)**2).mean(
                self.redux,
                keepdims=True)  # Note: x^2 - m^2 is not numerically stable.
            self.running_mean.value += (1 - self.momentum) * (
                m - self.running_mean.value)
            self.running_var.value += (1 - self.momentum) * (
                v - self.running_var.value)
        else:
            m, v = self.running_mean.value, self.running_var.value
        y = self.gamma.value * (
            x - m) * functional.rsqrt(v + self.eps) + self.beta.value
        return y
Esempio n. 4
0
def mixture_gaussian_cdf(
    x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray
) -> JaxArray:
    """
    Args:
        x (JaxArray): input vector
            (D,)
        prior_logits (JaxArray): prior logits to weight the components
            (D, K)
        means (JaxArray): means per component per feature
            (D, K)
        scales (JaxArray): scales per component per feature
            (D, K)
    Returns:
        x_cdf (JaxArray) : CDF for the mixture distribution
    """
    # n_features, n_components = prior_logits
    #
    # x_r = np.tile(x, (n_features, n_components))
    x_r = x.reshape(-1, 1)
    # normalize logit weights to 1
    prior_logits = jax.nn.log_softmax(prior_logits)

    # calculate the log cdf
    log_cdfs = prior_logits + jax.scipy.stats.norm.logcdf(x_r, means, scales)

    # normalize distribution
    log_cdf = jax.scipy.special.logsumexp(log_cdfs, axis=1)

    return np.exp(log_cdf)
Esempio n. 5
0
 def __call__(self, x: JaxArray, training: bool) -> JaxArray:
     if self.global_pool == 'avg':
         x = x.mean((2, 3), keepdims=True)
     x = self.conv_pw(x).reshape(x.shape[0], -1)
     x = self.act_fn(x)
     x = self.classifier(x)
     return x
Esempio n. 6
0
def mixture_logistic_cdf(
    x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray
) -> JaxArray:
    """
    Args:
        x (JaxArray): input vector
            (D,)
        prior_logits (JaxArray): prior logits to weight the components
            (D, K)
        means (JaxArray): means per component per feature
            (D, K)
        scales (JaxArray): scales per component per feature
            (D, K)
    Returns:
        log_cdf (JaxArray) : log CDF for the mixture distribution
    """
    # print(prior_logits.shape)
    # n_features, n_components = prior_logits
    x_r = x.reshape(-1, 1)
    #
    # x_r = np.tile(x, (n_features, n_components))
    # print(x.shape, x_r.shape)
    # normalize logit weights to 1, (D,K)->(D,K)
    prior_logits = log_softmax(prior_logits, axis=1)

    # calculate the log pdf, (D,K)->(D,K)
    log_cdfs = prior_logits + logistic_log_cdf(x_r, means, scales)

    # normalize distribution for components, (D,K)->(D,)
    log_cdf = logsumexp(log_cdfs, axis=1)

    return np.exp(log_cdf)
Esempio n. 7
0
def mixture_logistic_log_pdf(
    x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray
) -> JaxArray:
    """
    Args:
        x (JaxArray): input vector
            (D,)
        prior_logits (JaxArray): prior logits to weight the components
            (D, K)
        means (JaxArray): means per component per feature
            (D, K)
        scales (JaxArray): scales per component per feature
            (D, K)
    Returns:
        log_pdf (JaxArray) : log PDF for the mixture distribution
    """
    # n_components = prior_logits.shape[1]
    #

    # add component dimension, (D,)->(D,1)
    # will allow for broadcasting
    x_r = x.reshape(-1, 1)

    # normalize logit weights to 1, (D,K)->(D,K)
    prior_logits = log_softmax(prior_logits, axis=1)

    # calculate the log pdf, (D,K)->(D,K)
    # print(x.shape, prior_logits.shape, )
    log_pdfs = prior_logits + logistic_log_pdf(x_r, means, scales)
    # print("Log PDFS:", log_pdfs.shape)

    # normalize distribution for components, (D,K)->(D,)
    log_pdf = logsumexp(log_pdfs, axis=1)

    return log_pdf
Esempio n. 8
0
def mixture_gaussian_log_pdf(
    x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray
) -> JaxArray:
    """
    Args:
        x (JaxArray): input vector
            (D,)
        prior_logits (JaxArray): prior logits to weight the components
            (D, K)
        means (JaxArray): means per component per feature
            (D, K)
        scales (JaxArray): scales per component per feature
            (D, K)
    Returns:
        log_pdf (JaxArray) : log PDF for the mixture distribution
    """
    # n_components = prior_logits.shape[1]
    #
    # x_r = np.tile(x, (n_components))
    x_r = x.reshape(-1, 1)
    # normalize logit weights to 1, (D,K)->(D,K)
    prior_logits = log_softmax(prior_logits, axis=1)

    # calculate the log pdf, (D,K)->(D,K)
    log_pdfs = prior_logits + jax.scipy.stats.norm.logpdf(x_r, means, scales)

    # normalize distribution for components, (D,K)->(D,)
    log_pdf = logsumexp(log_pdfs, axis=1)

    return log_pdf
Esempio n. 9
0
 def __call__(self, x: JaxArray, training: bool) -> JaxArray:
     x = self.conv_pw(x)
     x = self.bn(x, training=training)
     x = self.act_fn(x)
     if self.global_pool == 'avg':
         x = x.mean((2, 3))
     if self.classifier is not None:
         x = self.classifier(x)
     return x
Esempio n. 10
0
 def device_reshape(self, x: JaxArray) -> JaxArray:
     """Utility to reshape an input array in order to broadcast to multiple devices."""
     assert hasattr(x, 'ndim'), f'Expected JaxArray, got {type(x)}. If you are trying to pass a scalar to ' \
                                f'parallel, first convert it to a JaxArray, for example np.float(0.5)'
     if x.ndim == 0:
         return jn.broadcast_to(x, [self.ndevices])
     assert x.shape[0] % self.ndevices == 0, f'Must be able to equally divide batch {x.shape} among ' \
                                             f'{self.ndevices} devices, but does not go equally.'
     return x.reshape((self.ndevices, x.shape[0] // self.ndevices) + x.shape[1:])
Esempio n. 11
0
 def __call__(self, x: JaxArray, training: bool, batch_norm_update: bool = True) -> JaxArray:
     if training:
         m = functional.parallel.pmean(x.mean(self.redux, keepdims=True))
         v = functional.parallel.pmean((x ** 2).mean(self.redux, keepdims=True) - m ** 2)
         if batch_norm_update:
             self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value)
             self.running_var.value += (1 - self.momentum) * (v - self.running_var.value)
     else:
         m, v = self.running_mean.value, self.running_var.value
     y = self.gamma.value * (x - m) * functional.rsqrt(v + self.eps) + self.beta.value
     return y
Esempio n. 12
0
def flatten(x: JaxArray) -> JaxArray:
    """Flattens input tensor to a 2D tensor.

    Args:
        x: input tensor with dimensions (n_1, n_2, ..., n_k)

    Returns:
        The input tensor reshaped to two dimensions (n_1, n_prod),
        where n_prod is equal to the product of n_2 to n_k.
    """
    return x.reshape([x.shape[0], -1])
Esempio n. 13
0
def channel_to_space2d(x: JaxArray,
                       size: Union[Tuple[int, int], int] = 2) -> JaxArray:
    """Transfer channel dimension C into spatial dimensions (H, W).

    Args:
        x: input tensor of shape (N, C, H, W).
        size: size of spatial area.

    Returns:
        output tensor of shape (N, C // (size[0] * size[1]), H * size[0], W * size[1]).
    """
    size = to_tuple(size, 2)
    s = x.shape
    y = x.reshape((s[0], -1, size[0], size[1], s[2], s[3]))
    y = y.transpose((0, 1, 4, 2, 5, 3))
    return y.reshape(
        (s[0], s[1] // (size[0] * size[1]), s[2] * size[0], s[3] * size[1]))
Esempio n. 14
0
    def reshape_microbatch(self, x: JaxArray) -> JaxArray:
        """Reshapes examples into microbatches.
        DP-SGD requires that per-example gradients are clipped and noised, however this can be inefficient.
        To speed this up, it is possible to clip and noise a microbatch of examples, at a sight cost to privacy.
        If speed is not an issue, the microbatch size should be set to 1.

        If x has shape [D0, D1, ..., Dn], the reshaped output will
        have shape [number_of_microbatches, microbatch_size, D1, ..., DN].

        Args:
            x: items to be reshaped.

        Returns:
            The reshaped items.
        """
        s = x.shape
        return x.reshape([s[0] // self.microbatch, self.microbatch, *s[1:]])
Esempio n. 15
0
def space_to_channel2d(x: JaxArray,
                       size: Union[Tuple[int, int], int] = 2) -> JaxArray:
    """Transfer spatial dimensions (H, W) into channel dimension C.

    Args:
        x: input tensor of shape (N, C, H, W).
        size: size of spatial area.

    Returns:
        output tensor of shape (N, C * size[0] * size[1]), H // size[0], W // size[1]).
    """
    size = to_tuple(size, 2)
    s = x.shape
    y = x.reshape(
        (s[0], s[1], s[2] // size[0], size[0], s[3] // size[1], size[1]))
    y = y.transpose((0, 1, 3, 5, 2, 4))
    return y.reshape(
        (s[0], s[1] * size[0] * size[1], s[2] // size[0], s[3] // size[1]))
Esempio n. 16
0
def upsample_2d(x: JaxArray,
                scale: Union[Tuple[int, int], int],
                method: Union[Interpolate, str] = Interpolate.BILINEAR) -> JaxArray:
    """Function to upscale 2D images.

    Args:
        x: input tensor.
        scale: int or tuple scaling factor
        method: str or UpSample interpolation methods e.g. ['bilinear', 'nearest'].

    Returns:
        upscaled 2d image tensor
    """
    s = x.shape
    assert len(s) == 4, f'{s} must have 4 dimensions to be upsampled, or you can try interpolate function.'
    scale = util.to_tuple(scale, 2)
    y = jax.image.resize(x.transpose([0, 2, 3, 1]),
                         shape=(s[0], s[2] * scale[0], s[3] * scale[1], s[1]),
                         method=util.to_interpolate(method))
    return y.transpose([0, 3, 1, 2])
Esempio n. 17
0
    def __call__(self, x: JaxArray, training: bool) -> JaxArray:
        """Performs batch normalization of input tensor.

        Args:
            x: input tensor.
            training: if True compute batch normalization in training mode (accumulating batch statistics),
                otherwise compute in evaluation mode (using already accumulated batch statistics).

        Returns:
            Batch normalized tensor.
        """
        shape = (1, -1, 1, 1)
        weight = self.weight.value.reshape(shape)
        bias = self.bias.value.reshape(shape)
        if training:
            mean = x.mean(self.redux, keepdims=True)
            var = (x ** 2).mean(self.redux, keepdims=True) - mean ** 2
            self.running_mean.value += (1 - self.momentum) * (mean.squeeze(axis=self.redux) - self.running_mean.value)
            self.running_var.value += (1 - self.momentum) * (var.squeeze(axis=self.redux) - self.running_var.value)
        else:
            mean, var = self.running_mean.value.reshape(shape), self.running_var.value.reshape(shape)

        y = weight * (x - mean) * functional.rsqrt(var + self.eps) + bias
        return y
Esempio n. 18
0
 def device_reshape(self, x: JaxArray) -> JaxArray:
     """Utility to reshape an input array in order to broadcast to multiple devices."""
     return x.reshape((self.ndevices, x.shape[0] // self.ndevices) +
                      x.shape[1:])
Esempio n. 19
0
 def _mean_reduce(x: JaxArray) -> JaxArray:
     return x.mean((2, 3))
Esempio n. 20
0
def reduce_mean(x: JaxArray) -> JaxArray:
    return x.mean(0)
Esempio n. 21
0
def kl(p: JaxArray, q: JaxArray, eps: float = 2**-17) -> JaxArray:
    """Calculates the Kullback-Leibler divergence between arrays p and q."""
    return p.dot(jn.log(p + eps) - jn.log(q + eps))
Esempio n. 22
0
 def device_reshape(self, x: JaxArray) -> JaxArray:
     """Utility to reshape an input array in order to broadcast to multiple devices."""
     assert x.shape[0] % self.ndevices == 0, f'Must be able to equally divide batch {x.shape} among ' \
                                             f'{self.ndevices} devices, but does not go equally.'
     return x.reshape((self.ndevices, x.shape[0] // self.ndevices) +
                      x.shape[1:])