def __call__(self, x: JaxArray, training: bool) -> JaxArray: """Performs batch normalization of input tensor. Args: x: input tensor. training: if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics). Returns: Batch normalized tensor. """ if training: m = x.mean(self.redux, keepdims=True) v = ((x - m)**2).mean( self.redux, keepdims=True) # Note: x^2 - m^2 is not numerically stable. self.running_mean.value += (1 - self.momentum) * ( m - self.running_mean.value) self.running_var.value += (1 - self.momentum) * ( v - self.running_var.value) else: m, v = self.running_mean.value, self.running_var.value y = self.gamma.value * ( x - m) * functional.rsqrt(v + self.eps) + self.beta.value return y
def __call__(self, x: JaxArray, training: bool) -> JaxArray: if self.global_pool == 'avg': x = x.mean((2, 3), keepdims=True) x = self.conv_pw(x).reshape(x.shape[0], -1) x = self.act_fn(x) x = self.classifier(x) return x
def __call__(self, x: JaxArray, training: bool) -> JaxArray: x = self.conv_pw(x) x = self.bn(x, training=training) x = self.act_fn(x) if self.global_pool == 'avg': x = x.mean((2, 3)) if self.classifier is not None: x = self.classifier(x) return x
def __call__(self, x: JaxArray, training: bool, batch_norm_update: bool = True) -> JaxArray: if training: m = functional.parallel.pmean(x.mean(self.redux, keepdims=True)) v = functional.parallel.pmean((x ** 2).mean(self.redux, keepdims=True) - m ** 2) if batch_norm_update: self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value) self.running_var.value += (1 - self.momentum) * (v - self.running_var.value) else: m, v = self.running_mean.value, self.running_var.value y = self.gamma.value * (x - m) * functional.rsqrt(v + self.eps) + self.beta.value return y
def __call__(self, x: JaxArray, training: bool = True) -> JaxArray: """Returns the results of applying group normalization to input x.""" # This method has unused training argument, so group norm can be used as a drop-in replacement of batch norm. del training group_shape = ((-1, self.groups, self.nin // self.groups) + x.shape[2:]) x = x.reshape(group_shape) mean = x.mean(axis=self.redux, keepdims=True) var = x.var(axis=self.redux, keepdims=True) x = (x - mean) * functional.rsqrt(var + self.eps) x = x.reshape((-1, self.nin,) + group_shape[3:]) x = x * self.gamma.value + self.beta.value return x
def __call__(self, x: JaxArray, training: bool) -> JaxArray: """Performs batch normalization of input tensor. Args: x: input tensor. training: if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics). Returns: Batch normalized tensor. """ shape = (1, -1, 1, 1) weight = self.weight.value.reshape(shape) bias = self.bias.value.reshape(shape) if training: mean = x.mean(self.redux, keepdims=True) var = (x ** 2).mean(self.redux, keepdims=True) - mean ** 2 self.running_mean.value += (1 - self.momentum) * (mean.squeeze(axis=self.redux) - self.running_mean.value) self.running_var.value += (1 - self.momentum) * (var.squeeze(axis=self.redux) - self.running_var.value) else: mean, var = self.running_mean.value.reshape(shape), self.running_var.value.reshape(shape) y = weight * (x - mean) * functional.rsqrt(var + self.eps) + bias return y
def _mean_reduce(x: JaxArray) -> JaxArray: return x.mean((2, 3))
def reduce_mean(x: JaxArray) -> JaxArray: return x.mean(0)