def svd_decomposition( np, # TODO: Typing tensor: Tensor, split_axis: int, max_singular_values: Optional[int] = None, max_truncation_error: Optional[float] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Computes the singular value decomposition (SVD) of a tensor. See tensornetwork.backends.tensorflow.decompositions for details. """ left_dims = tensor.shape_tensor[:split_axis] right_dims = tensor.shape_tensor[split_axis:] tensor = np.reshape( tensor, [numpy.shape_prod(left_dims), numpy.shape_prod(right_dims)]) u, s, vh = np.linalg.svd(tensor) if max_singular_values is None: max_singular_values = np.size(s) if max_truncation_error is not None: # Cumulative norms of singular values in ascending order. trunc_errs = np.sqrt(np.cumsum(np.square(s[::-1]))) # We must keep at least this many singular values to ensure the # truncation error is <= max_truncation_error. num_sing_vals_err = np.count_nonzero( (trunc_errs > max_truncation_error).astype(np.int32)) else: num_sing_vals_err = max_singular_values num_sing_vals_keep = min(max_singular_values, num_sing_vals_err) # tf.svd() always returns the singular values as a vector of float{32,64}. # since tf.math_ops.real is automatically applied to s. This causes # s to possibly not be the same dtype as the original tensor, which can cause # issues for later contractions. To fix it, we recast to the original dtype. s = s.astype(tensor.dtype) s_rest = s[num_sing_vals_keep:] s = s[:num_sing_vals_keep] u = u[:, :num_sing_vals_keep] vh = vh[:num_sing_vals_keep, :] dim_s = s.shape_tensor[0] u = np.reshape(u, list(left_dims) + [dim_s]) vh = np.reshape(vh, [dim_s] + list(right_dims)) return u, s, vh, s_rest
def _generate_random_tensors_and_dims(dtype_, rank_a_, rank_b_, num_dims_): a_shape = np.random.randint(1, _MAXDIM + 1, rank_a_) b_shape = np.random.randint(1, _MAXDIM + 1, rank_b_) shared_shape = np.random.randint(1, _MAXDIM + 1, num_dims_) a_dims = _random_subset(num_dims_, rank_a_) b_dims = _random_subset(num_dims_, rank_b_) for i in range(num_dims_): a_shape[a_dims[i]] = shared_shape[i] b_shape[b_dims[i]] = shared_shape[i] a = np.random.uniform( low=-1.0, high=1.0, size=np.shape_prod(a_shape)).reshape(a_shape).astype(dtype_) b = np.random.uniform( low=-1.0, high=1.0, size=np.shape_prod(b_shape)).reshape(b_shape).astype(dtype_) return a, b, a_dims, b_dims
def rq_decomposition( torch: Any, tensor: Tensor, split_axis: int, ) -> Tuple[Tensor, Tensor]: """Computes the RQ decomposition of a tensor. The RQ decomposition is performed by treating the tensor as a matrix, with an effective left (row) index resulting from combining the axes `tensor.shape[:split_axis]` and an effective right (column) index resulting from combining the axes `tensor.shape[split_axis:]`. For example, if `tensor` had a shape (2, 3, 4, 5) and `split_axis` was 2, then `r` would have shape (2, 3, 6), and `q` would have shape (6, 4, 5). The output consists of two tensors `R, Q` such that: ```python R[i1,...,iN, j] * Q[j, k1,...,kM] == tensor[i1,...,iN, k1,...,kM] ``` `R` is a lower triangular matrix, `Q` is an orthonormal matrix Note that the output ordering matches numpy.linalg.svd rather than tf.svd. Args: tf: The tensorflow module. tensor: A tensor to be decomposed. split_axis: Where to split the tensor's axes before flattening into a p matrix. Returns: R: Left tensor factor. Q: Right tensor factor. """ left_dims = tensor.shape_tensor[:split_axis] right_dims = tensor.shape_tensor[split_axis:] tensor = torch.reshape( tensor, [np.shape_prod(left_dims), np.shape_prod(right_dims)]) #torch has currently no support for complex dtypes q, r = torch.qr(torch.transpose(tensor, 0, 1)) r, q = torch.transpose(r, 0, 1), torch.transpose(q, 0, 1) #M=r*q at this point center_dim = r.shape[1] r = torch.reshape(r, list(left_dims) + [center_dim]) q = torch.reshape(q, [center_dim] + list(right_dims)) return r, q
def qr_decomposition( torch: Any, tensor: Tensor, split_axis: int, ) -> Tuple[Tensor, Tensor]: """Computes the QR decomposition of a tensor. The QR decomposition is performed by treating the tensor as a matrix, with an effective left (row) index resulting from combining the axes `tensor.shape[:split_axis]` and an effective right (column) index resulting from combining the axes `tensor.shape[split_axis:]`. For example, if `tensor` had a shape (2, 3, 4, 5) and `split_axis` was 2, then `q` would have shape (2, 3, 6), and `r` would have shape (6, 4, 5). The output consists of two tensors `Q, R` such that: ```python Q[i1,...,iN, j] * R[j, k1,...,kM] == tensor[i1,...,iN, k1,...,kM] ``` `R` is an upper triangular matrix, `Q` is an orthonormal matrix Note that the output ordering matches numpy.linalg.svd rather than tf.svd. Args: tf: The tensorflow module. tensor: A tensor to be decomposed. split_axis: Where to split the tensor's axes before flattening into a matrix. Returns: Q: Left tensor factor. R: Right tensor factor. """ left_dims = list(tensor.shape_tensor)[:split_axis] right_dims = list(tensor.shape_tensor)[split_axis:] tensor = torch.reshape( tensor, (np.shape_prod(left_dims), np.shape_prod(right_dims))) q, r = torch.qr(tensor) center_dim = q.shape_tensor[1] q = torch.reshape(q, list(left_dims) + [center_dim]) r = torch.reshape(r, [center_dim] + list(right_dims)) return q, r
def qr_decomposition( np, # TODO: Typing tensor: Tensor, split_axis: int, ) -> Tuple[Tensor, Tensor]: """Computes the QR decomposition of a tensor. See tensornetwork.backends.tensorflow.decompositions for details. """ left_dims = tensor.shape_tensor[:split_axis] right_dims = tensor.shape_tensor[split_axis:] tensor = np.reshape( tensor, [numpy.shape_prod(left_dims), numpy.shape_prod(right_dims)]) q, r = np.linalg.qr(tensor) center_dim = q.shape[1] q = np.reshape(q, list(left_dims) + [center_dim]) r = np.reshape(r, [center_dim] + list(right_dims)) return q, r
def rq_decomposition( np, # TODO: Typing tensor: Tensor, split_axis: int, ) -> Tuple[Tensor, Tensor]: """Computes the RQ (reversed QR) decomposition of a tensor. See tensornetwork.backends.tensorflow.decompositions for details. """ left_dims = tensor.shape_tensor[:split_axis] right_dims = tensor.shape_tensor[split_axis:] tensor = np.reshape( tensor, [numpy.shape_prod(left_dims), numpy.shape_prod(right_dims)]) q, r = np.linalg.qr(np.conj(np.transpose(tensor))) r, q = np.conj(np.transpose(r)), np.conj( np.transpose(q)) #M=r*q at this point center_dim = r.shape_tensor[1] r = np.reshape(r, list(left_dims) + [center_dim]) q = np.reshape(q, [center_dim] + list(right_dims)) return r, q
def test_tensordot_scalar_axes(dtype_, rank_a_, rank_b_, num_dims_): if not num_dims_ <= min(rank_a_, rank_b_): pytest.skip("Not a test") if dtype_ == np.float16: tol = 0.05 elif dtype_ in (np.float32, np.complex64): tol = 1e-5 else: tol = 1e-12 shape = [5] * num_dims_ a_np = np.random.uniform( low=-1.0, high=1.0, size=np.shape_prod(shape)).reshape(shape).astype(dtype_) b_np = np.random.uniform( low=-1.0, high=1.0, size=np.shape_prod(shape)).reshape(shape).astype(dtype_) all_axes = [0, 1] if a_np.ndim > 2: all_axes.append(a_np.ndim - 1) for axes in all_axes: np_ans = np.tensordot(a_np, b_np, axes=axes) tf_ans = tensordot2.tensordot(tf, a_np, b_np, axes=axes) np.testing.assert_allclose(tf_ans, np_ans, rtol=tol, atol=tol) assert tf_ans.shape == np_ans.shape
def svd_decomposition( torch: Any, tensor: Tensor, split_axis: int, max_singular_values: Optional[int] = None, max_truncation_error: Optional[float] = None ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Computes the singular value decomposition (SVD) of a tensor. The SVD is performed by treating the tensor as a matrix, with an effective left (row) index resulting from combining the axes `tensor.shape[:split_axis]` and an effective right (column) index resulting from combining the axes `tensor.shape[split_axis:]`. For example, if `tensor` had a shape (2, 3, 4, 5) and `split_axis` was 2, then `u` would have shape (2, 3, 6), `s` would have shape (6), and `vh` would have shape (6, 4, 5). If `max_singular_values` is set to an integer, the SVD is truncated to keep at most this many singular values. If `max_truncation_error > 0`, as many singular values will be truncated as possible, so that the truncation error (the norm of discarded singular values) is at most `max_truncation_error`. If both `max_singular_values` snd `max_truncation_error` are specified, the number of retained singular values will be `min(max_singular_values, nsv_auto_trunc)`, where `nsv_auto_trunc` is the number of singular values that must be kept to maintain a truncation error smaller than `max_truncation_error`. The output consists of three tensors `u, s, vh` such that: ```python u[i1,...,iN, j] * s[j] * vh[j, k1,...,kM] == tensor[i1,...,iN, k1,...,kM] ``` Note that the output ordering matches numpy.linalg.svd rather than tf.svd. Args: tf: The tensorflow module. tensor: A tensor to be decomposed. split_axis: Where to split the tensor's axes before flattening into a matrix. max_singular_values: The number of singular values to keep, or `None` to keep them all. max_truncation_error: The maximum allowed truncation error or `None` to not do any truncation. Returns: u: Left tensor factor. s: Vector of ordered singular values from largest to smallest. vh: Right tensor factor. s_rest: Vector of discarded singular values (length zero if no truncation). """ left_dims = list(tensor.shape_tensor)[:split_axis] right_dims = list(tensor.shape_tensor)[split_axis:] tensor = torch.reshape( tensor, (np.shape_prod(left_dims), np.shape_prod(right_dims))) u, s, v = torch.svd(tensor) if max_singular_values is None: max_singular_values = s.nelement() if max_truncation_error is not None: # Cumulative norms of singular values in ascending order s_sorted, _ = torch.sort(s**2) trunc_errs = torch.sqrt(torch.cumsum(s_sorted, 0)) # We must keep at least this many singular values to ensure the # truncation error is <= max_truncation_error. num_sing_vals_err = torch.nonzero( trunc_errs > max_truncation_error).nelement() else: num_sing_vals_err = max_singular_values num_sing_vals_keep = min(max_singular_values, num_sing_vals_err) # we recast to the original dtype. s = s.type(tensor.type()) s_rest = s[num_sing_vals_keep:] s = s[:num_sing_vals_keep] u = u[:, :num_sing_vals_keep] v = v[:, :num_sing_vals_keep] vh = torch.transpose(v, 0, 1) dim_s = s.shape_tensor[0] u = torch.reshape(u, left_dims + [dim_s]) vh = torch.reshape(vh, [dim_s] + right_dims) return u, s, vh, s_rest
def shape_prod(self, values: Tensor) -> int: return np.shape_prod(np.array(values))
def _tensordot_reshape( a: Tensor, axes: Union[Sequence[int], Tensor], is_right_term=False ) -> Tuple[Tensor, Union[List[int], Tensor], Optional[List[int]], bool]: """Helper method to perform transpose and reshape for contraction op. This method is helpful in reducing `math_ops.tensordot` to `math_ops.matmul` using `array_ops.transpose` and `array_ops.reshape`. The method takes a tensor and performs the correct transpose and reshape operation for a given set of indices. It returns the reshaped tensor as well as a list of indices necessary to reshape the tensor again after matrix multiplication. Args: a: `Tensor`. axes: List or `int32` `Tensor` of unique indices specifying valid axes of `a`. is_right_term: Whether `a` is the right (second) argument to `matmul`. Returns: A tuple `(reshaped_a, free_dims, free_dims_static, transpose_needed)` where `reshaped_a` is the tensor `a` reshaped to allow contraction via `matmul`, `free_dims` is either a list of integers or an `int32` `Tensor`, depending on whether the shape of a is fully specified, and free_dims_static is either a list of integers and None values, or None, representing the inferred static shape of the free dimensions. `transpose_needed` indicates whether `reshaped_a` must be transposed, or not, when calling `matmul`. """ if a.get_shape().is_fully_defined() and isinstance( axes, (list, tuple)): shape_a = a.get_shape().as_list() # NOTE: This will fail if axes contains any tensors axes = [i if i >= 0 else i + len(shape_a) for i in axes] free = [i for i in range(len(shape_a)) if i not in axes] flipped = _tensordot_should_flip(axes, free) free_dims = [shape_a[i] for i in free] prod_free = int(np.shape_prod([shape_a[i] for i in free])) prod_axes = int(np.shape_prod([shape_a[i] for i in axes])) perm = axes + free if flipped else free + axes new_shape = [prod_axes, prod_free ] if flipped else [prod_free, prod_axes] transposed_a = _tranpose_if_necessary(a, perm) reshaped_a = _reshape_if_necessary(transposed_a, new_shape) transpose_needed = (not flipped) if is_right_term else flipped return reshaped_a, free_dims, free_dims, transpose_needed if a.get_shape().ndims is not None and isinstance(axes, (list, tuple)): shape_a = a.get_shape().as_list() axes = [i if i >= 0 else i + len(shape_a) for i in axes] free = [i for i in range(len(shape_a)) if i not in axes] flipped = _tensordot_should_flip(axes, free) perm = axes + free if flipped else free + axes axes_dims = [shape_a[i] for i in axes] free_dims = [shape_a[i] for i in free] free_dims_static = free_dims axes = tf.convert_to_tensor(axes, dtype=tf.dtypes.int32, name="axes") free = tf.convert_to_tensor(free, dtype=tf.dtypes.int32, name="free") shape_a = tf.shape_tensor(a) transposed_a = _tranpose_if_necessary(a, perm) else: free_dims_static = None shape_a = tf.shape_tensor(a) rank_a = tf.rank(a) axes = tf.convert_to_tensor(axes, dtype=tf.dtypes.int32, name="axes") axes = tf.where(axes >= 0, axes, axes + rank_a) free, _ = tf.compat.v1.setdiff1d(tf.range(rank_a), axes) # Matmul does not accept tensors for its transpose arguments, so fall # back to the previous, fixed behavior. # NOTE(amilsted): With a suitable wrapper for `matmul` using e.g. `case` # to match transpose arguments to tensor values, we could also avoid # unneeded tranposes in this case at the expense of a somewhat more # complicated graph. Unclear whether this would be beneficial overall. flipped = is_right_term perm = (tf.shape_concat([axes, free], 0) if flipped else tf.shape_concat([free, axes], 0)) transposed_a = tf.transpose(a, perm) free_dims = tf.gather(shape_a, free) axes_dims = tf.gather(shape_a, axes) prod_free_dims = tf.reduce_prod(free_dims) prod_axes_dims = tf.reduce_prod(axes_dims) if flipped: new_shape = tf.stack([prod_axes_dims, prod_free_dims]) else: new_shape = tf.stack([prod_free_dims, prod_axes_dims]) reshaped_a = tf.reshape(transposed_a, new_shape) transpose_needed = (not flipped) if is_right_term else flipped return reshaped_a, free_dims, free_dims_static, transpose_needed