def train_step(xy_pair, lr): with tf.GradientTape() as g2nd: with tf.GradientTape() as g1st: loss = train_loss(xy_pair) grads = g1st.gradient(loss, lenet5_vars) vs = [tf.random.normal(W.shape) for W in lenet5_vars] # a random vector hess_vs = g2nd.gradient(grads, lenet5_vars, vs) # Hessian-vector products new_Qs = [ psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv) for (Qlr, v, Hv) in zip(Qs, vs, hess_vs) ] [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])] for (Qlr, new_Qlr) in zip(Qs, new_Qs)] pre_grads = [ psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads) ] grad_norm = tf.sqrt(sum([tf.reduce_sum(g * g) for g in pre_grads])) lr_adjust = tf.minimum(grad_norm_clip_thr / grad_norm, 1.0) [ W.assign_sub(lr_adjust * lr * g) for (W, g) in zip(lenet5_vars, pre_grads) ] return loss
def opt_step(): with tf.GradientTape() as g2nd: # second order derivative with tf.GradientTape() as g1st: # first order derivative cost = f() grads = g1st.gradient(cost, xyz) # gradient vs = [tf.random.normal(w.shape) for w in xyz] # a random vector hess_vs = g2nd.gradient(grads, xyz, vs) # Hessian-vector products new_Qs = [psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv, step=0.1) for (Qlr, v, Hv) in zip(Qs, vs, hess_vs)] [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])] for (Qlr, new_Qlr) in zip(Qs, new_Qs)] pre_grads = [psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads)] [w.assign_sub(0.1*g) for (w, g) in zip(xyz, pre_grads)] return cost
def train_step_apprx_Hvp(inp, targ, lr): """ Demenstrate the usage of approximated Hessian-vector product (Hvp). Typically smaller graph and faster excution. But, the accuracy of Hvp might be in question. """ delta = tf.constant( 2**(-23 / 2) ) # sqrt(float32.eps), the scale of perturbation for Hessian-vector product approximation def eval_loss_grads(): with tf.GradientTape() as gt: enc_output = encoder(inp) dec_input = targ[:, 0] dec_h = enc_output[:, -1, :] # the last hidden state of encoder as the initial one for decoder loss = 0.0 for t in range(1, targ.shape[1]): pred, dec_h = decoder(dec_input, dec_h, enc_output) loss += loss_function(targ[:, t], pred) dec_input = targ[:, t] loss /= targ.shape[1] return loss, gt.gradient(loss, trainable_vars) # loss and gradients loss, grads = eval_loss_grads() # calculate the perturbed gradients vs = [delta * tf.random.normal(W.shape) for W in trainable_vars] [W.assign_add(v) for (W, v) in zip(trainable_vars, vs)] _, perturbed_grads = eval_loss_grads() # update the preconditioners new_Qs = [ psgd.update_precond_kron(Qlr[0], Qlr[1], v, tf.subtract(perturbed_g, g)) for (Qlr, v, perturbed_g, g) in zip(Qs, vs, perturbed_grads, grads) ] [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])] for (Qlr, new_Qlr) in zip(Qs, new_Qs)] # calculate the preconditioned gradients pre_grads = [ psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads) ] # update variables; do not forget to remove the perturbations on variables [ W.assign_sub(lr * g + v) for (W, g, v) in zip(trainable_vars, pre_grads, vs) ] return loss
def train_step_exact_Hvp(inp, targ, lr): """ Demenstrate the usage of exact Hessian-vector product (Hvp). Typically larger graph and slower excution. But, Hvp is exact, and cleaner code. """ with tf.GradientTape() as g2nd: with tf.GradientTape() as g1st: enc_output = encoder(inp) dec_input = targ[:, 0] dec_h = enc_output[:, -1, :] # the last hidden state of encoder as the initial one for decoder loss = 0.0 for t in range(1, targ.shape[1]): pred, dec_h = decoder(dec_input, dec_h, enc_output) loss += loss_function(targ[:, t], pred) dec_input = targ[:, t] loss /= targ.shape[1] grads = g1st.gradient(loss, trainable_vars) grads = [ tf.convert_to_tensor(g) if isinstance(g, tf.IndexedSlices) else g for g in grads ] vs = [tf.random.normal(W.shape) for W in trainable_vars] # a random vector hess_vs = g2nd.gradient(grads, trainable_vars, vs) # Hessian-vector products new_Qs = [ psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv) for (Qlr, v, Hv) in zip(Qs, vs, hess_vs) ] [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])] for (Qlr, new_Qlr) in zip(Qs, new_Qs)] pre_grads = [ psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads) ] [W.assign_sub(lr * g) for (W, g) in zip(trainable_vars, pre_grads)] return loss
for label in labels: new_labels.append(U.remove_repetitive_labels(label)) loss = train_loss(new_images, new_labels) grads = grad(loss, Ws, create_graph=True) TrainLoss.append(loss.item()) v = [torch.randn(W.shape, device=device) for W in Ws] Hv = grad(grads, Ws, v) with torch.no_grad(): Qs = [ psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv) ] pre_grads = [ psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads) ] grad_norm = torch.sqrt(sum([torch.sum(g * g) for g in pre_grads])) lr_adjust = min(grad_norm_clip_thr / (grad_norm + 1.2e-38), 1.0) for i in range(len(Ws)): Ws[i] -= lr_adjust * lr * pre_grads[i] if num_iter % 100 == 0: print( 'epoch: {}; iter: {}; train loss: {}; elapsed time: {}'.format( epoch, num_iter, TrainLoss[-1], time.time() - t0)) TestLoss.append(test_loss()) if TestLoss[-1] < BestTestLoss:
Qs = [[0.1*torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws] step_size = 0.1#normalized step size grad_norm_clip_thr = 100.0#the size of trust region # begin iteration here TrainLoss = [] TestLoss = [] echo_every = 100#iterations for num_iter in range(5000): x, y = get_batches( ) # calculate the loss and gradient loss = train_criterion(Ws, x, y) grads = grad(loss, Ws, create_graph=True) TrainLoss.append(loss.item()) v = [torch.randn(W.shape).to(device) for W in Ws] #grad_v = sum([torch.sum(g*d) for (g, d) in zip(grads, v)]) #hess_v = grad(grad_v, Ws) Hv = grad(grads, Ws, v)#replace the above two lines, due to Bulatov #Hv = grads # just let Hv=grads if you use Fisher type preconditioner with torch.no_grad(): Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)] pre_grads = [psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)] grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads])) step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0) for i in range(len(Ws)): Ws[i] -= step_adjust*step_size*pre_grads[i] if (num_iter+1) % echo_every == 0: loss = test_criterion(Ws) TestLoss.append(loss.item()) print('train loss: {}; test loss: {}'.format(TrainLoss[-1], TestLoss[-1]))
def main(): use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") print("using device ", device) torch.manual_seed(args.seed) u.set_runs_directory('runs3') logger = u.TensorboardLogger(args.run) batch_size = 64 shuffle = True kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} train_loader = torch.utils.data.DataLoader(datasets.MNIST( '/tmp/data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()])), batch_size=batch_size, shuffle=shuffle, **kwargs) test_loader = torch.utils.data.DataLoader(datasets.MNIST( '/tmp/data', train=False, transform=transforms.Compose([transforms.ToTensor()])), batch_size=1000, shuffle=shuffle, **kwargs) """input image size for the original LeNet5 is 32x32, here is 28x28""" # W1 = 0.1 * torch.randn(1 * 5 * 5 + 1, 6) net = LeNet5().to(device) def train_loss(data, target): y = net(data) y = F.log_softmax(y, dim=1) loss = F.nll_loss(y, target) for w in net.W: loss += 0.0002 * torch.sum(w * w) return loss def test_loss(): num_errs = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) y = net(data) _, pred = torch.max(y, dim=1) num_errs += torch.sum(pred != target) return num_errs.item() / len(test_loader.dataset) Qs = [[torch.eye(w.shape[0]), torch.eye(w.shape[1])] for w in net.W] for i in range(len(Qs)): for j in range(len(Qs[i])): Qs[i][j] = Qs[i][j].to(device) step_size = 0.1 # tried 0.15, diverges grad_norm_clip_thr = 1e10 TrainLoss, TestLoss = [], [] example_count = 0 step_time_ms = 0 for epoch in range(10): for batch_idx, (data, target) in enumerate(train_loader): step_start = time.perf_counter() data, target = data.to(device), target.to(device) loss = train_loss(data, target) with u.timeit('grad'): grads = autograd.grad(loss, net.W, create_graph=True) TrainLoss.append(loss.item()) logger.set_step(example_count) logger('loss/train', TrainLoss[-1]) if batch_idx % 10 == 0: print( f'Epoch: {epoch}; batch: {batch_idx}; train loss: {TrainLoss[-1]:.2f}, step time: {step_time_ms:.0f}' ) with u.timeit('Hv'): # noise.normal_() # torch.manual_seed(args.seed) v = [torch.randn(w.shape).to(device) for w in net.W] # v = grads Hv = autograd.grad(grads, net.W, v) if args.verbose: print("v", v[0].mean()) print("data", data.mean()) print("Hv", Hv[0].mean()) n = len(net.W) with torch.no_grad(): with u.timeit('P_update'): for i in range(num_updates): psteps = [] for j in range(n): q = Qs[j] dw = v[j] dg = Hv[j] Qs[j][0], Qs[j][ 1], pstep = psgd.update_precond_kron_with_step( q[0], q[1], dw, dg) psteps.append(pstep) # print(np.array(psteps).mean()) logger('p_residual', np.array(psteps).mean()) with u.timeit('g_update'): pre_grads = [ psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads) ] grad_norm = torch.sqrt( sum([torch.sum(g * g) for g in pre_grads])) with u.timeit('gradstep'): step_adjust = min( grad_norm_clip_thr / (grad_norm + 1.2e-38), 1.0) for i in range(len(net.W)): net.W[i] -= step_adjust * step_size * pre_grads[i] total_step = step_adjust * step_size logger('step/adjust', step_adjust) logger('step/size', step_size) logger('step/total', total_step) logger('grad_norm', grad_norm) if args.verbose: print(data.mean()) import pdb pdb.set_trace() if args.early_stop: sys.exit() example_count += batch_size step_time_ms = 1000 * (time.perf_counter() - step_start) logger('time/step', step_time_ms) if args.test and batch_idx >= 100: break if args.test and batch_idx >= 100: break test_loss0 = test_loss() TestLoss.append(test_loss0) logger('loss/test', test_loss0) step_size = (0.1**0.1) * step_size print('Epoch: {}; best test loss: {}'.format(epoch, min(TestLoss))) if args.test: step_times = logger.d['time/step'] assert step_times[-1] < 30, step_times # should be around 20ms losses = logger.d['loss/train'] assert losses[0] > 2 # around 2.3887393474578857 assert losses[-1] < 0.5, losses print("Test passed")
def main(): """DNN speech prior training""" device = config.device wavloader = artificial_mixture_generator.WavLoader(config.wav_dir) mixergenerator = artificial_mixture_generator.MixerGenerator( wavloader, config.batch_size, config.num_mic, config.Lh, config.iva_hop_size * config.num_frame, config.prb_mix_change) Ws = [ W.to(device) for W in initW(config.iva_fft_size // 2 - 1, config.src_prior['num_layer'], config. src_prior['num_state'], config.src_prior['dim_h']) ] hs = [ torch.zeros(config.batch_size * config.num_mic, config.src_prior['num_state']).to(device) for _ in range(config.src_prior['num_layer'] - 1) ] for W in Ws: W.requires_grad = True # W_iva initialization W_iva = (100.0 * torch.randn(config.batch_size, config.iva_fft_size // 2 - 1, config.num_mic, config.num_mic, 2)).to(device) # STFT window for IVA win_iva = stft.pre_def_win(config.iva_fft_size, config.iva_hop_size) # preconditioners for the source prior neural network optimization Qs = [[ torch.eye(W.shape[0], device=device), torch.eye(W.shape[1], device=device) ] for W in Ws] # preconditioners for SGD # buffers for STFT s_bfr = torch.zeros(config.batch_size, config.num_mic, config.iva_fft_size - config.iva_hop_size) x_bfr = torch.zeros(config.batch_size, config.num_mic, config.iva_fft_size - config.iva_hop_size) # buffer for overlap-add synthesis and reconstruction loss calculation ola_bfr = torch.zeros(config.batch_size, config.num_mic, config.iva_fft_size).to(device) xtr_bfr = torch.zeros(config.batch_size, config.num_mic, config.iva_fft_size - config.iva_hop_size ) # extra buffer for reconstruction loss calculation Loss, lr = [], config.psgd_setting['lr'] for bi in range(config.psgd_setting['num_iter']): srcs, xs = mixergenerator.get_mixtures() Ss, s_bfr = stft.stft(srcs[:, :, config.Lh:-config.Lh], win_iva, config.iva_hop_size, s_bfr) Xs, x_bfr = stft.stft(xs, win_iva, config.iva_hop_size, x_bfr) Ss, Xs = Ss.to(device), Xs.to(device) Ys, W_iva, hs = iva(Xs, W_iva.detach(), [h.detach() for h in hs], Ws, config.iva_lr) # loss calculation coherence = losses.di_pi_coh(Ss, Ys) loss = 1.0 - coherence if config.use_spectra_dist_loss: spectra_dist = losses.di_pi_is_dist( Ss[:, :, :, 0] * Ss[:, :, :, 0] + Ss[:, :, :, 1] * Ss[:, :, :, 1], Ys[:, :, :, 0] * Ys[:, :, :, 0] + Ys[:, :, :, 1] * Ys[:, :, :, 1]) loss = loss + spectra_dist if config.reconstruction_loss_fft_sizes: srcs = torch.cat([xtr_bfr, srcs], dim=2) ys, ola_bfr = stft.istft(Ys, win_iva.to(device), config.iva_hop_size, ola_bfr.detach()) for fft_size in config.reconstruction_loss_fft_sizes: win = stft.coswin(fft_size) Ss, _ = stft.stft(srcs, win, fft_size // 2) Ys, _ = stft.stft(ys, win.to(device), fft_size // 2) Ss = Ss.to(device) coherence = losses.di_pi_coh(Ss, Ys) loss = loss + 1.0 - coherence if config.use_spectra_dist_loss: spectra_dist = losses.di_pi_is_dist( Ss[:, :, :, 0] * Ss[:, :, :, 0] + Ss[:, :, :, 1] * Ss[:, :, :, 1], Ys[:, :, :, 0] * Ys[:, :, :, 0] + Ys[:, :, :, 1] * Ys[:, :, :, 1]) loss = loss + spectra_dist xtr_bfr = srcs[:, :, -(config.iva_fft_size - config.iva_hop_size):] Loss.append(loss.item()) if config.use_spectra_dist_loss: print('Loss: {}; coherence: {}; spectral_distance: {}'.format( loss.item(), coherence.item(), spectra_dist.item())) else: print('Loss: {}; coherence: {}'.format(loss.item(), coherence.item())) # Preconditioned SGD optimizer for source prior network optimization Q_update_gap = max(math.floor(math.log10(bi + 1)), 1) if bi % Q_update_gap == 0: # update preconditioner less frequently grads = grad(loss, Ws, create_graph=True) v = [torch.randn(W.shape, device=device) for W in Ws] Hv = grad(grads, Ws, v) with torch.no_grad(): Qs = [ psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv) ] else: grads = grad(loss, Ws) with torch.no_grad(): pre_grads = [ psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads) ] for i in range(len(Ws)): Ws[i] -= lr * pre_grads[i] if bi == int(0.9 * config.psgd_setting['num_iter']): lr *= 0.1 if (bi + 1) % 1000 == 0 or bi + 1 == config.psgd_setting['num_iter']: scipy.io.savemat( 'src_prior.mat', dict([('W' + str(i), W.cpu().detach().numpy()) for (i, W) in enumerate(Ws)])) plt.plot(Loss)
def adjust_grads(self, grads): assert len(grads) == len(self.net.W) return [ psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(self.Qs, grads) ]
with tf.Session() as sess: Qs_left = [ tf.Variable(tf.eye(W.shape.as_list()[0], dtype=dtype), trainable=False) for W in Ws ] Qs_right = [ tf.Variable(tf.eye(W.shape.as_list()[1], dtype=dtype), trainable=False) for W in Ws ] train_loss = train_criterion(Ws) grads = tf.gradients(train_loss, Ws) precond_grads = [ psgd.precond_grad_kron(ql, qr, g) for (ql, qr, g) in zip(Qs_left, Qs_right, grads) ] grad_norm = tf.sqrt( tf.reduce_sum([tf.reduce_sum(g * g) for g in precond_grads])) step_size_adjust = tf.minimum(1.0, grad_norm_clip_thr / (grad_norm + 1.2e-38)) new_Ws = [ W - (step_size_adjust * step_size) * g for (W, g) in zip(Ws, precond_grads) ] update_Ws = [tf.assign(W, new_W) for (W, new_W) in zip(Ws, new_Ws)] delta_Ws = [tf.random_normal(W.shape, dtype=dtype) for W in Ws] grad_deltaw = tf.reduce_sum( [tf.reduce_sum(g * v) for (g, v) in zip(grads, delta_Ws)])
def test_loss( ): num_errs = 0 with torch.no_grad(): for data, target in test_loader: y = lenet5(data) _, pred = torch.max(y, dim=1) num_errs += torch.sum(pred!=target) return num_errs.item()/len(test_loader.dataset) Qs = [[torch.eye(W.shape[0]), torch.eye(W.shape[1])] for W in lenet5.parameters()] lr = 0.1 grad_norm_clip_thr = 0.1*sum(W.numel() for W in lenet5.parameters())**0.5 TrainLosses, best_test_loss = [], 1.0 for epoch in range(10): for _, (data, target) in enumerate(train_loader): loss = train_loss(data, target) grads = torch.autograd.grad(loss, lenet5.parameters(), create_graph=True) vs = [torch.randn_like(W) for W in lenet5.parameters()] Hvs = torch.autograd.grad(grads, lenet5.parameters(), vs) with torch.no_grad(): Qs = [psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv) for (Qlr, v, Hv) in zip(Qs, vs, Hvs)] pre_grads = [psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads)] grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads])) lr_adjust = min(grad_norm_clip_thr/grad_norm, 1.0) [W.subtract_(lr_adjust*lr*g) for (W, g) in zip(lenet5.parameters(), pre_grads)] TrainLosses.append(loss.item()) best_test_loss = min(best_test_loss, test_loss()) lr *= (0.01)**(1/9) print('Epoch: {}; best test classification error rate: {}'.format(epoch+1, best_test_loss)) plt.plot(TrainLosses)
], [0.1 * torch.eye(R), torch.ones(1, J)], [ 0.1 * torch.stack([torch.ones(R), torch.zeros(R)], dim=0), torch.ones(1, K) ], ] # # example 3 # Qs = [[0.1*torch.eye(w.shape[0]), torch.eye(w.shape[1])] for w in xyz] for _ in range(100): loss = f() f_values.append(loss.item()) grads = torch.autograd.grad(loss, xyz, create_graph=True) vs = [torch.randn_like(w) for w in xyz] Hvs = torch.autograd.grad(grads, xyz, vs) with torch.no_grad(): Qs = [ psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv, step=0.1) for (Qlr, v, Hv) in zip(Qs, vs, Hvs) ] pre_grads = [ psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads) ] [w.subtract_(0.1 * g) for (w, g) in zip(xyz, pre_grads)] plt.semilogy(f_values) plt.xlabel('Iterations') plt.ylabel('Decomposition losses')
# update preconditioners Q_update_gap = max(int(np.floor(np.log10(num_iter + 1.0))), 1) if num_iter % Q_update_gap == 0: # let us update Q less frequently delta = [Variable(torch.randn(W.size())) for W in Ws] grad_delta = sum([torch.sum(g * d) for (g, d) in zip(grads, delta)]) hess_delta = grad(grad_delta, Ws) Qs = [ psgd.update_precond_kron(q[0], q[1], dw.data.numpy(), dg.data.numpy()) for (q, dw, dg) in zip(Qs, delta, hess_delta) ] # update Ws pre_grads = [ psgd.precond_grad_kron(q[0], q[1], g.data.numpy()) for (q, g) in zip(Qs, grads) ] grad_norm = np.sqrt(sum([np.sum(g * g) for g in pre_grads])) if grad_norm > grad_norm_clip_thr: step_adjust = grad_norm_clip_thr / grad_norm else: step_adjust = 1.0 for i in range(len(Ws)): Ws[i].data = Ws[i].data - step_adjust * step_size * torch.FloatTensor( pre_grads[i]) if num_iter % 100 == 0: print('training loss: {}'.format(Loss[-1])) plt.semilogy(Loss)