Beispiel #1
0
def _calculate_shape(output: torch.Tensor, grad: torch.Tensor,
                     is_grads_batched: bool) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
    # is_same_size ensures that both tensors are either nested or non nested
    if output.is_nested:
        if is_grads_batched:
            raise RuntimeError("Batched grads are not supported with Nested Tensor.")
        out_shape = output._nested_tensor_size()
        grad_shape = grad._nested_tensor_size()

        return out_shape, grad_shape

    reg_out_shape = output.shape
    reg_grad_shape = grad.shape if not is_grads_batched else grad.shape[1:]
    return reg_out_shape, reg_grad_shape