def __call__(self, x: JaxArray, training: bool = True) -> JaxArray: """Returns the results of applying group normalization to input x.""" # This method has unused training argument, so group norm can be used as a drop-in replacement of batch norm. del training group_shape = ((-1, self.groups, self.nin // self.groups) + x.shape[2:]) x = x.reshape(group_shape) mean = x.mean(axis=self.redux, keepdims=True) var = x.var(axis=self.redux, keepdims=True) x = (x - mean) * functional.rsqrt(var + self.eps) x = x.reshape((-1, self.nin,) + group_shape[3:]) x = x * self.gamma.value + self.beta.value return x
def upscale_nn(x: JaxArray, scale: int = 2) -> JaxArray: """Nearest neighbor upscale for image batches of shape (N, C, H, W). Args: x: input tensor of shape (N, C, H, W). scale: integer scaling factor. Returns: Output tensor of shape (N, C, H * scale, W * scale). """ s = x.shape x = x.reshape(s[:2] + (s[2], 1, s[3], 1)) x = jn.tile(x, (1, 1, 1, scale, 1, scale)) return x.reshape(s[:2] + (scale * s[2], scale * s[3]))
def __call__(self, x: JaxArray, training: bool) -> JaxArray: """Performs batch normalization of input tensor. Args: x: input tensor. training: if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics). Returns: Batch normalized tensor. """ if training: m = x.mean(self.redux, keepdims=True) v = ((x - m)**2).mean( self.redux, keepdims=True) # Note: x^2 - m^2 is not numerically stable. self.running_mean.value += (1 - self.momentum) * ( m - self.running_mean.value) self.running_var.value += (1 - self.momentum) * ( v - self.running_var.value) else: m, v = self.running_mean.value, self.running_var.value y = self.gamma.value * ( x - m) * functional.rsqrt(v + self.eps) + self.beta.value return y
def mixture_gaussian_cdf( x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray ) -> JaxArray: """ Args: x (JaxArray): input vector (D,) prior_logits (JaxArray): prior logits to weight the components (D, K) means (JaxArray): means per component per feature (D, K) scales (JaxArray): scales per component per feature (D, K) Returns: x_cdf (JaxArray) : CDF for the mixture distribution """ # n_features, n_components = prior_logits # # x_r = np.tile(x, (n_features, n_components)) x_r = x.reshape(-1, 1) # normalize logit weights to 1 prior_logits = jax.nn.log_softmax(prior_logits) # calculate the log cdf log_cdfs = prior_logits + jax.scipy.stats.norm.logcdf(x_r, means, scales) # normalize distribution log_cdf = jax.scipy.special.logsumexp(log_cdfs, axis=1) return np.exp(log_cdf)
def __call__(self, x: JaxArray, training: bool) -> JaxArray: if self.global_pool == 'avg': x = x.mean((2, 3), keepdims=True) x = self.conv_pw(x).reshape(x.shape[0], -1) x = self.act_fn(x) x = self.classifier(x) return x
def mixture_logistic_cdf( x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray ) -> JaxArray: """ Args: x (JaxArray): input vector (D,) prior_logits (JaxArray): prior logits to weight the components (D, K) means (JaxArray): means per component per feature (D, K) scales (JaxArray): scales per component per feature (D, K) Returns: log_cdf (JaxArray) : log CDF for the mixture distribution """ # print(prior_logits.shape) # n_features, n_components = prior_logits x_r = x.reshape(-1, 1) # # x_r = np.tile(x, (n_features, n_components)) # print(x.shape, x_r.shape) # normalize logit weights to 1, (D,K)->(D,K) prior_logits = log_softmax(prior_logits, axis=1) # calculate the log pdf, (D,K)->(D,K) log_cdfs = prior_logits + logistic_log_cdf(x_r, means, scales) # normalize distribution for components, (D,K)->(D,) log_cdf = logsumexp(log_cdfs, axis=1) return np.exp(log_cdf)
def mixture_logistic_log_pdf( x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray ) -> JaxArray: """ Args: x (JaxArray): input vector (D,) prior_logits (JaxArray): prior logits to weight the components (D, K) means (JaxArray): means per component per feature (D, K) scales (JaxArray): scales per component per feature (D, K) Returns: log_pdf (JaxArray) : log PDF for the mixture distribution """ # n_components = prior_logits.shape[1] # # add component dimension, (D,)->(D,1) # will allow for broadcasting x_r = x.reshape(-1, 1) # normalize logit weights to 1, (D,K)->(D,K) prior_logits = log_softmax(prior_logits, axis=1) # calculate the log pdf, (D,K)->(D,K) # print(x.shape, prior_logits.shape, ) log_pdfs = prior_logits + logistic_log_pdf(x_r, means, scales) # print("Log PDFS:", log_pdfs.shape) # normalize distribution for components, (D,K)->(D,) log_pdf = logsumexp(log_pdfs, axis=1) return log_pdf
def mixture_gaussian_log_pdf( x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray ) -> JaxArray: """ Args: x (JaxArray): input vector (D,) prior_logits (JaxArray): prior logits to weight the components (D, K) means (JaxArray): means per component per feature (D, K) scales (JaxArray): scales per component per feature (D, K) Returns: log_pdf (JaxArray) : log PDF for the mixture distribution """ # n_components = prior_logits.shape[1] # # x_r = np.tile(x, (n_components)) x_r = x.reshape(-1, 1) # normalize logit weights to 1, (D,K)->(D,K) prior_logits = log_softmax(prior_logits, axis=1) # calculate the log pdf, (D,K)->(D,K) log_pdfs = prior_logits + jax.scipy.stats.norm.logpdf(x_r, means, scales) # normalize distribution for components, (D,K)->(D,) log_pdf = logsumexp(log_pdfs, axis=1) return log_pdf
def __call__(self, x: JaxArray, training: bool) -> JaxArray: x = self.conv_pw(x) x = self.bn(x, training=training) x = self.act_fn(x) if self.global_pool == 'avg': x = x.mean((2, 3)) if self.classifier is not None: x = self.classifier(x) return x
def device_reshape(self, x: JaxArray) -> JaxArray: """Utility to reshape an input array in order to broadcast to multiple devices.""" assert hasattr(x, 'ndim'), f'Expected JaxArray, got {type(x)}. If you are trying to pass a scalar to ' \ f'parallel, first convert it to a JaxArray, for example np.float(0.5)' if x.ndim == 0: return jn.broadcast_to(x, [self.ndevices]) assert x.shape[0] % self.ndevices == 0, f'Must be able to equally divide batch {x.shape} among ' \ f'{self.ndevices} devices, but does not go equally.' return x.reshape((self.ndevices, x.shape[0] // self.ndevices) + x.shape[1:])
def __call__(self, x: JaxArray, training: bool, batch_norm_update: bool = True) -> JaxArray: if training: m = functional.parallel.pmean(x.mean(self.redux, keepdims=True)) v = functional.parallel.pmean((x ** 2).mean(self.redux, keepdims=True) - m ** 2) if batch_norm_update: self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value) self.running_var.value += (1 - self.momentum) * (v - self.running_var.value) else: m, v = self.running_mean.value, self.running_var.value y = self.gamma.value * (x - m) * functional.rsqrt(v + self.eps) + self.beta.value return y
def flatten(x: JaxArray) -> JaxArray: """Flattens input tensor to a 2D tensor. Args: x: input tensor with dimensions (n_1, n_2, ..., n_k) Returns: The input tensor reshaped to two dimensions (n_1, n_prod), where n_prod is equal to the product of n_2 to n_k. """ return x.reshape([x.shape[0], -1])
def channel_to_space2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2) -> JaxArray: """Transfer channel dimension C into spatial dimensions (H, W). Args: x: input tensor of shape (N, C, H, W). size: size of spatial area. Returns: output tensor of shape (N, C // (size[0] * size[1]), H * size[0], W * size[1]). """ size = to_tuple(size, 2) s = x.shape y = x.reshape((s[0], -1, size[0], size[1], s[2], s[3])) y = y.transpose((0, 1, 4, 2, 5, 3)) return y.reshape( (s[0], s[1] // (size[0] * size[1]), s[2] * size[0], s[3] * size[1]))
def reshape_microbatch(self, x: JaxArray) -> JaxArray: """Reshapes examples into microbatches. DP-SGD requires that per-example gradients are clipped and noised, however this can be inefficient. To speed this up, it is possible to clip and noise a microbatch of examples, at a sight cost to privacy. If speed is not an issue, the microbatch size should be set to 1. If x has shape [D0, D1, ..., Dn], the reshaped output will have shape [number_of_microbatches, microbatch_size, D1, ..., DN]. Args: x: items to be reshaped. Returns: The reshaped items. """ s = x.shape return x.reshape([s[0] // self.microbatch, self.microbatch, *s[1:]])
def space_to_channel2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2) -> JaxArray: """Transfer spatial dimensions (H, W) into channel dimension C. Args: x: input tensor of shape (N, C, H, W). size: size of spatial area. Returns: output tensor of shape (N, C * size[0] * size[1]), H // size[0], W // size[1]). """ size = to_tuple(size, 2) s = x.shape y = x.reshape( (s[0], s[1], s[2] // size[0], size[0], s[3] // size[1], size[1])) y = y.transpose((0, 1, 3, 5, 2, 4)) return y.reshape( (s[0], s[1] * size[0] * size[1], s[2] // size[0], s[3] // size[1]))
def upsample_2d(x: JaxArray, scale: Union[Tuple[int, int], int], method: Union[Interpolate, str] = Interpolate.BILINEAR) -> JaxArray: """Function to upscale 2D images. Args: x: input tensor. scale: int or tuple scaling factor method: str or UpSample interpolation methods e.g. ['bilinear', 'nearest']. Returns: upscaled 2d image tensor """ s = x.shape assert len(s) == 4, f'{s} must have 4 dimensions to be upsampled, or you can try interpolate function.' scale = util.to_tuple(scale, 2) y = jax.image.resize(x.transpose([0, 2, 3, 1]), shape=(s[0], s[2] * scale[0], s[3] * scale[1], s[1]), method=util.to_interpolate(method)) return y.transpose([0, 3, 1, 2])
def __call__(self, x: JaxArray, training: bool) -> JaxArray: """Performs batch normalization of input tensor. Args: x: input tensor. training: if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics). Returns: Batch normalized tensor. """ shape = (1, -1, 1, 1) weight = self.weight.value.reshape(shape) bias = self.bias.value.reshape(shape) if training: mean = x.mean(self.redux, keepdims=True) var = (x ** 2).mean(self.redux, keepdims=True) - mean ** 2 self.running_mean.value += (1 - self.momentum) * (mean.squeeze(axis=self.redux) - self.running_mean.value) self.running_var.value += (1 - self.momentum) * (var.squeeze(axis=self.redux) - self.running_var.value) else: mean, var = self.running_mean.value.reshape(shape), self.running_var.value.reshape(shape) y = weight * (x - mean) * functional.rsqrt(var + self.eps) + bias return y
def device_reshape(self, x: JaxArray) -> JaxArray: """Utility to reshape an input array in order to broadcast to multiple devices.""" return x.reshape((self.ndevices, x.shape[0] // self.ndevices) + x.shape[1:])
def _mean_reduce(x: JaxArray) -> JaxArray: return x.mean((2, 3))
def reduce_mean(x: JaxArray) -> JaxArray: return x.mean(0)
def kl(p: JaxArray, q: JaxArray, eps: float = 2**-17) -> JaxArray: """Calculates the Kullback-Leibler divergence between arrays p and q.""" return p.dot(jn.log(p + eps) - jn.log(q + eps))
def device_reshape(self, x: JaxArray) -> JaxArray: """Utility to reshape an input array in order to broadcast to multiple devices.""" assert x.shape[0] % self.ndevices == 0, f'Must be able to equally divide batch {x.shape} among ' \ f'{self.ndevices} devices, but does not go equally.' return x.reshape((self.ndevices, x.shape[0] // self.ndevices) + x.shape[1:])