return Lamda[j:n1, i:n0]( Sum[sum_limit(r, w, n1, dj, j0, 1), sum_limit(r, w, n0, di, i0, 0)]( x[i0 + di * r[0], j0 + dj * r[1]] @ w[di, dj])) if batch_size: batch_size = batch_size[0] k = Symbol.k(integer=True) return Lamda[k:batch_size](conv2d(x[k], w)) else: return conv2d(x, w) conv2d = Function.conv2d(real=True, nargs=(2, ), eval=conv2d, shape=property(shape)) def conv3d(x, w, *limits): if limits: (r, ), *_ = limits else: r = (1, 1, 1) l0, l1, l2, in_channels, out_channels = w.shape *batch_size, n0, n1, n2, _in_channels = x.shape assert in_channels == _in_channels def conv3d(x, w):