def multidim_dot(space, left, right, result, dtype, right_critical_dim): ''' assumes left, right are concrete arrays given left.shape == [3, 5, 7], right.shape == [2, 7, 4] then result.shape == [3, 5, 2, 4] broadcast shape should be [3, 5, 2, 7, 4] result should skip dims 3 which is len(result_shape) - 1 (note that if right is 1d, result should skip len(result_shape)) left should skip 2, 4 which is a.ndims-1 + range(right.ndims) except where it==(right.ndims-2) right should skip 0, 1 ''' broadcast_shape = left.shape[:-1] + right.shape shapelen = len(broadcast_shape) left_skip = [len(left.shape) - 1 + i for i in range(len(right.shape)) if i != right_critical_dim] right_skip = range(len(left.shape) - 1) result_skip = [len(result.shape) - (len(right.shape) > 1)] _r = calculate_dot_strides(result.strides, result.backstrides, broadcast_shape, result_skip) outi = ViewIterator(result.start, _r[0], _r[1], broadcast_shape) _r = calculate_dot_strides(left.strides, left.backstrides, broadcast_shape, left_skip) lefti = ViewIterator(left.start, _r[0], _r[1], broadcast_shape) _r = calculate_dot_strides(right.strides, right.backstrides, broadcast_shape, right_skip) righti = ViewIterator(right.start, _r[0], _r[1], broadcast_shape) while not outi.done(): dot_driver.jit_merge_point(left=left, right=right, shapelen=shapelen, lefti=lefti, righti=righti, outi=outi, result=result, dtype=dtype, ) lval = left.getitem(lefti.offset).convert_to(dtype) rval = right.getitem(righti.offset).convert_to(dtype) outval = result.getitem(outi.offset).convert_to(dtype) v = dtype.itemtype.mul(lval, rval) value = dtype.itemtype.add(v, outval).convert_to(dtype) result.setitem(outi.offset, value) outi = outi.next(shapelen) righti = righti.next(shapelen) lefti = lefti.next(shapelen) return result
def create_dot_iter(self, shape, skip): r = calculate_dot_strides(self.get_strides(), self.get_backstrides(), shape, skip) return iter.MultiDimViewIterator(self, self.dtype, self.start, r[0], r[1], shape)