def train_step(xy_pair, lr): with tf.GradientTape() as g2nd: with tf.GradientTape() as g1st: loss = train_loss(xy_pair) grads = g1st.gradient(loss, lenet5_vars) vs = [tf.random.normal(W.shape) for W in lenet5_vars] # a random vector hess_vs = g2nd.gradient(grads, lenet5_vars, vs) # Hessian-vector products new_Qs = [ psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv) for (Qlr, v, Hv) in zip(Qs, vs, hess_vs) ] [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])] for (Qlr, new_Qlr) in zip(Qs, new_Qs)] pre_grads = [ psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads) ] grad_norm = tf.sqrt(sum([tf.reduce_sum(g * g) for g in pre_grads])) lr_adjust = tf.minimum(grad_norm_clip_thr / grad_norm, 1.0) [ W.assign_sub(lr_adjust * lr * g) for (W, g) in zip(lenet5_vars, pre_grads) ] return loss
def opt_step(): with tf.GradientTape() as g2nd: # second order derivative with tf.GradientTape() as g1st: # first order derivative cost = f() grads = g1st.gradient(cost, xyz) # gradient vs = [tf.random.normal(w.shape) for w in xyz] # a random vector hess_vs = g2nd.gradient(grads, xyz, vs) # Hessian-vector products new_Qs = [psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv, step=0.1) for (Qlr, v, Hv) in zip(Qs, vs, hess_vs)] [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])] for (Qlr, new_Qlr) in zip(Qs, new_Qs)] pre_grads = [psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads)] [w.assign_sub(0.1*g) for (w, g) in zip(xyz, pre_grads)] return cost
def train_step_apprx_Hvp(inp, targ, lr): """ Demenstrate the usage of approximated Hessian-vector product (Hvp). Typically smaller graph and faster excution. But, the accuracy of Hvp might be in question. """ delta = tf.constant( 2**(-23 / 2) ) # sqrt(float32.eps), the scale of perturbation for Hessian-vector product approximation def eval_loss_grads(): with tf.GradientTape() as gt: enc_output = encoder(inp) dec_input = targ[:, 0] dec_h = enc_output[:, -1, :] # the last hidden state of encoder as the initial one for decoder loss = 0.0 for t in range(1, targ.shape[1]): pred, dec_h = decoder(dec_input, dec_h, enc_output) loss += loss_function(targ[:, t], pred) dec_input = targ[:, t] loss /= targ.shape[1] return loss, gt.gradient(loss, trainable_vars) # loss and gradients loss, grads = eval_loss_grads() # calculate the perturbed gradients vs = [delta * tf.random.normal(W.shape) for W in trainable_vars] [W.assign_add(v) for (W, v) in zip(trainable_vars, vs)] _, perturbed_grads = eval_loss_grads() # update the preconditioners new_Qs = [ psgd.update_precond_kron(Qlr[0], Qlr[1], v, tf.subtract(perturbed_g, g)) for (Qlr, v, perturbed_g, g) in zip(Qs, vs, perturbed_grads, grads) ] [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])] for (Qlr, new_Qlr) in zip(Qs, new_Qs)] # calculate the preconditioned gradients pre_grads = [ psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads) ] # update variables; do not forget to remove the perturbations on variables [ W.assign_sub(lr * g + v) for (W, g, v) in zip(trainable_vars, pre_grads, vs) ] return loss
def train_step_exact_Hvp(inp, targ, lr): """ Demenstrate the usage of exact Hessian-vector product (Hvp). Typically larger graph and slower excution. But, Hvp is exact, and cleaner code. """ with tf.GradientTape() as g2nd: with tf.GradientTape() as g1st: enc_output = encoder(inp) dec_input = targ[:, 0] dec_h = enc_output[:, -1, :] # the last hidden state of encoder as the initial one for decoder loss = 0.0 for t in range(1, targ.shape[1]): pred, dec_h = decoder(dec_input, dec_h, enc_output) loss += loss_function(targ[:, t], pred) dec_input = targ[:, t] loss /= targ.shape[1] grads = g1st.gradient(loss, trainable_vars) grads = [ tf.convert_to_tensor(g) if isinstance(g, tf.IndexedSlices) else g for g in grads ] vs = [tf.random.normal(W.shape) for W in trainable_vars] # a random vector hess_vs = g2nd.gradient(grads, trainable_vars, vs) # Hessian-vector products new_Qs = [ psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv) for (Qlr, v, Hv) in zip(Qs, vs, hess_vs) ] [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])] for (Qlr, new_Qlr) in zip(Qs, new_Qs)] pre_grads = [ psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads) ] [W.assign_sub(lr * g) for (W, g) in zip(trainable_vars, pre_grads)] return loss
new_images = torch.tensor(images / 256, dtype=torch.float, device=device) new_labels = [] for label in labels: new_labels.append(U.remove_repetitive_labels(label)) loss = train_loss(new_images, new_labels) grads = grad(loss, Ws, create_graph=True) TrainLoss.append(loss.item()) v = [torch.randn(W.shape, device=device) for W in Ws] Hv = grad(grads, Ws, v) with torch.no_grad(): Qs = [ psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv) ] pre_grads = [ psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads) ] grad_norm = torch.sqrt(sum([torch.sum(g * g) for g in pre_grads])) lr_adjust = min(grad_norm_clip_thr / (grad_norm + 1.2e-38), 1.0) for i in range(len(Ws)): Ws[i] -= lr_adjust * lr * pre_grads[i] if num_iter % 100 == 0: print( 'epoch: {}; iter: {}; train loss: {}; elapsed time: {}'.format( epoch, num_iter, TrainLoss[-1],
Qs = [[0.1*torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws] step_size = 0.1#normalized step size grad_norm_clip_thr = 100.0#the size of trust region # begin iteration here TrainLoss = [] TestLoss = [] echo_every = 100#iterations for num_iter in range(5000): x, y = get_batches( ) # calculate the loss and gradient loss = train_criterion(Ws, x, y) grads = grad(loss, Ws, create_graph=True) TrainLoss.append(loss.item()) v = [torch.randn(W.shape).to(device) for W in Ws] #grad_v = sum([torch.sum(g*d) for (g, d) in zip(grads, v)]) #hess_v = grad(grad_v, Ws) Hv = grad(grads, Ws, v)#replace the above two lines, due to Bulatov #Hv = grads # just let Hv=grads if you use Fisher type preconditioner with torch.no_grad(): Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)] pre_grads = [psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)] grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads])) step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0) for i in range(len(Ws)): Ws[i] -= step_adjust*step_size*pre_grads[i] if (num_iter+1) % echo_every == 0: loss = test_criterion(Ws) TestLoss.append(loss.item()) print('train loss: {}; test loss: {}'.format(TrainLoss[-1], TestLoss[-1]))
def main(): """DNN speech prior training""" device = config.device wavloader = artificial_mixture_generator.WavLoader(config.wav_dir) mixergenerator = artificial_mixture_generator.MixerGenerator( wavloader, config.batch_size, config.num_mic, config.Lh, config.iva_hop_size * config.num_frame, config.prb_mix_change) Ws = [ W.to(device) for W in initW(config.iva_fft_size // 2 - 1, config.src_prior['num_layer'], config. src_prior['num_state'], config.src_prior['dim_h']) ] hs = [ torch.zeros(config.batch_size * config.num_mic, config.src_prior['num_state']).to(device) for _ in range(config.src_prior['num_layer'] - 1) ] for W in Ws: W.requires_grad = True # W_iva initialization W_iva = (100.0 * torch.randn(config.batch_size, config.iva_fft_size // 2 - 1, config.num_mic, config.num_mic, 2)).to(device) # STFT window for IVA win_iva = stft.pre_def_win(config.iva_fft_size, config.iva_hop_size) # preconditioners for the source prior neural network optimization Qs = [[ torch.eye(W.shape[0], device=device), torch.eye(W.shape[1], device=device) ] for W in Ws] # preconditioners for SGD # buffers for STFT s_bfr = torch.zeros(config.batch_size, config.num_mic, config.iva_fft_size - config.iva_hop_size) x_bfr = torch.zeros(config.batch_size, config.num_mic, config.iva_fft_size - config.iva_hop_size) # buffer for overlap-add synthesis and reconstruction loss calculation ola_bfr = torch.zeros(config.batch_size, config.num_mic, config.iva_fft_size).to(device) xtr_bfr = torch.zeros(config.batch_size, config.num_mic, config.iva_fft_size - config.iva_hop_size ) # extra buffer for reconstruction loss calculation Loss, lr = [], config.psgd_setting['lr'] for bi in range(config.psgd_setting['num_iter']): srcs, xs = mixergenerator.get_mixtures() Ss, s_bfr = stft.stft(srcs[:, :, config.Lh:-config.Lh], win_iva, config.iva_hop_size, s_bfr) Xs, x_bfr = stft.stft(xs, win_iva, config.iva_hop_size, x_bfr) Ss, Xs = Ss.to(device), Xs.to(device) Ys, W_iva, hs = iva(Xs, W_iva.detach(), [h.detach() for h in hs], Ws, config.iva_lr) # loss calculation coherence = losses.di_pi_coh(Ss, Ys) loss = 1.0 - coherence if config.use_spectra_dist_loss: spectra_dist = losses.di_pi_is_dist( Ss[:, :, :, 0] * Ss[:, :, :, 0] + Ss[:, :, :, 1] * Ss[:, :, :, 1], Ys[:, :, :, 0] * Ys[:, :, :, 0] + Ys[:, :, :, 1] * Ys[:, :, :, 1]) loss = loss + spectra_dist if config.reconstruction_loss_fft_sizes: srcs = torch.cat([xtr_bfr, srcs], dim=2) ys, ola_bfr = stft.istft(Ys, win_iva.to(device), config.iva_hop_size, ola_bfr.detach()) for fft_size in config.reconstruction_loss_fft_sizes: win = stft.coswin(fft_size) Ss, _ = stft.stft(srcs, win, fft_size // 2) Ys, _ = stft.stft(ys, win.to(device), fft_size // 2) Ss = Ss.to(device) coherence = losses.di_pi_coh(Ss, Ys) loss = loss + 1.0 - coherence if config.use_spectra_dist_loss: spectra_dist = losses.di_pi_is_dist( Ss[:, :, :, 0] * Ss[:, :, :, 0] + Ss[:, :, :, 1] * Ss[:, :, :, 1], Ys[:, :, :, 0] * Ys[:, :, :, 0] + Ys[:, :, :, 1] * Ys[:, :, :, 1]) loss = loss + spectra_dist xtr_bfr = srcs[:, :, -(config.iva_fft_size - config.iva_hop_size):] Loss.append(loss.item()) if config.use_spectra_dist_loss: print('Loss: {}; coherence: {}; spectral_distance: {}'.format( loss.item(), coherence.item(), spectra_dist.item())) else: print('Loss: {}; coherence: {}'.format(loss.item(), coherence.item())) # Preconditioned SGD optimizer for source prior network optimization Q_update_gap = max(math.floor(math.log10(bi + 1)), 1) if bi % Q_update_gap == 0: # update preconditioner less frequently grads = grad(loss, Ws, create_graph=True) v = [torch.randn(W.shape, device=device) for W in Ws] Hv = grad(grads, Ws, v) with torch.no_grad(): Qs = [ psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv) ] else: grads = grad(loss, Ws) with torch.no_grad(): pre_grads = [ psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads) ] for i in range(len(Ws)): Ws[i] -= lr * pre_grads[i] if bi == int(0.9 * config.psgd_setting['num_iter']): lr *= 0.1 if (bi + 1) % 1000 == 0 or bi + 1 == config.psgd_setting['num_iter']: scipy.io.savemat( 'src_prior.mat', dict([('W' + str(i), W.cpu().detach().numpy()) for (i, W) in enumerate(Ws)])) plt.plot(Loss)
tf.reduce_sum([tf.reduce_sum(g * g) for g in precond_grads])) step_size_adjust = tf.minimum(1.0, grad_norm_clip_thr / (grad_norm + 1.2e-38)) new_Ws = [ W - (step_size_adjust * step_size) * g for (W, g) in zip(Ws, precond_grads) ] update_Ws = [tf.assign(W, new_W) for (W, new_W) in zip(Ws, new_Ws)] delta_Ws = [tf.random_normal(W.shape, dtype=dtype) for W in Ws] grad_deltaw = tf.reduce_sum( [tf.reduce_sum(g * v) for (g, v) in zip(grads, delta_Ws)]) hess_deltaw = tf.gradients(grad_deltaw, Ws) new_Qs = [ psgd.update_precond_kron(ql, qr, dw, dg) for (ql, qr, dw, dg) in zip(Qs_left, Qs_right, delta_Ws, hess_deltaw) ] update_Qs = [[tf.assign(old_ql, new_q[0]), tf.assign(old_qr, new_q[1])] for (old_ql, old_qr, new_q) in zip(Qs_left, Qs_right, new_Qs)] test_loss = test_criterion(Ws) sess.run(tf.global_variables_initializer()) avg_train_loss = 0.0 TrainLoss = list() TestLoss = list() Time = list() for num_iter in range(20000): _train_inputs, _train_outputs = get_batches()
def test_loss( ): num_errs = 0 with torch.no_grad(): for data, target in test_loader: y = lenet5(data) _, pred = torch.max(y, dim=1) num_errs += torch.sum(pred!=target) return num_errs.item()/len(test_loader.dataset) Qs = [[torch.eye(W.shape[0]), torch.eye(W.shape[1])] for W in lenet5.parameters()] lr = 0.1 grad_norm_clip_thr = 0.1*sum(W.numel() for W in lenet5.parameters())**0.5 TrainLosses, best_test_loss = [], 1.0 for epoch in range(10): for _, (data, target) in enumerate(train_loader): loss = train_loss(data, target) grads = torch.autograd.grad(loss, lenet5.parameters(), create_graph=True) vs = [torch.randn_like(W) for W in lenet5.parameters()] Hvs = torch.autograd.grad(grads, lenet5.parameters(), vs) with torch.no_grad(): Qs = [psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv) for (Qlr, v, Hv) in zip(Qs, vs, Hvs)] pre_grads = [psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads)] grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads])) lr_adjust = min(grad_norm_clip_thr/grad_norm, 1.0) [W.subtract_(lr_adjust*lr*g) for (W, g) in zip(lenet5.parameters(), pre_grads)] TrainLosses.append(loss.item()) best_test_loss = min(best_test_loss, test_loss()) lr *= (0.01)**(1/9) print('Epoch: {}; best test classification error rate: {}'.format(epoch+1, best_test_loss)) plt.plot(TrainLosses)
], [0.1 * torch.eye(R), torch.ones(1, J)], [ 0.1 * torch.stack([torch.ones(R), torch.zeros(R)], dim=0), torch.ones(1, K) ], ] # # example 3 # Qs = [[0.1*torch.eye(w.shape[0]), torch.eye(w.shape[1])] for w in xyz] for _ in range(100): loss = f() f_values.append(loss.item()) grads = torch.autograd.grad(loss, xyz, create_graph=True) vs = [torch.randn_like(w) for w in xyz] Hvs = torch.autograd.grad(grads, xyz, vs) with torch.no_grad(): Qs = [ psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv, step=0.1) for (Qlr, v, Hv) in zip(Qs, vs, Hvs) ] pre_grads = [ psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads) ] [w.subtract_(0.1 * g) for (w, g) in zip(xyz, pre_grads)] plt.semilogy(f_values) plt.xlabel('Iterations') plt.ylabel('Decomposition losses')
for num_iter in range(10000): x, y = get_batches() # calculate the loss and gradient loss = train_criterion(Ws, x, y) grads = grad(loss, Ws, create_graph=True) Loss.append(loss.data.numpy()[0]) # update preconditioners Q_update_gap = max(int(np.floor(np.log10(num_iter + 1.0))), 1) if num_iter % Q_update_gap == 0: # let us update Q less frequently delta = [Variable(torch.randn(W.size())) for W in Ws] grad_delta = sum([torch.sum(g * d) for (g, d) in zip(grads, delta)]) hess_delta = grad(grad_delta, Ws) Qs = [ psgd.update_precond_kron(q[0], q[1], dw.data.numpy(), dg.data.numpy()) for (q, dw, dg) in zip(Qs, delta, hess_delta) ] # update Ws pre_grads = [ psgd.precond_grad_kron(q[0], q[1], g.data.numpy()) for (q, g) in zip(Qs, grads) ] grad_norm = np.sqrt(sum([np.sum(g * g) for g in pre_grads])) if grad_norm > grad_norm_clip_thr: step_adjust = grad_norm_clip_thr / grad_norm else: step_adjust = 1.0 for i in range(len(Ws)): Ws[i].data = Ws[i].data - step_adjust * step_size * torch.FloatTensor(
torch.log(torch.sigmoid(xy_pair[1] * lstm_net(xy_pair[0])))) Qs = [[torch.eye(W.shape[0]), torch.eye(W.shape[1])] for W in lstm_net.parameters()] lr = 0.02 grad_norm_clip_thr = 1.0 Losses = [] for num_iter in range(100000): loss = train_loss(generate_train_data()) grads = torch.autograd.grad(loss, lstm_net.parameters(), create_graph=True) vs = [torch.randn_like(W) for W in lstm_net.parameters()] Hvs = torch.autograd.grad(grads, lstm_net.parameters(), vs) with torch.no_grad(): Qs = [ psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv) for (Qlr, v, Hv) in zip(Qs, vs, Hvs) ] pre_grads = [ psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads) ] grad_norm = torch.sqrt(sum([torch.sum(g * g) for g in pre_grads])) lr_adjust = min(grad_norm_clip_thr / grad_norm, 1.0) [ W.subtract_(lr_adjust * lr * g) for (W, g) in zip(lstm_net.parameters(), pre_grads) ] Losses.append(loss.item()) print('Iteration: {}; loss: {}'.format(num_iter, Losses[-1])) if Losses[-1] < 0.1: