def _test_crop_resize_with_diff_type(dtype): # test normal case data_in = nd.arange(60).reshape((5, 4, 3)).astype(dtype) out_nd = transforms.CropResize(0, 0, 3, 2)(data_in) out_np = out_nd.asnumpy() assert(out_np.sum() == 180) assert((out_np[0:2,1,1].flatten() == [4, 16]).all()) # test 4D input data_bath_in = nd.arange(180).reshape((2, 6, 5, 3)).astype(dtype) out_batch_nd = transforms.CropResize(1, 2, 3, 4)(data_bath_in) out_batch_np = out_batch_nd.asnumpy() assert(out_batch_np.sum() == 7524) assert((out_batch_np[0:2,0:4,1,1].flatten() == [37, 52, 67, 82, 127, 142, 157, 172]).all()) # test normal case with resize data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype(dtype) out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_in) data_expected = image.imresize(nd.slice(data_in, (0, 0, 0), (50, 100 , 3)), 25, 25, 2) assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy()) # test 4D input with resize data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 3)).astype(dtype) out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_bath_in) for i in range(len(out_batch_nd)): assert_almost_equal(image.imresize(nd.slice(data_bath_in[i], (0, 0, 0), (50, 100, 3)), 25, 25, 2).asnumpy(), out_batch_nd[i].asnumpy()) # test with resize height and width should be greater than 0 transformer = transforms.CropResize(0, 0, 100, 50, (-25, 25), 2) assertRaises(MXNetError, transformer, data_in) # test height and width should be greater than 0 transformer = transforms.CropResize(0, 0, -100, -50) assertRaises(MXNetError, transformer, data_in) # test cropped area is bigger than input data transformer = transforms.CropResize(150, 200, 200, 500) assertRaises(MXNetError, transformer, data_in) assertRaises(MXNetError, transformer, data_bath_in)
def _test_crop_resize_with_diff_type(dtype): # test normal case data_in = nd.arange(60).reshape((5, 4, 3)).astype(dtype) out_nd = transforms.CropResize(0, 0, 3, 2)(data_in) out_np = out_nd.asnumpy() assert (out_np.sum() == 180) assert ((out_np[0:2, 1, 1].flatten() == [4, 16]).all()) # test 4D input data_bath_in = nd.arange(180).reshape((2, 6, 5, 3)).astype(dtype) out_batch_nd = transforms.CropResize(1, 2, 3, 4)(data_bath_in) out_batch_np = out_batch_nd.asnumpy() assert (out_batch_np.sum() == 7524) assert ((out_batch_np[0:2, 0:4, 1, 1].flatten() == [ 37, 52, 67, 82, 127, 142, 157, 172 ]).all()) # test normal case with resize data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype(dtype) out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 1)(data_in) data_expected = transforms.Resize(size=25, interpolation=1)(nd.slice( data_in, (0, 0, 0), (50, 100, 3))) assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy()) # test 4D input with resize data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 3)).astype(dtype) out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 1)(data_bath_in) for i in range(len(out_batch_nd)): actual = transforms.Resize(size=25, interpolation=1)(nd.slice( data_bath_in[i], (0, 0, 0), (50, 100, 3))).asnumpy() expected = out_batch_nd[i].asnumpy() assert_almost_equal(expected, actual) # test with resize height and width should be greater than 0 transformer = transforms.CropResize(0, 0, 100, 50, (-25, 25), 1) assertRaises(MXNetError, transformer, data_in) # test height and width should be greater than 0 transformer = transforms.CropResize(0, 0, -100, -50) assertRaises(MXNetError, transformer, data_in) # test cropped area is bigger than input data transformer = transforms.CropResize(150, 200, 200, 500) assertRaises(MXNetError, transformer, data_in) assertRaises(MXNetError, transformer, data_bath_in)
def sym_slice(op, ichannel, step): shp = op.shape ndims = len(shp) nodes = [] rchannel = ndims - ichannel - 1 for i in range(0, shp[ichannel], step): opi = nd.slice(op, begin=(None, ) * ichannel + (i, ) + (None, ) * rchannel, end=(None, ) * ichannel + (i + step, ) + (None, ) * rchannel) nodes.append(opi) return nodes
def test_strided_slice(): print("test strided slice") tmp_dir = DIR + "strided_slice/" os.makedirs(tmp_dir + "0/", exist_ok=True) shape = np.random.randint(low=3, high=4, size=(4)) print(shape) a = np.random.randint(low=-127, high=127, size=shape) print(a) np.save(tmp_dir + "0/in_0.npy", a.astype("int32")) params = {"begin": [2, 0], "end": [0, 3], "step": [-1, 2]} save_dict(params, tmp_dir + "0/attr.txt") b = nd.slice(nd.array(a), **params) np.save(tmp_dir + "0/out_0.npy", b.asnumpy().astype("int32")) print(b.shape) print(b) os.makedirs(tmp_dir + "1/", exist_ok=True) shape = np.random.randint(low=3, high=4, size=(4)) print(shape) a = np.random.randint(low=-127, high=127, size=shape) print(a) np.save(tmp_dir + "1/in_0.npy", a.astype("int32")) params = {"begin": [0, 0], "end": [2, 3]} save_dict(params, tmp_dir + "1/attr.txt") b = nd.slice(nd.array(a), **params) np.save(tmp_dir + "1/out_0.npy", b.asnumpy().astype("int32")) print(b.shape) os.makedirs(tmp_dir + "2/", exist_ok=True) shape = np.random.randint(low=3, high=4, size=(4)) print(shape) a = np.random.randint(low=-127, high=127, size=shape) print(a) np.save(tmp_dir + "2/in_0.npy", a.astype("int32")) params = {"begin": [0, 0, 1, 1], "end": [1, 2, 3, 3]} save_dict(params, tmp_dir + "2/attr.txt") b = nd.slice(nd.array(a), **params) np.save(tmp_dir + "2/out_0.npy", b.asnumpy().astype("int32")) print(b.shape)
def mxwindow(mna,window): mnas=mna.shape mnout=(*mnas[:-2],*window,((mnas[-2]-window[-2])+1),((mnas[-1]-window[-1])+1)) mne2=None for R in range(window[0]): j_lim = R + mnout[-2] for H in range(window[1]): tdata=mnd.slice(mna, begin=(None,None,R,H), end=(None,None,j_lim,(H + mnout[-1])), step=(None,None,1,1)) if mne2 is None: mne2=tdata else: mne2=mnd.concat(mne2,tdata,dim=1) return(mnd.expand_dims(mnd.transpose(mnd.reshape(mne2, shape=mnout),axes=(0,5,4,3,2,1)), 3))
def fftfilt_nd(x, params): (b, m, nx, nb, L, nfft) = params B = nd.contrib.fft(data=nd.concatenate( [b.T, nd.zeros(shape=(1, (nfft - b.size)), ctx=ctx)], axis=1)) if b.size == 1: B = B.T # make sure fft of B is a column (might be a row if b is scalar) if b.shape[1] == 1: B = nd.repeat(data=B, repeats=x.shape[1], axis=0) # replicate the column B B_re = nd.slice(data=B, begin=(0, 0), end=(0, None), step=(1, 2)) B_im = nd.slice(data=B, begin=(0, 1), end=(0, None), step=(1, 2)) if x.shape[1] == 1: x = nd.repeat(data=x, repeats=b.shape[1], axis=1) # replicate the column x y = nd.zeros_like(x.T) istart = 1 while istart <= nx: iend = min(istart + L - 1, nx) if (iend - istart) == 0: X = x[istart] * np.ones((nfft, 1)) # need to fft a scalar else: temp = nd.slice(x, begin=istart - 1, end=iend).T X = nd.contrib.fft(data=nd.concatenate([ temp, nd.zeros(shape=(temp.shape[0], (nfft - temp.shape[1])), ctx=ctx) ], axis=1)) X_re = nd.slice(data=X, begin=(0, 0), end=(0, None), step=(1, 2)) X_im = nd.slice(data=X, begin=(0, 1), end=(0, None), step=(1, 2)) XprodB_re = (X_re * B_re - X_im * B_im) XprodB_im = (X_re * B_im + X_im * B_re) Ytemp = nd.zeros((X.shape[0], X.shape[1]), ctx=ctx) Ytemp[:, ::2] = XprodB_re Ytemp[:, 1::2] = XprodB_im Y = mx.contrib.ndarray.ifft(Ytemp / nfft) # only the real part!!!! yend = min(nx, istart + nfft - 1) y[:, istart - 1:yend] = nd.slice( data=y, begin=(0, istart - 1), end=(0, yend), step=(1, 1)) + nd.slice( data=Y, begin=(0, 0), end=(0, yend - istart + 1), step=(1, 1)) istart += L # y = real(y) return y
def forward(self, input_vec, loss=None): assert input_vec.shape[1] == self.input_dimension # get inputs for every slot(including global) inputs = {} for slot in self.slots: inputs[slot] = input_vec[:, self.slot_dimension[slot][0]:self.slot_dimension[slot][1]] input_global = [] for seg in self.global_dimension: input_global.append(input_vec[:, seg[0]:seg[1]]) inputs['global'] = nd.concat(*input_global, dim=1) layer = [] # inputs -> first_hidden_layer if (not self.sort_input_vec) and self.state_feature != 'dip': layer.append([]) for slot in self.slots: layer[0].append(self.input_trans[slot](inputs[slot])) layer[0].append(self.input_trans['global'](inputs['global'])) elif self.state_feature == 'dip': sorted_inputs = [] for slot in self.slots: sorted_inputs.append(inputs[slot]) sorted_inputs.append(inputs['global']) layer.append(self.input_trans(sorted_inputs, loss)) elif self.sort_input_vec: sorted_inputs = [] for slot in self.slots: tmp = inputs[slot][:, :-2].sort(is_ascend=False) if tmp.shape[1] < 20: tmp = nd.concat(tmp, nd.zeros((tmp.shape[0], 20 - tmp.shape[1]), ctx=CTX), dim=1) else: tmp = nd.slice_axis(tmp, axis=1, begin=0, end=20) sorted_inputs.append(nd.concat(tmp, inputs[slot][:, -2:], dim=1)) sorted_inputs.append(inputs['global']) layer.append(self.input_trans(sorted_inputs, loss)) # hidden_layers for i in range(self.hidden_layers - 1): if self.recurrent_mode is False: # equal to 'layer.append(self.ma_trans[i](layer[-1], loss))' layer.append(self.ma_trans[i](layer[i], loss)) else: layer.append(self.ma_trans(layer[i], loss)) if self.share_last_layer is False: # dropout of last hidden layer for j in range(len(self.slots)): layer[-1][j] = self.local_out_drop_op(layer[-1][j]) layer[-1][-1] = self.global_out_drop_op(layer[-1][-1]) # last_hidden_layer -> outputs outputs = [] for i in range(len(self.slots) + 1): if self.use_dueling is False: outputs.append(self.output_trans[i](layer[-1][i])) else: if i < len(self.slots): tmp_adv = self.output_trans_local_advantage(sorted_inputs[i]) else: tmp_adv = self.output_trans_global_advantage(sorted_inputs[-1]) if self.dueling_share_last: if i < len(self.slots): cur_value = self.output_trans_local_value(layer[-1][i]) if self.shared_last_layer_use_bias: cur_value = cur_value + nd.slice(self.value_bias_local.data(), begin=(i, ), end=(i + 1, )) else: cur_value = self.output_trans_global_value(layer[-1][i]) else: cur_value = self.output_trans_value[i](layer[-1][i]) outputs.append( cur_value + tmp_adv - tmp_adv.mean(axis=1).reshape( (tmp_adv.shape[0], 1)).broadcast_axes(axis=1, size=tmp_adv.shape[1])) else: outputs = [] for i in range(len(self.slots)): output_i = self.output_trans_local(layer[-1][i]) if self.shared_last_layer_use_bias: output_i = output_i + self.output_trans_local_biases[i].data() outputs.append(output_i) outputs.append(self.output_trans_global(layer[-1][-1])) return nd.concat(*outputs, dim=1)
def forward(self, input_vec, loss=None, training=True): # print('************* ' + str(input_vec.shape[1]) + ' *************') # print('############# ' + str(input_vec.shape) + ' #############') assert input_vec.shape[1] == self.input_dimension # get inputs for every slot(including global) inputs = {} for slot in self.slots: inputs[slot] = input_vec[:, self.slot_dimension[slot][0]:self.slot_dimension[slot][1]] input_global = [] for seg in self.global_dimension: input_global.append(input_vec[:, seg[0]:seg[1]]) inputs['global'] = nd.concat(*input_global, dim=1) layer = [] # inputs -> first_hidden_layer if (not self.sort_input_vec) and self.state_feature != 'dip': layer.append([]) for slot in self.slots: layer[0].append(self.input_trans[slot](inputs[slot])) layer[0].append(self.input_trans['global'](inputs['global'])) elif self.state_feature == 'dip': sorted_inputs = [] for slot in self.slots: sorted_inputs.append(inputs[slot]) sorted_inputs.append(inputs['global']) layer.append(self.input_trans.forward(sorted_inputs, loss, training=training)) elif self.sort_input_vec: sorted_inputs = [] for slot in self.slots: tmp = inputs[slot][:, :-2].sort(is_ascend=False) if tmp.shape[1] < 20: tmp = nd.concat(tmp, nd.zeros((tmp.shape[0], 20 - tmp.shape[1]), ctx=CTX), dim=1) else: tmp = nd.slice_axis(tmp, axis=1, begin=0, end=20) sorted_inputs.append(nd.concat(tmp, inputs[slot][:, -2:], dim=1)) sorted_inputs.append(inputs['global']) layer.append(self.input_trans.forward(sorted_inputs, loss, training=training)) # hidden_layers for i in range(self.hidden_layers - 1): if self.recurrent_mode is False: # equal to 'layer.append(self.ma_trans[i](layer[-1], loss))' layer.append(self.ma_trans[i](layer[i], loss)) else: layer.append(self.ma_trans(layer[i], loss)) if self.share_last_layer is False: # dropout of last hidden layer for j in range(len(self.slots)): layer[-1][j] = self.local_out_drop_op.forward(layer[-1][j]) layer[-1][-1] = self.global_out_drop_op.forward(layer[-1][-1]) # last_hidden_layer -> outputs outputs = [] slotv_probs = [] slotqs = [] slot_probs = [] top_decision = [] for i in range(len(self.slots) + 1): if self.use_dueling is False: outputs.append(self.output_trans[i](layer[-1][i])) else: if i < len(self.slots): cur_slotv_prob = self.output_trans_local_valueP.forward(layer[-1][i], training=training) cur_slotv_prob = nd.softmax(cur_slotv_prob) else: cur_slotv_prob = self.output_trans_global_valueP.forward(layer[-1][i], training=training) cur_slotv_prob = nd.softmax(cur_slotv_prob) if self.dueling_share_last: if i < len(self.slots): cur_slotq = self.output_trans_local_slotQ.forward(layer[-1][i], training=training) cur_slot_prob = self.output_trans_local_slotP.forward(layer[-1][i], training=training).reshape(-1,1) cur_slotv_prob = cur_slotv_prob*cur_slot_prob # cur_slot_prob = nd.softmax(cur_slot_prob) if self.shared_last_layer_use_bias: cur_slotq = cur_slotq + nd.slice(self.value_bias_local.data(), begin=(i, ), end=(i + 1, )) else: cur_slotq = self.output_trans_global_slotQ.forward(layer[-1][i], training=training) cur_slot_prob = self.output_trans_global_slotP.forward(layer[-1][i], training=training).reshape(-1,1) cur_slotv_prob = cur_slotv_prob*cur_slot_prob # cur_slot_prob = nd.softmax(cur_slot_prob) top_decision.append(cur_slot_prob) else: cur_slotq = self.output_trans_value[i](layer[-1][i]) slotv_probs.append(cur_slotv_prob) slot_probs.append(cur_slot_prob) slotqs.append(cur_slotq) # batch_slotv_probs_list = [] # slot_prob_softmax = nd.softmax(nd.concat(*slot_probs, dim=1)) # slot_prob_split = nd.split(slot_prob_softmax, axis=1, num_outputs=len(self.slots)+1) # assert len(slotv_probs) == len(self.slots)+1 # for i in range(len(slotv_probs)): # tmp = slot_prob_split[i].reshape(-1,1)*slotv_probs[i] # batch_slotv_probs_list.append(tmp) batch_slot_prob = nd.softmax(nd.concat(*slot_probs, dim=1)) batch_slot_slotq = nd.concat(*slotqs, dim=1) batch_slotv_prob = nd.softmax(nd.concat(*slotv_probs, dim=1)) batch_top_decision = nd.softmax(nd.concat(*top_decision,dim=1)) # print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') # print(batch_slotv_prob) # print(batch_slot_prob.shape) # print(batch_slot_slotq.shape) # print(batch_slotv_prob.shape) prob = batch_slotv_prob value = nd.max(batch_slot_slotq, axis=1) top_decision = batch_top_decision # CTname = threading.currentThread().getName() # print(CTname+' top decision is : ') # print(top_decision) return prob, value, top_decision
def forward(self, input_vec, loss=None, training=True): assert input_vec.shape[1] == self.input_dimension # get inputs for every slot(including global) inputs = {} for slot in self.slots: inputs[slot] = input_vec[:, self.slot_dimension[slot][0]:self.slot_dimension[slot][1]] input_global = [] for seg in self.global_dimension: input_global.append(input_vec[:, seg[0]:seg[1]]) inputs['global'] = nd.concat(*input_global, dim=1) layer = [] # inputs -> first_hidden_layer sorted_inputs = [] for slot in self.slots: sorted_inputs.append(inputs[slot]) sorted_inputs.append(inputs['global']) layer.append(self.input_trans.forward(sorted_inputs, loss, training=training)) # hidden_layers for i in range(self.hidden_layers - 1): layer.append(self.ma_trans[i](layer[i], loss)) if self.share_last_layer is False: # dropout of last hidden layer for j in range(len(self.slots)): layer[-1][j] = self.local_out_drop_op.forward(layer[-1][j]) layer[-1][-1] = self.global_out_drop_op.forward(layer[-1][-1]) # last_hidden_layer -> outputs slotv_probs = [] slotqs = [] slot_probs = [] top_decision = [] for i in range(len(self.slots) + 1): if i < len(self.slots): cur_slotv_prob = self.output_trans_local_valueP.forward(layer[-1][i], training=training) else: cur_slotv_prob = self.output_trans_global_valueP.forward(layer[-1][i], training=training) cur_slotv_prob_adv = cur_slotv_prob - nd.max(cur_slotv_prob, axis=1, keepdims=True) if i < len(self.slots): cur_slotq = self.output_trans_local_slotQ.forward(layer[-1][i], training=training) cur_slot_prob = self.output_trans_local_slotP.forward(layer[-1][i], training=training).reshape(-1, 1) if self.shared_last_layer_use_bias: cur_slotq = cur_slotq + nd.slice(self.value_bias_local.data(), begin=(i,), end=(i + 1,)) else: cur_slotq = self.output_trans_global_slotQ.forward(layer[-1][i], training=training) cur_slot_prob = self.output_trans_global_slotP.forward(layer[-1][i], training=training).reshape(-1, 1) cur_slotv_prob = cur_slot_prob + cur_slotv_prob_adv top_decision.append(cur_slot_prob) slotv_probs.append(cur_slotv_prob) slot_probs.append(cur_slot_prob) slotqs.append(cur_slotq) batch_slot_slotq = nd.concat(*slotqs, dim=1) batch_slotv_prob = nd.softmax(nd.concat(*slotv_probs, dim=1)) batch_top_decision = nd.softmax(nd.concat(*top_decision, dim=1)) prob = batch_slotv_prob value = nd.sum(batch_top_decision * batch_slot_slotq, axis=1) top_decision = batch_top_decision return prob, value, top_decision
def narrow_row(data, start, stop): return nd.slice(data, begin=start, end=stop)
def pre_fftfilt(b, shape, nfft=None): (numsamples, numsamplepoints) = shape inputNoise = nd.random.randn(numsamples, numsamplepoints, ctx=mx.cpu()) b, x = b.reshape((-1, 1)), inputNoise.T m = x.shape[0] if m == 1: x = x.reshape((-1, 1)) # turn row into a column nx = x.shape[0] if min(b.shape) > 1: assert b.shape[1] == x.shape[ 1] and x.shape[1] <= 1, "signal:fftfilt:InvalidDimensions" else: b = b.reshape((-1, 1)) # make input a column nb = b.shape[0] if nfft == None: # figure out which nfft and L to use if (nb >= nx) or (nb > 2**20): # take a single FFT in this case nfft = int(2**round(np.log(nb + nx - 1) / np.log(2))) L = nx else: fftflops = nd.array([ 18, 59, 138, 303, 660, 1441, 3150, 6875, 14952, 32373, 69762, 149647, 319644, 680105, 1441974, 3047619, 6422736, 13500637, 28311786, 59244791, 59244791 * 2.09 ]) n = 2**nd.arange(1, 22, 1) validset_first = nd.argmax(n > nb - 1, axis=0).asscalar() n = nd.slice(n, begin=[ int(validset_first), ], end=(None, )) fftflops = nd.slice(fftflops, begin=[ int(validset_first), ], end=(None, )) # minimize (number of blocks) * (number of flops per fft) L = n - (nb - 1) temp = nd.ceil(nx / L) * fftflops dum, ind = nd.min(temp), nd.argmin(temp, axis=0) nfft = int(n[int(ind.asscalar())].asscalar()) L = int(L[int(ind.asscalar())].asscalar()) else: # nfft is given # Cast to enforce precision rules pass raise 'nfft is given?' ''' nfft = signal.internal.sigcasttofloat(nfft,'double','fftfilt','N','allownumeric'); if nfft < nb nfft = nb; end nfft = 2.^(ceil(log(nfft)/log(2))); % force this to a power of 2 for speed L = nfft - nb + 1; ''' # Check the input data type. Single precision is not supported. ''' try chkinputdatatype(b,x,nfft); catch ME throwAsCaller(ME); end''' return (b, m, nx, nb, L, nfft)
def verify_broadcast_like_dynamic(xshp, wshp, lhs_axes, rhs_axes): x_np = np.random.uniform(size=xshp) w_np = np.random.uniform(size=wshp) x = nd.array(x_np) w = nd.array(w_np) # org op y = nd.broadcast_like(x, w, lhs_axes=lhs_axes, rhs_axes=rhs_axes) print(y.shape) # rewrite op xndims, wndims = len(xshp), len(wshp) if lhs_axes is None or rhs_axes is None: assert xndims == wndims and lhs_axes is None \ and rhs_axes is None z = _broadcast_like(x, w) else: lhs_axes, lndims = list(lhs_axes), len(lhs_axes) rhs_axes, rndims = list(rhs_axes), len(rhs_axes) assert lndims == rndims > 0 lhs_axes = tuple([v+xndims if v<0 else v for v in lhs_axes]) assert all([0<=v<xndims for v in list(lhs_axes)]) rhs_axes = tuple([v+wndims if v<0 else v for v in rhs_axes]) assert all([0<=v<wndims for v in list(rhs_axes)]) assert all([xshp[lhs_axes[i]] == 1 for i in range(lndims)]) batch_axes = [0] flg = all([batch_axis not in rhs_axes \ for batch_axis in batch_axes]) if flg: cnts = {v: wshp[rhs_axes[i]] \ for i, v in enumerate(lhs_axes)} reps = tuple([cnts[v] if v in lhs_axes else 1 \ for v in range(xndims)]) z = nd.tile(x, reps=reps) else: axis_map = {} for i, v in enumerate(lhs_axes): axis_map[v] = rhs_axes[i] for batch_axis in batch_axes: assert sum([1 if v == batch_axis else 0 \ for k, v in axis_map.items()]) <= 1, \ "multiple broadcast on batch_axis: %s, " + \ "which is not support by dynamic shape fusion." % \ batch_axis assert wndims < 6, \ "slice can manipulate at most 5d" # reduce shape to 1 for non-broadcast dimensions begin = tuple([0]*wndims) end = tuple([wshp[v] if v in axis_map.values() else 1 \ for v in range(wndims)]) w = nd.slice(w, begin=begin, end=end) # decompose k1->v, k2->v into k1->v, k2->v2 # which make axis while True: vs, flag, paxis_map = set(), True, axis_map for pk, pv in paxis_map.items(): if pv not in vs: vs.add(pv) continue flag = False axis_map = {k: (v+1 if v>pv or k==pk else v) \ for k, v in axis_map.items()} w = nd.expand_dims(w, axis=pv) w = nd.repeat(w, axis=pv, repeats=wshp[pv]) wshp = wshp[:pv] + (wshp[pv],) + wshp[pv:] break if flag: break wndims = len(wshp) # trim wndims if not equal to xndims v = 0 while wndims > xndims: while v in axis_map.values(): v += 1 w = nd.squeeze(w, axis=v) wndims -= 1 axis_map = {k: (nv-1 if nv > v else nv) \ for k, nv in axis_map.items()} while wndims < xndims: w = nd.expand_dims(w, axis=wndims) wndims += 1 axes = list(range(wndims)) while True: dels = [k for k, v in axis_map.items() if k==v] for k in dels: del axis_map[k] if not axis_map: break keys = list(axis_map.keys()) k, v = keys[0], axis_map[keys[0]] axes[k], axes[v] = axes[v], axes[k] for nk in keys: nv = axis_map[nk] if nv == k: axis_map[nk] = v elif nv == v: axis_map[nk] = k axes = tuple(axes) if axes != tuple(range(wndims)): assert wndims < 7, \ "slice can manipulate at most 6d" w = nd.transpose(w, axes=axes) z = _broadcast_like(x, w) print(z.shape) # compare assert z.shape == y.shape zn, zp = get_norm(z) yn, yp = get_norm(y) rn = np.linalg.norm(zp-yp) print(zn, yn, rn)