def _calculate_shape(output: torch.Tensor, grad: torch.Tensor, is_grads_batched: bool) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]: # is_same_size ensures that both tensors are either nested or non nested if output.is_nested: if is_grads_batched: raise RuntimeError("Batched grads are not supported with Nested Tensor.") out_shape = output._nested_tensor_size() grad_shape = grad._nested_tensor_size() return out_shape, grad_shape reg_out_shape = output.shape reg_grad_shape = grad.shape if not is_grads_batched else grad.shape[1:] return reg_out_shape, reg_grad_shape