def __call__(self, x: JaxArray, training: bool) -> JaxArray: """Performs batch normalization of input tensor. Args: x: input tensor. training: if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics). Returns: Batch normalized tensor. """ if training: m = x.mean(self.redux, keepdims=True) v = ((x - m)**2).mean( self.redux, keepdims=True) # Note: x^2 - m^2 is not numerically stable. self.running_mean.value += (1 - self.momentum) * ( m - self.running_mean.value) self.running_var.value += (1 - self.momentum) * ( v - self.running_var.value) else: m, v = self.running_mean.value, self.running_var.value y = self.gamma.value * ( x - m) * functional.rsqrt(v + self.eps) + self.beta.value return y
def __call__(self, lr: float, grads: List[JaxArray], beta1: Optional[float] = None, beta2: Optional[float] = None): """Updates variables and other state based on Adam algorithm. Args: lr: the learning rate. grads: the gradients to apply. beta1: optional, override the default beta1. beta2: optional, override the default beta2. """ assert len(grads) == len( self.train_vars ), 'Expecting as many gradients as trainable variables' if beta1 is None: beta1 = self.beta1 if beta2 is None: beta2 = self.beta2 self.step.value += 1 lr *= jn.sqrt(1 - beta2**self.step.value) / (1 - beta1**self.step.value) for g, p, m, v in zip(grads, self.train_vars, self.m, self.v): m.value += (1 - beta1) * (g - m.value) v.value += (1 - beta2) * (g**2 - v.value) p.value -= lr * m.value * functional.rsqrt(v.value + self.eps)
def __call__(self, x, mask, training): sum_dims = list(range(len(vals.shape[:-1]))) x_or_zero = jnp.where(mask[..., None], vals, 0 * vals) smask = jax.device_put(scalar_mask(self.rep)) if training: num_valid = mask.sum(sum_dims) m = x_or_zero.sum(sum_dims) / num_valid x2 = (x_or_zero**2).sum(sum_dims) / num_valid v = x2 - m**2 v = jnp.where(smask, v, ragged_gather_scatter(x2, self.rep)) self.running_mean.value += (1 - self.momentum) * ( m - self.running_mean.value) self.running_var.value += (1 - self.momentum) * ( v - self.running_var.value) else: m, v = self.running_mean.value, self.running_var.value y = jnp.where(smask,self.gamma.value * (x_or_zero - m) * F.rsqrt(v + self.eps) + \ self.beta.value,x_or_zero*F.rsqrt(v+self.eps)) return y, mask # switch to or (x-m)
def __call__(self, x: JaxArray, training: bool, batch_norm_update: bool = True) -> JaxArray: if training: m = functional.parallel.pmean(x.mean(self.redux, keepdims=True)) v = functional.parallel.pmean((x ** 2).mean(self.redux, keepdims=True) - m ** 2) if batch_norm_update: self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value) self.running_var.value += (1 - self.momentum) * (v - self.running_var.value) else: m, v = self.running_mean.value, self.running_var.value y = self.gamma.value * (x - m) * functional.rsqrt(v + self.eps) + self.beta.value return y
def __call__(self, x: JaxArray, training: bool = True) -> JaxArray: """Returns the results of applying group normalization to input x.""" # This method has unused training argument, so group norm can be used as a drop-in replacement of batch norm. del training group_shape = ((-1, self.groups, self.nin // self.groups) + x.shape[2:]) x = x.reshape(group_shape) mean = x.mean(axis=self.redux, keepdims=True) var = x.var(axis=self.redux, keepdims=True) x = (x - mean) * functional.rsqrt(var + self.eps) x = x.reshape((-1, self.nin,) + group_shape[3:]) x = x * self.gamma.value + self.beta.value return x
def __call__(self, lr: float, grads: List[JaxArray]): """Updates variables and other state based on Adam algorithm. Args: lr: the learning rate. grads: the gradients to apply. """ assert len(grads) == len(self.train_vars), 'Expecting as many gradients as trainable variables' self.step.value += 1 lr *= jn.sqrt(1 - self.beta2 ** self.step.value) / (1 - self.beta1 ** self.step.value) for g, p, m, v in zip(grads, self.train_vars, self.m, self.v): m.value += (1 - self.beta1) * (g - m.value) v.value += (1 - self.beta2) * (g ** 2 - v.value) p.value -= lr * m.value * functional.rsqrt(v.value + self.eps)
def __call__(self, x, training): #TODO: support elementwise for regular reps #return x #DISABLE BN, harms performance!! !! smask = jax.device_put(scalar_mask(self.rep)) if training: m = x.mean(self.redux, keepdims=True) v = (x**2).mean(self.redux, keepdims=True) - m**2 v = jnp.where( smask, v, ragged_gather_scatter( (x**2).mean(self.redux), self.rep)) #in non scalar indices, divide by sum squared self.running_mean.value += (1 - self.momentum) * ( m - self.running_mean.value) self.running_var.value += (1 - self.momentum) * ( v - self.running_var.value) else: m, v = self.running_mean.value, self.running_var.value y = jnp.where(smask, self.gamma.value * (x - m) * F.rsqrt(v + self.eps) + self.beta.value, x * F.rsqrt(v + self.eps)) #(x-m)*F.rsqrt(v + self.eps)) return y # switch to or (x-m)
def __call__(self, vals, mask, training=True): sum_dims = list(range(len(vals.shape[:-1]))) x_or_zero = jnp.where(mask[..., None], vals, 0 * vals) if training: num_valid = mask.sum(sum_dims) m = x_or_zero.sum(sum_dims) / num_valid v = (x_or_zero**2).sum(sum_dims) / num_valid - m**2 self.running_mean.value += (1 - self.momentum) * ( m - self.running_mean.value) self.running_var.value += (1 - self.momentum) * ( v - self.running_var.value) else: m, v = self.running_mean.value, self.running_var.value return ((x_or_zero - m) * self.gamma.value * F.rsqrt(v + self.eps) + self.beta.value, mask)
def __call__(self, x: JaxArray, training: bool) -> JaxArray: """Performs batch normalization of input tensor. Args: x: input tensor. training: if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics). Returns: Batch normalized tensor. """ shape = (1, -1, 1, 1) weight = self.weight.value.reshape(shape) bias = self.bias.value.reshape(shape) if training: mean = x.mean(self.redux, keepdims=True) var = (x ** 2).mean(self.redux, keepdims=True) - mean ** 2 self.running_mean.value += (1 - self.momentum) * (mean.squeeze(axis=self.redux) - self.running_mean.value) self.running_var.value += (1 - self.momentum) * (var.squeeze(axis=self.redux) - self.running_var.value) else: mean, var = self.running_mean.value.reshape(shape), self.running_var.value.reshape(shape) y = weight * (x - mean) * functional.rsqrt(var + self.eps) + bias return y