def orth_test(): #p30 = NN("../plain30/data/plain30.data") #p30 = NN("../../resnet20/data/resnet20_acc91.7") #p30 = NN("../plain30_orth/data/plain30_orth.data") p30 = NN("../../densenetl100k24/data/densenetl100k24.data") #p30 = NN("../../lrvswc/wc/data/lr.data") net = p30.net loss = net.loss_var visitor = NetworkVisitor(loss) """ for i in visitor.all_oprs: print(i) """ for i in visitor.all_oprs: if ":W" in i.name: W = i.eval() print(W.shape) if W.ndim != 4: continue #W = W.reshape(W.shape[0], -1) W = W.sum(axis=3).sum(axis=2) print(W.shape) #plt.plot(range(len(W[:, 0])), abs(W[:, 1])) W = np.array(W) W = W.T W = W / ((W**2).sum(axis=0)**0.5) A = np.dot(W.T, W) print(list(np.round(A * 1000).astype(np.int32) / 1000)) I = np.identity(A.shape[0]) print(np.round(((A - I)**2).mean(), decimals=5)) input()
def orth_test(): #p30 = NN("../plain30/data/plain30.data") #p30 = NN("../../resnet20/data/resnet20_acc91.7") #p30 = NN("../plain30_orth/data/plain30_orth.data") #p30 = NN("../../densenetl100k24/data/densenetl100k24.data") p30 = NN("../../lrvswc/wc/data/lr.data") net = p30.net loss = net.loss_var visitor = NetworkVisitor(loss) """ for i in visitor.all_oprs: print(i) """ W = visitor.all_oprs_dict["fc0:W"] W = W.eval() print(W.shape) #plt.plot(range(len(W[:, 0])), abs(W[:, 1])) W = np.array(W) W = W / ((W**2).sum(axis=0)**0.5) A = np.dot(W.T, W) print(A) I = np.identity(10) print(((A - I)**2).mean()) with open("p30wcW.data", "wb") as f: pickle.dump(W, f)
def orth_test1(): p30 = NN("../plain30_xcep/data/plain30_xcep.data") net = p30.net visitor = NetworkVisitor(net.loss_var) for i in visitor.all_oprs: print(i) print(i.partial_shape) """
def slim(): r20 = NN("./data/slm_res20.data") visitor = NetworkVisitor(r20.net.loss_var) lis_k = [] for i in visitor.all_oprs: if ":k" in i.name: lis_k.append(i.eval()) print(lis_k)
def test(): data, labels = load_CIFAR_data() p120 = NN("data/p120.data") net = p120.net loss = net.loss_var visitor = NetworkVisitor(loss) inp = [] for i in visitor.all_oprs: if "data" in i.name: inp.append(i) if "conv" in i.name and ":" not in i.name: inp.append(i) print(inp) grad = [] out = [] for i in inp: grad.append(O.Grad(loss, i)) out.append(i) F = Function() F._env.flags.train_batch_normalization = True func = F.compile(grad) F = Function() F._env.flags.train_batch_normalization = True func1 = F.compile(out) batch = data[:128] batch = batch.reshape(128, 3, 32, 32) mean, std = p120.mean, p120.std batch = (batch - mean) / std label = labels[:128] grad_out = func(data=batch, label=label) lay_out = func1(data=batch, label=label) idx = 0 grad_list = [] for i, j in zip(grad_out, lay_out): print(i.shape, idx) idx += 1 f = i.flatten() print("grad") print(f) print(np.mean(f), np.std(f)) grad_list.append(np.std(f)) print("val") h = j.flatten() print(h) print(np.mean(h), np.std(j)) pickle.dump(grad_list, open("p120_norelu_grad.data", "wb")) """
def trans_test(): #p30 = NN("/home/liuyanyi02/CIFAR/slimming/resnet20/data/slm_res20.data") p30 = NN("data/fixedfc_res110_rand.data") net = p30.net loss = net.loss_var visitor = NetworkVisitor(loss) W0 = None for i in visitor.all_oprs: if "fc0:W" in i.name: W1 = i.eval() #W1 = W1.sum(axis = 3).sum(axis = 2) #W1 = W1 / ((W1**2).sum(axis = 0)) #W1 = W1.T W1 = W1 / ((W1**2).sum(axis=0)**0.5) A = np.dot(W1.T, W1) print(np.round(A * 1000).astype(np.int32) / 1000) I = np.identity(A.shape[0]) print(i) print(((A - I)**2).mean()) input()
def myw_test(): d40_MY = NN("./data/r20_MY.data") net = d40_MY.net outputs = [] visitor = NetworkVisitor(net.loss_var) for i in visitor.all_oprs: if "fc1" in i.name and ":W" not in i.name and ":b" not in i.name: outputs.append(i) func = Function().compile(outputs) data, labels = load_CIFAR_data() batch = data[:128] batch = batch.reshape(128, 3, 32, 32) mean, std = d40_MY.mean, d40_MY.std batch = (batch - mean) / std outputs_weights = func(data=batch) for i in outputs_weights: print(i.shape) w = i[0] w = w.reshape(-1, 4, 4) print(w) input()
def trans_test(): p30 = NN("/home/liuyanyi02/CIFAR/slimming/resnet20/data/slm_res20.data") net = p30.net loss = net.loss_var visitor = NetworkVisitor(loss) W0 = None for i in visitor.all_oprs: if ":W" in i.name: W1 = i.eval() W1 = W1.sum(axis = 3).sum(axis = 2) #W1 = W1 / ((W1**2).sum(axis = 0)) W1 = W1.T if W0 is None: W0 = W1 continue if W0.shape != W1.shape: W0 = W1 continue A = np.dot(W0, W1) print(np.round(A * 1000).astype(np.int32) / 1000) I = np.identity(A.shape[0]) print(i) print(((A - I)**2).mean()) input()
def init(net, batch): visitor = NetworkVisitor(net.loss_var) lisk = [] lisb = [] for i in visitor.all_oprs: if ":k" in i.name and "bnaff" in i.name: lisk.append(i) if ":b" in i.name and "bnaff" in i.name: lisb.append(i) for i, k, b in zip(range(len(lisk)), lisk, lisb): func = Function().compile(net.outputs) outputs = func(data=batch['data']) t = outputs[1 + i] mean = t.mean(axis=3).mean(axis=2).mean(axis=0) std = ((t - mean[np.newaxis, :, np.newaxis, np.newaxis])**2).mean( axis=3).mean(axis=2).mean(axis=0)**0.5 nk = O.ParamProvider("new" + k.name, 1.0 / std) nb = O.ParamProvider("new" + b.name, -mean / std) visitor.replace_vars([(k, nk), (b, nb)], copy=False) visitor = NetworkVisitor(net.loss_var) for i in visitor.all_oprs: print(i) return net