Esempio n. 1
0
    def __call__(self, x: JaxArray, training: bool) -> JaxArray:
        """Performs batch normalization of input tensor.

        Args:
            x: input tensor.
            training: if True compute batch normalization in training mode (accumulating batch statistics),
                otherwise compute in evaluation mode (using already accumulated batch statistics).

        Returns:
            Batch normalized tensor.
        """
        if training:
            m = x.mean(self.redux, keepdims=True)
            v = ((x - m)**2).mean(
                self.redux,
                keepdims=True)  # Note: x^2 - m^2 is not numerically stable.
            self.running_mean.value += (1 - self.momentum) * (
                m - self.running_mean.value)
            self.running_var.value += (1 - self.momentum) * (
                v - self.running_var.value)
        else:
            m, v = self.running_mean.value, self.running_var.value
        y = self.gamma.value * (
            x - m) * functional.rsqrt(v + self.eps) + self.beta.value
        return y
Esempio n. 2
0
 def __call__(self, x: JaxArray, training: bool) -> JaxArray:
     if self.global_pool == 'avg':
         x = x.mean((2, 3), keepdims=True)
     x = self.conv_pw(x).reshape(x.shape[0], -1)
     x = self.act_fn(x)
     x = self.classifier(x)
     return x
Esempio n. 3
0
 def __call__(self, x: JaxArray, training: bool) -> JaxArray:
     x = self.conv_pw(x)
     x = self.bn(x, training=training)
     x = self.act_fn(x)
     if self.global_pool == 'avg':
         x = x.mean((2, 3))
     if self.classifier is not None:
         x = self.classifier(x)
     return x
Esempio n. 4
0
 def __call__(self, x: JaxArray, training: bool, batch_norm_update: bool = True) -> JaxArray:
     if training:
         m = functional.parallel.pmean(x.mean(self.redux, keepdims=True))
         v = functional.parallel.pmean((x ** 2).mean(self.redux, keepdims=True) - m ** 2)
         if batch_norm_update:
             self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value)
             self.running_var.value += (1 - self.momentum) * (v - self.running_var.value)
     else:
         m, v = self.running_mean.value, self.running_var.value
     y = self.gamma.value * (x - m) * functional.rsqrt(v + self.eps) + self.beta.value
     return y
Esempio n. 5
0
 def __call__(self, x: JaxArray, training: bool = True) -> JaxArray:
     """Returns the results of applying group normalization to input x."""
     # This method has unused training argument, so group norm can be used as a drop-in replacement of batch norm.
     del training
     group_shape = ((-1, self.groups, self.nin // self.groups) + x.shape[2:])
     x = x.reshape(group_shape)
     mean = x.mean(axis=self.redux, keepdims=True)
     var = x.var(axis=self.redux, keepdims=True)
     x = (x - mean) * functional.rsqrt(var + self.eps)
     x = x.reshape((-1, self.nin,) + group_shape[3:])
     x = x * self.gamma.value + self.beta.value
     return x
Esempio n. 6
0
    def __call__(self, x: JaxArray, training: bool) -> JaxArray:
        """Performs batch normalization of input tensor.

        Args:
            x: input tensor.
            training: if True compute batch normalization in training mode (accumulating batch statistics),
                otherwise compute in evaluation mode (using already accumulated batch statistics).

        Returns:
            Batch normalized tensor.
        """
        shape = (1, -1, 1, 1)
        weight = self.weight.value.reshape(shape)
        bias = self.bias.value.reshape(shape)
        if training:
            mean = x.mean(self.redux, keepdims=True)
            var = (x ** 2).mean(self.redux, keepdims=True) - mean ** 2
            self.running_mean.value += (1 - self.momentum) * (mean.squeeze(axis=self.redux) - self.running_mean.value)
            self.running_var.value += (1 - self.momentum) * (var.squeeze(axis=self.redux) - self.running_var.value)
        else:
            mean, var = self.running_mean.value.reshape(shape), self.running_var.value.reshape(shape)

        y = weight * (x - mean) * functional.rsqrt(var + self.eps) + bias
        return y
Esempio n. 7
0
 def _mean_reduce(x: JaxArray) -> JaxArray:
     return x.mean((2, 3))
Esempio n. 8
0
def reduce_mean(x: JaxArray) -> JaxArray:
    return x.mean(0)