Ejemplo n.º 1
0
	def f(A,B,axes=[-2,-1]):
		axes = [i if i >= 0 else A.ndim+i for i in axes]
		bax  = axes[:len(axes)-(A.ndim-B.ndim)]
		B  = B.copy()
		Af = utils.partial_flatten(A,axes)
		Bf = utils.partial_flatten(B,bax)
		mustadd = vec2mat and Bf.ndim == 2
		if mustadd: Bf = utils.addaxes(Bf, [1])
		Bf = np.ascontiguousarray(Bf)
		assert A.dtype == B.dtype
		fun = get_dtype_fun(funcs, A.dtype)
		fun(Af.T, Bf.T)
		if mustadd: Bf = utils.delaxes(Bf, [1])
		B[...] = utils.partial_expand(Bf, B.shape, bax)
		return B
Ejemplo n.º 2
0
	def f(A,*args,**kwargs):
		axes = kwargs["axes"] if "axes" in kwargs else [-2,-1]
		axes = [i if i >= 0 else A.ndim+i for i in axes]
		A  = A.copy()
		Af = np.ascontiguousarray(utils.partial_flatten(A,axes))
		fun = get_dtype_fun(funcs, A.dtype)
		fun(Af.T, *args)
		return utils.partial_expand(Af, A.shape, axes)
Ejemplo n.º 3
0
	def f(A,axes=[-2,-1]):
		axes = [i if i >= 0 else A.ndim+i for i in axes]
		b  = utils.moveaxes(A,axes,[0,1])[0,0].copy()
		Af = utils.partial_flatten(A,axes)
		bf = b.reshape(-1)
		assert A.dtype == b.dtype
		fun = get_dtype_fun(funcs, A.dtype)
		fun(Af.T, bf.T)
		return b
Ejemplo n.º 4
0
	def f(A,*args,**kwargs):
		axes = kwargs["axes"] if "axes" in kwargs else [-2,-1]
		axes = [i if i >= 0 else A.ndim+i for i in axes]
		A  = A.copy()
		Af = utils.partial_flatten(A,axes)
		A2 = np.ascontiguousarray(Af)
		fun = get_dtype_fun(funcs, A.dtype)
		fun(A2.T, *args)
		if Af is not A2: Af[...] = A2[...]
		return A
Ejemplo n.º 5
0
def matmul(A,B,axes=[-2,-1]):
	# Massage input arrays. This should be factored out,
	# as it is common for many functions
	axes = [i if i >= 0 else A.ndim+i for i in axes]
	bax  = axes[:len(axes)-(A.ndim-B.ndim)]
	Af = utils.partial_flatten(A,axes)
	Bf = utils.partial_flatten(B,bax)
	mustadd = Bf.ndim == 2
	if mustadd: Bf = utils.addaxes(Bf, [1])
	Bf = np.ascontiguousarray(Bf)
	if A.dtype != B.dtype:
		dtype = np.result_type(A.dtype,B.dtype)
		Af = Af.astype(dtype,copy=False)
		Bf = Bf.astype(dtype,copy=False)
	# Compute the shape of the output array
	Xf = np.empty((Bf.shape[0],Bf.shape[1],Af.shape[1]),dtype=Bf.dtype)
	# Actually perform the operation
	core = get_core(Bf.dtype)
	core.matmul_multi(Af.T, Bf.T, Xf.T)
	# Unwrangle
	if mustadd: Xf = utils.delaxes(Xf, [1])
	X = utils.partial_expand(Xf, B.shape, bax)
	return X