def forward(self, x, qst): with x.context: self.coord_tensor = F.zeros((x.shape[0], 25, 2)) # prepare coord tensor def cvt_coord(i): return [(i / 5 - 2) / 2., (i % 5 - 2) / 2.] for i in range(25): self.coord_tensor[:, i, :] = F.array(cvt_coord(i)) #input size = (64 * 3 * 75 * 75) x = self.conv(x) ## x = (64 * 24 * 5 * 5) ##g part mb = x.shape[0] n_channels = x.shape[1] d = x.shape[2] x_flat = x.reshape(shape=(mb, n_channels, d * d)) x_flat = F.swapaxes(x_flat, 1, 2) ## (64 * 25 * 24) ##add coordinates x_flat = F.concat(x_flat, self.coord_tensor, dim=2) ##add question qst = qst.expand_dims(1) qst = F.repeat(qst, repeats=25, axis=1) qst = qst.expand_dims(2) # cast all pairs against each other x_i = x_flat.expand_dims(1) x_i = F.repeat(x_i, repeats=25, axis=1) x_j = x_flat.expand_dims(2) x_j = F.concat(x_j, qst, dim=3) x_j = F.repeat(x_j, repeats=25, axis=2) #concatenate all x_full = F.concat(x_i, x_j, dim=3) #reshape and apply dnn network x_ = x_full.reshape((-1, 63)) x_ = self.g_fc1(x_) x_ = self.g_fc2(x_) x_ = self.g_fc3(x_) x_ = self.g_fc4(x_) x_g = x_.reshape((mb, -1, 256)) x_g = x_g.sum(1) ##### f part ####### x_f = self.f_fc1(x_g) return self.fcout(x_f)
def fftfilt_nd(x, params): (b, m, nx, nb, L, nfft) = params B = nd.contrib.fft(data=nd.concatenate( [b.T, nd.zeros(shape=(1, (nfft - b.size)), ctx=ctx)], axis=1)) if b.size == 1: B = B.T # make sure fft of B is a column (might be a row if b is scalar) if b.shape[1] == 1: B = nd.repeat(data=B, repeats=x.shape[1], axis=0) # replicate the column B B_re = nd.slice(data=B, begin=(0, 0), end=(0, None), step=(1, 2)) B_im = nd.slice(data=B, begin=(0, 1), end=(0, None), step=(1, 2)) if x.shape[1] == 1: x = nd.repeat(data=x, repeats=b.shape[1], axis=1) # replicate the column x y = nd.zeros_like(x.T) istart = 1 while istart <= nx: iend = min(istart + L - 1, nx) if (iend - istart) == 0: X = x[istart] * np.ones((nfft, 1)) # need to fft a scalar else: temp = nd.slice(x, begin=istart - 1, end=iend).T X = nd.contrib.fft(data=nd.concatenate([ temp, nd.zeros(shape=(temp.shape[0], (nfft - temp.shape[1])), ctx=ctx) ], axis=1)) X_re = nd.slice(data=X, begin=(0, 0), end=(0, None), step=(1, 2)) X_im = nd.slice(data=X, begin=(0, 1), end=(0, None), step=(1, 2)) XprodB_re = (X_re * B_re - X_im * B_im) XprodB_im = (X_re * B_im + X_im * B_re) Ytemp = nd.zeros((X.shape[0], X.shape[1]), ctx=ctx) Ytemp[:, ::2] = XprodB_re Ytemp[:, 1::2] = XprodB_im Y = mx.contrib.ndarray.ifft(Ytemp / nfft) # only the real part!!!! yend = min(nx, istart + nfft - 1) y[:, istart - 1:yend] = nd.slice( data=y, begin=(0, istart - 1), end=(0, yend), step=(1, 1)) + nd.slice( data=Y, begin=(0, 0), end=(0, yend - istart + 1), step=(1, 1)) istart += L # y = real(y) return y
def verify_l2normalization_rewrite(shape, eps, mode): assert len(shape) == 4 # NCHW data_np = np.random.uniform(size=shape) x = nd.array(data_np) # org op y = nd.L2Normalization(x, eps=eps, mode=mode) # rewrite op z = nd.broadcast_mul(x, x) if mode == "channel": axis = [1] elif mode == "instance": axis = [1, 2, 3] elif mode == "spatial": axis = [2, 3] else: assert "not valid `mode` type: %s" % mode z = nd.sum(z, axis=axis) eps_tensor = nd.array([eps]) z = nd.broadcast_add(z, eps_tensor) z = nd.sqrt(z) for i in axis: z = nd.expand_dims(z, axis=i) z = nd.repeat(z, repeats=shape[i], axis=i) z = nd.broadcast_div(x, z) print(z.shape) return # compare assert z.shape == y.shape zn, zp = get_norm(z) yn, yp = get_norm(y) rn = np.linalg.norm(zp - yp) print(zn, yn, rn)
def repeat(input, repeats, dim): if isinstance(repeats, nd.NDArray): return nd.array(np.repeat(input.asnumpy(), repeats.asnumpy(), axis=dim), ctx=input.context, dtype=input.dtype) else: return nd.repeat(input, repeats, axis=dim)
def repeat(input, repeats, dim): return nd.repeat(input, repeats, axis=dim)
def verify_broadcast_like_dynamic(xshp, wshp, lhs_axes, rhs_axes): x_np = np.random.uniform(size=xshp) w_np = np.random.uniform(size=wshp) x = nd.array(x_np) w = nd.array(w_np) # org op y = nd.broadcast_like(x, w, lhs_axes=lhs_axes, rhs_axes=rhs_axes) print(y.shape) # rewrite op xndims, wndims = len(xshp), len(wshp) if lhs_axes is None or rhs_axes is None: assert xndims == wndims and lhs_axes is None \ and rhs_axes is None z = _broadcast_like(x, w) else: lhs_axes, lndims = list(lhs_axes), len(lhs_axes) rhs_axes, rndims = list(rhs_axes), len(rhs_axes) assert lndims == rndims > 0 lhs_axes = tuple([v+xndims if v<0 else v for v in lhs_axes]) assert all([0<=v<xndims for v in list(lhs_axes)]) rhs_axes = tuple([v+wndims if v<0 else v for v in rhs_axes]) assert all([0<=v<wndims for v in list(rhs_axes)]) assert all([xshp[lhs_axes[i]] == 1 for i in range(lndims)]) batch_axes = [0] flg = all([batch_axis not in rhs_axes \ for batch_axis in batch_axes]) if flg: cnts = {v: wshp[rhs_axes[i]] \ for i, v in enumerate(lhs_axes)} reps = tuple([cnts[v] if v in lhs_axes else 1 \ for v in range(xndims)]) z = nd.tile(x, reps=reps) else: axis_map = {} for i, v in enumerate(lhs_axes): axis_map[v] = rhs_axes[i] for batch_axis in batch_axes: assert sum([1 if v == batch_axis else 0 \ for k, v in axis_map.items()]) <= 1, \ "multiple broadcast on batch_axis: %s, " + \ "which is not support by dynamic shape fusion." % \ batch_axis assert wndims < 6, \ "slice can manipulate at most 5d" # reduce shape to 1 for non-broadcast dimensions begin = tuple([0]*wndims) end = tuple([wshp[v] if v in axis_map.values() else 1 \ for v in range(wndims)]) w = nd.slice(w, begin=begin, end=end) # decompose k1->v, k2->v into k1->v, k2->v2 # which make axis while True: vs, flag, paxis_map = set(), True, axis_map for pk, pv in paxis_map.items(): if pv not in vs: vs.add(pv) continue flag = False axis_map = {k: (v+1 if v>pv or k==pk else v) \ for k, v in axis_map.items()} w = nd.expand_dims(w, axis=pv) w = nd.repeat(w, axis=pv, repeats=wshp[pv]) wshp = wshp[:pv] + (wshp[pv],) + wshp[pv:] break if flag: break wndims = len(wshp) # trim wndims if not equal to xndims v = 0 while wndims > xndims: while v in axis_map.values(): v += 1 w = nd.squeeze(w, axis=v) wndims -= 1 axis_map = {k: (nv-1 if nv > v else nv) \ for k, nv in axis_map.items()} while wndims < xndims: w = nd.expand_dims(w, axis=wndims) wndims += 1 axes = list(range(wndims)) while True: dels = [k for k, v in axis_map.items() if k==v] for k in dels: del axis_map[k] if not axis_map: break keys = list(axis_map.keys()) k, v = keys[0], axis_map[keys[0]] axes[k], axes[v] = axes[v], axes[k] for nk in keys: nv = axis_map[nk] if nv == k: axis_map[nk] = v elif nv == v: axis_map[nk] = k axes = tuple(axes) if axes != tuple(range(wndims)): assert wndims < 7, \ "slice can manipulate at most 6d" w = nd.transpose(w, axes=axes) z = _broadcast_like(x, w) print(z.shape) # compare assert z.shape == y.shape zn, zp = get_norm(z) yn, yp = get_norm(y) rn = np.linalg.norm(zp-yp) print(zn, yn, rn)