def __call__(self, x: JaxArray, training: bool = True) -> JaxArray: """Returns the results of applying group normalization to input x.""" # This method has unused training argument, so group norm can be used as a drop-in replacement of batch norm. del training group_shape = ((-1, self.groups, self.nin // self.groups) + x.shape[2:]) x = x.reshape(group_shape) mean = x.mean(axis=self.redux, keepdims=True) var = x.var(axis=self.redux, keepdims=True) x = (x - mean) * functional.rsqrt(var + self.eps) x = x.reshape((-1, self.nin,) + group_shape[3:]) x = x * self.gamma.value + self.beta.value return x
def upscale_nn(x: JaxArray, scale: int = 2) -> JaxArray: """Nearest neighbor upscale for image batches of shape (N, C, H, W). Args: x: input tensor of shape (N, C, H, W). scale: integer scaling factor. Returns: Output tensor of shape (N, C, H * scale, W * scale). """ s = x.shape x = x.reshape(s[:2] + (s[2], 1, s[3], 1)) x = jn.tile(x, (1, 1, 1, scale, 1, scale)) return x.reshape(s[:2] + (scale * s[2], scale * s[3]))
def mixture_gaussian_log_pdf( x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray ) -> JaxArray: """ Args: x (JaxArray): input vector (D,) prior_logits (JaxArray): prior logits to weight the components (D, K) means (JaxArray): means per component per feature (D, K) scales (JaxArray): scales per component per feature (D, K) Returns: log_pdf (JaxArray) : log PDF for the mixture distribution """ # n_components = prior_logits.shape[1] # # x_r = np.tile(x, (n_components)) x_r = x.reshape(-1, 1) # normalize logit weights to 1, (D,K)->(D,K) prior_logits = log_softmax(prior_logits, axis=1) # calculate the log pdf, (D,K)->(D,K) log_pdfs = prior_logits + jax.scipy.stats.norm.logpdf(x_r, means, scales) # normalize distribution for components, (D,K)->(D,) log_pdf = logsumexp(log_pdfs, axis=1) return log_pdf
def mixture_gaussian_cdf( x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray ) -> JaxArray: """ Args: x (JaxArray): input vector (D,) prior_logits (JaxArray): prior logits to weight the components (D, K) means (JaxArray): means per component per feature (D, K) scales (JaxArray): scales per component per feature (D, K) Returns: x_cdf (JaxArray) : CDF for the mixture distribution """ # n_features, n_components = prior_logits # # x_r = np.tile(x, (n_features, n_components)) x_r = x.reshape(-1, 1) # normalize logit weights to 1 prior_logits = jax.nn.log_softmax(prior_logits) # calculate the log cdf log_cdfs = prior_logits + jax.scipy.stats.norm.logcdf(x_r, means, scales) # normalize distribution log_cdf = jax.scipy.special.logsumexp(log_cdfs, axis=1) return np.exp(log_cdf)
def mixture_logistic_cdf( x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray ) -> JaxArray: """ Args: x (JaxArray): input vector (D,) prior_logits (JaxArray): prior logits to weight the components (D, K) means (JaxArray): means per component per feature (D, K) scales (JaxArray): scales per component per feature (D, K) Returns: log_cdf (JaxArray) : log CDF for the mixture distribution """ # print(prior_logits.shape) # n_features, n_components = prior_logits x_r = x.reshape(-1, 1) # # x_r = np.tile(x, (n_features, n_components)) # print(x.shape, x_r.shape) # normalize logit weights to 1, (D,K)->(D,K) prior_logits = log_softmax(prior_logits, axis=1) # calculate the log pdf, (D,K)->(D,K) log_cdfs = prior_logits + logistic_log_cdf(x_r, means, scales) # normalize distribution for components, (D,K)->(D,) log_cdf = logsumexp(log_cdfs, axis=1) return np.exp(log_cdf)
def mixture_logistic_log_pdf( x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray ) -> JaxArray: """ Args: x (JaxArray): input vector (D,) prior_logits (JaxArray): prior logits to weight the components (D, K) means (JaxArray): means per component per feature (D, K) scales (JaxArray): scales per component per feature (D, K) Returns: log_pdf (JaxArray) : log PDF for the mixture distribution """ # n_components = prior_logits.shape[1] # # add component dimension, (D,)->(D,1) # will allow for broadcasting x_r = x.reshape(-1, 1) # normalize logit weights to 1, (D,K)->(D,K) prior_logits = log_softmax(prior_logits, axis=1) # calculate the log pdf, (D,K)->(D,K) # print(x.shape, prior_logits.shape, ) log_pdfs = prior_logits + logistic_log_pdf(x_r, means, scales) # print("Log PDFS:", log_pdfs.shape) # normalize distribution for components, (D,K)->(D,) log_pdf = logsumexp(log_pdfs, axis=1) return log_pdf
def device_reshape(self, x: JaxArray) -> JaxArray: """Utility to reshape an input array in order to broadcast to multiple devices.""" assert hasattr(x, 'ndim'), f'Expected JaxArray, got {type(x)}. If you are trying to pass a scalar to ' \ f'parallel, first convert it to a JaxArray, for example np.float(0.5)' if x.ndim == 0: return jn.broadcast_to(x, [self.ndevices]) assert x.shape[0] % self.ndevices == 0, f'Must be able to equally divide batch {x.shape} among ' \ f'{self.ndevices} devices, but does not go equally.' return x.reshape((self.ndevices, x.shape[0] // self.ndevices) + x.shape[1:])
def flatten(x: JaxArray) -> JaxArray: """Flattens input tensor to a 2D tensor. Args: x: input tensor with dimensions (n_1, n_2, ..., n_k) Returns: The input tensor reshaped to two dimensions (n_1, n_prod), where n_prod is equal to the product of n_2 to n_k. """ return x.reshape([x.shape[0], -1])
def reshape_microbatch(self, x: JaxArray) -> JaxArray: """Reshapes examples into microbatches. DP-SGD requires that per-example gradients are clipped and noised, however this can be inefficient. To speed this up, it is possible to clip and noise a microbatch of examples, at a sight cost to privacy. If speed is not an issue, the microbatch size should be set to 1. If x has shape [D0, D1, ..., Dn], the reshaped output will have shape [number_of_microbatches, microbatch_size, D1, ..., DN]. Args: x: items to be reshaped. Returns: The reshaped items. """ s = x.shape return x.reshape([s[0] // self.microbatch, self.microbatch, *s[1:]])
def channel_to_space2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2) -> JaxArray: """Transfer channel dimension C into spatial dimensions (H, W). Args: x: input tensor of shape (N, C, H, W). size: size of spatial area. Returns: output tensor of shape (N, C // (size[0] * size[1]), H * size[0], W * size[1]). """ size = to_tuple(size, 2) s = x.shape y = x.reshape((s[0], -1, size[0], size[1], s[2], s[3])) y = y.transpose((0, 1, 4, 2, 5, 3)) return y.reshape( (s[0], s[1] // (size[0] * size[1]), s[2] * size[0], s[3] * size[1]))
def space_to_channel2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2) -> JaxArray: """Transfer spatial dimensions (H, W) into channel dimension C. Args: x: input tensor of shape (N, C, H, W). size: size of spatial area. Returns: output tensor of shape (N, C * size[0] * size[1]), H // size[0], W // size[1]). """ size = to_tuple(size, 2) s = x.shape y = x.reshape( (s[0], s[1], s[2] // size[0], size[0], s[3] // size[1], size[1])) y = y.transpose((0, 1, 3, 5, 2, 4)) return y.reshape( (s[0], s[1] * size[0] * size[1], s[2] // size[0], s[3] // size[1]))
def device_reshape(self, x: JaxArray) -> JaxArray: """Utility to reshape an input array in order to broadcast to multiple devices.""" assert x.shape[0] % self.ndevices == 0, f'Must be able to equally divide batch {x.shape} among ' \ f'{self.ndevices} devices, but does not go equally.' return x.reshape((self.ndevices, x.shape[0] // self.ndevices) + x.shape[1:])
def device_reshape(self, x: JaxArray) -> JaxArray: """Utility to reshape an input array in order to broadcast to multiple devices.""" return x.reshape((self.ndevices, x.shape[0] // self.ndevices) + x.shape[1:])