def forward(self, input: Tensor) -> Tensor: assert (len(input.shape) >= 3), "The dimensions of input tensor must larger than 2" assert (input.shape[1] == self.num_channels ), "The channels of input tensor must equal num_channels" origin_shape = input.shape reshape_to_1d = flow.reshape( input, shape=[origin_shape[0], self.num_groups, -1]) mean = flow.mean(reshape_to_1d, dim=2, keepdim=True) variance = flow.var(reshape_to_1d, dim=2, unbiased=False, keepdim=True) normalized = (reshape_to_1d - mean) / flow.sqrt(variance + self.eps) normalized = flow.reshape( normalized, shape=[origin_shape[0], self.num_channels, -1]) if self.weight is not None: normalized = normalized * self.weight.reshape( 1, self.num_channels, 1) if self.bias is not None: normalized = normalized + self.bias.reshape( 1, self.num_channels, 1) res = flow.reshape(normalized, shape=tuple(input.shape)) return res
def forward(self, x): return flow.var(x, 1, unbiased=False, keepdim=True)