def forward(self, x: Tensor) -> Tensor: r""" Args: x (Tensor): (batch_size, dim_2, dim_3, ...) arbitrary number of dims after batch_size Returns: out (Tensor): (batch_size, dim_2 * dim_3 * ...) batch_size, then all other dims flattened """ dim1 = x.shape[0] dim2 = np.prod(x.shape[1:]) out = x.reshape((dim1, dim2)) out.name = 'flatten_res' return out
def _normalize(self, x: Tensor, mean: Tensor, var: Tensor) -> Tensor: r""" Normalize a Tensor with mean and variance Args: x (Tensor): tensor to normalize mean (Tensor): mean of the tensor var (Tensor): variance of the tensor """ x_hat = (x - mean.reshape(shape=[1, -1, 1, 1])) / ( var + self.eps).sqrt().reshape(shape=[1, -1, 1, 1]) out = self.gamma.reshape( shape=[1, -1, 1, 1]) * x_hat + self.beta.reshape( shape=[1, -1, 1, 1]) out.name = 'bn_2d_res' return out