def elementwise_trinary(alpha, A, desc_A, mode_A, beta, B, desc_B, mode_B, gamma, C, desc_C, mode_C, out=None, op_AB=cutensor.OP_ADD, op_ABC=cutensor.OP_ADD, compute_dtype=None): """Element-wise tensor operation for three input tensors This function performs a element-wise tensor operation of the form: D_{Pi^C(i_0,i_1,...,i_nc)} = op_ABC(op_AB(alpha * uop_A(A_{Pi^A(i_0,i_1,...,i_na)}), beta * uop_B(B_{Pi^B(i_0,i_1,...,i_nb)})), gamma * uop_C(C_{Pi^C(i_0,i_1,...,i_nc)})) See cupy/cuda/cutensor.elementwiseTrinary() for details. Args: alpha (scalar or 0-dim numpy.ndarray): Scaling factor for tensor A. A (cupy.ndarray): Input tensor. desc_A (class Descriptor): A descriptor that holds the information about the data type, modes, and strides of tensor A. mode_A (cutensor.Mode): A mode object created by `create_mode`. beta (scalar or 0-dim numpy.ndarray): Scaling factor for tensor B. B (cupy.ndarray): Input tensor. desc_B (class Descriptor): A descriptor that holds the information about the data type, modes, and strides of tensor B. mode_B (cutensor.Mode): A mode object created by `create_mode`. gamma (scalar or 0-dim numpy.ndarray): Scaling factor for tensor C. C (cupy.ndarray): Input tensor. desc_C (class Descriptor): A descriptor that holds the information about the data type, modes, and strides of tensor C. mode_C (cutensor.Mode): A mode object created by `create_mode`. out (cupy.ndarray): Output tensor. op_AB (cutensorOperator_t): Element-wise binary operator. op_ABC (cutensorOperator_t): Element-wise binary operator. compute_dtype (numpy.dtype): Compute type for the intermediate computation. Returns: out (cupy.ndarray): Output tensor. Examples: See examples/cutensor/elementwise_trinary.py """ if not (A.dtype == B.dtype == C.dtype): raise ValueError('dtype mismatch: ({}, {}, {})'.format( A.dtype, B.dtype, C.dtype)) if not (A.flags.c_contiguous and B.flags.c_contiguous and C.flags.c_contiguous): raise ValueError('The inputs should be contiguous arrays.') if out is None: out = cupy.ndarray(C.shape, dtype=C.dtype) elif C.dtype != out.dtype: raise ValueError('dtype mismatch: {} != {}'.format(C.dtype, out.dtype)) elif C.shape != out.shape: raise ValueError('shape mismatch: {} != {}'.format(C.shape, out.shape)) elif not out.flags.c_contiguous: raise ValueError('`out` should be a contiguous array.') mode_A = _auto_create_mode(A, mode_A) mode_B = _auto_create_mode(B, mode_B) mode_C = _auto_create_mode(C, mode_C) if compute_dtype is None: compute_dtype = A.dtype alpha = numpy.asarray(alpha, compute_dtype) beta = numpy.asarray(beta, compute_dtype) gamma = numpy.asarray(gamma, compute_dtype) handle = get_handle() cuda_dtype = get_cuda_dtype(compute_dtype) cutensor.elementwiseTrinary(handle, alpha.ctypes.data, A.data.ptr, desc_A, mode_A.data, beta.ctypes.data, B.data.ptr, desc_B, mode_B.data, gamma.ctypes.data, C.data.ptr, desc_C, mode_C.data, out.data.ptr, desc_C, mode_C.data, op_AB, op_ABC, cuda_dtype) return out
def elementwise_trinary(alpha, A, desc_A, mode_A, beta, B, desc_B, mode_B, gamma, C, desc_C, mode_C, out=None, op_AB=cutensor.OP_ADD, op_ABC=cutensor.OP_ADD, compute_dtype=None): """Element-wise tensor operation for three input tensors This function performs a element-wise tensor operation of the form: D_{Pi^C(i_0,i_1,...,i_nc)} = op_ABC(op_AB(alpha * uop_A(A_{Pi^A(i_0,i_1,...,i_na)}), beta * uop_B(B_{Pi^B(i_0,i_1,...,i_nb)})), gamma * uop_C(C_{Pi^C(i_0,i_1,...,i_nc)})) See cupy/cuda/cutensor.elementwiseTrinary() for details. Args: alpha: Scaling factor for tensor A. A (cupy.ndarray): Input tensor. desc_A (class Descriptor): A descriptor that holds the information about the data type, modes, and strides of tensor A. mode_A (tuple of int/str): A tuple that holds the labels of the modes of tensor A (e.g., if A_{x,y,z}, mode_A = {'x','y','z'}) beta: Scaling factor for tensor B. B (cupy.ndarray): Input tensor. desc_B (class Descriptor): A descriptor that holds the information about the data type, modes, and strides of tensor B. mode_B (tuple of int/str): A tuple that holds the labels of the modes of tensor B. gamma: Scaling factor for tensor C. C (cupy.ndarray): Input tensor. desc_C (class Descriptor): A descriptor that holds the information about the data type, modes, and strides of tensor C. mode_C (tuple of int/str): A tuple that holds the labels of the modes of tensor C. out (cupy.ndarray): Output tensor. op_AB (cutensorOperator_t): Element-wise binary operator. op_ABC (cutensorOperator_t): Element-wise binary operator. compute_dtype (numpy.dtype): Compute type for the intermediate computation. Returns: out (cupy.ndarray): Output tensor. Examples: See examples/cutensor/elementwise_trinary.py """ assert A.dtype == B.dtype == C.dtype assert A.ndim == len(mode_A) assert B.ndim == len(mode_B) assert C.ndim == len(mode_C) mode_A = _convert_mode(mode_A) mode_B = _convert_mode(mode_B) mode_C = _convert_mode(mode_C) if out is None: out = cupy.ndarray(C.shape, dtype=C.dtype) else: assert C.dtype == out.dtype assert C.ndim == out.ndim for i in range(C.ndim): assert C.shape[i] == out.shape[i] if compute_dtype is None: compute_dtype = A.dtype alpha = numpy.array(alpha, compute_dtype) beta = numpy.array(beta, compute_dtype) gamma = numpy.array(gamma, compute_dtype) handle = get_handle() cuda_dtype = get_cuda_dtype(compute_dtype) cutensor.elementwiseTrinary(handle, alpha.ctypes.data, A.data.ptr, desc_A, mode_A.ctypes.data, beta.ctypes.data, B.data.ptr, desc_B, mode_B.ctypes.data, gamma.ctypes.data, C.data.ptr, desc_C, mode_C.ctypes.data, out.data.ptr, desc_C, mode_C.ctypes.data, op_AB, op_ABC, cuda_dtype) return out