def forward(self, inp: torch.Tensor): if type(inp) is MultiSampleTensor: return inp.apply_batch(self.forward) if inp.ndim == 2: reduce_axes = [0] def view(x): return x.view(1, x.size(0)) else: assert inp.ndim == 4 reduce_axes = [0, 2, 3] def view(x): return x.view(1, x.size(0), 1, 1) if type(inp) is AbstractTensor: scale, bias = self._get_scale_bias(self.running_var, self.running_mean) scale = view(scale) bias = view(bias) return inp.apply_linear(scale, lambda x, w: x * w + bias) if self.training: if self.use_scalar_scale: var = inp.flatten().var(unbiased=True) else: var = inp.var(reduce_axes, unbiased=True) mean = inp.mean(reduce_axes) with torch.no_grad(): self.running_var = ( self.running_var.mul_(1 - self.momentum).add_( var * self.momentum).detach_()) self.running_mean = ( self.running_mean.mul_(1 - self.momentum).add_( mean * self.momentum).detach_()) else: var = self.running_var mean = self.running_mean scale, bias = self._get_scale_bias(var, mean) ret = inp * view(scale) + view(bias) self.owner.on_bn_internals(self, scale, bias) return ret