Beispiel #1
0
    def forward(self, inp: torch.Tensor):
        if type(inp) is MultiSampleTensor:
            return inp.apply_batch(self.forward)

        if inp.ndim == 2:
            reduce_axes = [0]

            def view(x):
                return x.view(1, x.size(0))
        else:
            assert inp.ndim == 4
            reduce_axes = [0, 2, 3]

            def view(x):
                return x.view(1, x.size(0), 1, 1)

        if type(inp) is AbstractTensor:
            scale, bias = self._get_scale_bias(self.running_var,
                                               self.running_mean)
            scale = view(scale)
            bias = view(bias)
            return inp.apply_linear(scale, lambda x, w: x * w + bias)

        if self.training:
            if self.use_scalar_scale:
                var = inp.flatten().var(unbiased=True)
            else:
                var = inp.var(reduce_axes, unbiased=True)
            mean = inp.mean(reduce_axes)
            with torch.no_grad():
                self.running_var = (
                    self.running_var.mul_(1 - self.momentum).add_(
                        var * self.momentum).detach_())
                self.running_mean = (
                    self.running_mean.mul_(1 - self.momentum).add_(
                        mean * self.momentum).detach_())
        else:
            var = self.running_var
            mean = self.running_mean

        scale, bias = self._get_scale_bias(var, mean)
        ret = inp * view(scale) + view(bias)
        self.owner.on_bn_internals(self, scale, bias)
        return ret