示例#1
0
文件: utils.py 项目: adhamrait/nums
def get_bop_output_type(op_name, dtype_a, dtype_b):
    a = np.array(1, dtype=dtype_a)
    b = np.array(2, dtype=dtype_b)
    op_name = np_ufunc_map.get(op_name, op_name)
    try:
        dtype = np.__getattribute__(op_name)(a, b).dtype
        return np.__getattribute__(str(dtype))
    except Exception as _:
        dtype = scipy.special.__getattribute__(op_name)(a, b).dtype
        return np.__getattribute__(str(dtype))
示例#2
0
def get_bop_output_type(op_name, dtype_a, dtype_b):
    a = np.array(1, dtype=dtype_a)
    b = np.array(2, dtype=dtype_b)
    op_name = np_ufunc_map.get(op_name, op_name)
    try:
        dtype = np.__getattribute__(op_name)(a, b).dtype
        return to_dtype_cls(dtype)
    except TypeError as err:
        raise err
    except Exception as _:
        dtype = scipy.special.__getattribute__(op_name)(a, b).dtype
        return to_dtype_cls(dtype)
示例#3
0
    def bop(self, op, a1, a2, a1_T, a2_T, axes):
        if a1_T:
            a1 = a1.T
        if a2_T:
            a2 = a2.T

        if op == "tensordot":
            return np.tensordot(a1, a2, axes=axes)
        op = np_ufunc_map.get(op, op)
        try:
            ufunc = np.__getattribute__(op)
        except Exception as _:
            ufunc = scipy.special.__getattribute__(op)
        return ufunc(a1, a2)
示例#4
0
def get_bop_output_type(op_name, dtype_a, dtype_b):
    a = np.array(1, dtype=dtype_a)
    b = np.array(2, dtype=dtype_b)
    op_name = np_ufunc_map.get(op_name, op_name)
    try:
        dtype = np.__getattribute__(op_name)(a, b).dtype
        return np.__getattribute__(str(dtype))
    except Exception as _:
        if op_name == "sparse_tensordot":
            a = coo_matrix((1, 1), dtype=dtype_a)
            b = coo_matrix((1, 1), dtype=dtype_b)
            dtype = (a * b).dtype
            return np.__getattribute__(str(dtype))

        dtype = scipy.special.__getattribute__(op_name)(a, b).dtype
        return np.__getattribute__(str(dtype))
示例#5
0
 def bop(self, op, a1, a2, a1_T, a2_T, axes):
     if a1_T:
         a1 = a1.T
     if a2_T:
         a2 = a2.T
     if op == "tensordot":
         if axes == 1 and max(len(a1.shape), len(a2.shape)) <= 2:
             # Execute this as a matmul.
             # TODO: Outer product is optimized.
             #  detect here and execute np.outer(...)
             return np.matmul(a1, a2)
         return np.tensordot(a1, a2, axes=axes)
     op = np_ufunc_map.get(op, op)
     try:
         ufunc = np.__getattribute__(op)
     except Exception as _:
         ufunc = scipy.special.__getattribute__(op)
     return ufunc(a1, a2)