def native_layer_norm( input: Tensor, normalized_shape: List[int], weight: Optional[Tensor], bias: Optional[Tensor], eps: float, ) -> Tuple[Tensor, Tensor, Tensor]: computation_dtype = utils.get_computation_dtype(input.dtype) axis = input.dim() - len(normalized_shape) if prod(list(input.shape[:axis])) == 0: mean = input.new_zeros((0, ), dtype=computation_dtype) rstd = input.new_zeros((0, ), dtype=computation_dtype) out = input else: reduction_dims = list(range(axis, input.dim())) out, mean, rstd = normalize(input, reduction_dims, eps) if weight is not None: out = out * weight if bias is not None: out = out + bias out = out.to(dtype=input.dtype) if input.device.type == 'cpu': mean = mean.to(dtype=input.dtype) rstd = rstd.to(dtype=input.dtype) return (out, mean, rstd)
def native_batch_norm( input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: bool, momentum: float, eps: float, ) -> Tuple[Tensor, Tensor, Tensor]: reduction_dims = [0] + list(range(2, input.dim())) computation_dtype = utils.get_computation_dtype(input.dtype) if training: output, mean, rstd = normalize(input, reduction_dims, eps) save_mean = _squeeze_multiple(mean, reduction_dims) save_rstd = _squeeze_multiple(rstd, reduction_dims) if running_mean is not None: running_mean.copy_(momentum * save_mean + (1 - momentum) * running_mean) if running_var is not None: n = input.numel() / input.shape[1] # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose # numerics probably don't matter. unbiased_var = torch.var(input, reduction_dims, unbiased=False) * (n / (n - 1)) running_var.copy_(momentum * unbiased_var + (1 - momentum) * running_var) else: assert running_mean is not None and running_var is not None running_mean = running_mean.to(dtype=computation_dtype) running_var = running_var.to(dtype=computation_dtype) mean = running_mean invstd = 1 / (torch.sqrt(running_var + eps)) # Very annoying inconsistency where CPU and CUDA give different shapes if input.device.type != "cpu": save_mean = running_mean save_rstd = invstd else: save_mean = input.new_zeros((0, )) save_rstd = input.new_zeros((0, )) mean = _unsqueeze_to_dim(mean, input.dim() - 1) invstd = _unsqueeze_to_dim(invstd, input.dim() - 1) output = ((input - mean) * invstd) if weight is None: weight = input.new_ones(()) if bias is None: bias = input.new_zeros(()) weight = _unsqueeze_to_dim(weight, input.dim() - 1) bias = _unsqueeze_to_dim(bias, input.dim() - 1) output = output * weight + bias if input.device.type == 'cpu': save_mean = save_mean.to(dtype=input.dtype) save_rstd = save_rstd.to(dtype=input.dtype) return output.to(dtype=input.dtype), save_mean, save_rstd
def normalize(input, norm_dims, eps): computation_dtype = utils.get_computation_dtype(input.dtype) input_acc = input.to(dtype=computation_dtype) biased_var = torch.var(input_acc, dim=norm_dims, unbiased=False, keepdim=True) mean = torch.mean(input_acc, dim=norm_dims, keepdim=True) rstd = torch.rsqrt(biased_var + eps) out = (input - mean) * rstd return out, mean, rstd
def native_layer_norm_backward( grad_out: Tensor, input: Tensor, normalized_shape: List[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], output_mask: List[bool], ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: input_shape = input.shape input_ndim = input.dim() computation_dtype = utils.get_computation_dtype(input.dtype) grad_out_cast, input_cast, weight_cast, bias_cast = [ x.to(computation_dtype) if x is not None else x for x in (grad_out, input, weight, bias) ] assert grad_out_cast is not None axis = input_ndim - len(normalized_shape) inner_dims = input_shape[axis:] outer_dims = input_shape[:axis] inner_dim_indices: List[int] = [] outer_dim_indices: List[int] = [] for i in range(input_ndim): if i >= axis: inner_dim_indices.append(i) else: outer_dim_indices.append(i) N = prod(inner_dims) # type: ignore[arg-type] M = prod(outer_dims) # type: ignore[arg-type] if M <= 0 or N <= 0: return ( input.new_zeros(input_shape), input.new_zeros(input_shape[axis:]), input.new_zeros(input_shape[axis:]), ) x_hat = (input_cast - mean) * rstd if weight_cast is not None: grad_x_hat = grad_out_cast * weight_cast else: grad_x_hat = grad_out_cast a = grad_x_hat * N b = torch.sum(grad_x_hat, inner_dim_indices, True) c1 = torch.mul(grad_x_hat, x_hat) c2 = torch.sum(c1, inner_dim_indices, True) c3 = torch.mul(x_hat, c2) inner = a - b - c3 d_input: Optional[Tensor] = None d_weight: Optional[Tensor] = None d_bias: Optional[Tensor] = None if output_mask[0]: d_input = (rstd / N) * inner if output_mask[1] and weight_cast is not None: if len(outer_dim_indices) > 0: d_weight = torch.sum(grad_out_cast * x_hat, outer_dim_indices, False) else: d_weight = grad_out_cast * x_hat if output_mask[2] and bias_cast is not None: if len(outer_dim_indices) > 0: d_bias = torch.sum(grad_out_cast, outer_dim_indices, False) else: d_bias = grad_out_cast return _maybe_cast(d_input, input.dtype), _maybe_cast( d_weight, input.dtype), _maybe_cast(d_bias, input.dtype)