def train_step(xy_pair, lr):
    with tf.GradientTape() as g2nd:
        with tf.GradientTape() as g1st:
            loss = train_loss(xy_pair)
        grads = g1st.gradient(loss, lenet5_vars)
        vs = [tf.random.normal(W.shape)
              for W in lenet5_vars]  # a random vector
    hess_vs = g2nd.gradient(grads, lenet5_vars, vs)  # Hessian-vector products
    new_Qs = [
        psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv)
        for (Qlr, v, Hv) in zip(Qs, vs, hess_vs)
    ]
    [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])]
     for (Qlr, new_Qlr) in zip(Qs, new_Qs)]
    pre_grads = [
        psgd.precond_grad_kron(Qlr[0], Qlr[1], g)
        for (Qlr, g) in zip(Qs, grads)
    ]
    grad_norm = tf.sqrt(sum([tf.reduce_sum(g * g) for g in pre_grads]))
    lr_adjust = tf.minimum(grad_norm_clip_thr / grad_norm, 1.0)
    [
        W.assign_sub(lr_adjust * lr * g)
        for (W, g) in zip(lenet5_vars, pre_grads)
    ]
    return loss
 def opt_step():
     with tf.GradientTape() as g2nd: # second order derivative
         with tf.GradientTape() as g1st: # first order derivative
             cost = f()
         grads = g1st.gradient(cost, xyz) # gradient
         vs = [tf.random.normal(w.shape) for w in xyz] # a random vector
     hess_vs = g2nd.gradient(grads, xyz, vs) # Hessian-vector products
     new_Qs = [psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv, step=0.1) for (Qlr, v, Hv) in zip(Qs, vs, hess_vs)]
     [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])] for (Qlr, new_Qlr) in zip(Qs, new_Qs)]          
     pre_grads = [psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads)]
     [w.assign_sub(0.1*g) for (w, g) in zip(xyz, pre_grads)]
     return cost
def train_step_apprx_Hvp(inp, targ, lr):
    """
    Demenstrate the usage of approximated Hessian-vector product (Hvp). 
    Typically smaller graph and faster excution. But, the accuracy of Hvp might be in question. 
    """
    delta = tf.constant(
        2**(-23 / 2)
    )  # sqrt(float32.eps), the scale of perturbation for Hessian-vector product approximation

    def eval_loss_grads():
        with tf.GradientTape() as gt:
            enc_output = encoder(inp)
            dec_input = targ[:, 0]
            dec_h = enc_output[:,
                               -1, :]  # the last hidden state of encoder as the initial one for decoder
            loss = 0.0
            for t in range(1, targ.shape[1]):
                pred, dec_h = decoder(dec_input, dec_h, enc_output)
                loss += loss_function(targ[:, t], pred)
                dec_input = targ[:, t]
            loss /= targ.shape[1]
        return loss, gt.gradient(loss, trainable_vars)

    # loss and gradients
    loss, grads = eval_loss_grads()

    # calculate the perturbed gradients
    vs = [delta * tf.random.normal(W.shape) for W in trainable_vars]
    [W.assign_add(v) for (W, v) in zip(trainable_vars, vs)]
    _, perturbed_grads = eval_loss_grads()
    # update the preconditioners
    new_Qs = [
        psgd.update_precond_kron(Qlr[0], Qlr[1], v,
                                 tf.subtract(perturbed_g, g))
        for (Qlr, v, perturbed_g, g) in zip(Qs, vs, perturbed_grads, grads)
    ]
    [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])]
     for (Qlr, new_Qlr) in zip(Qs, new_Qs)]
    # calculate the preconditioned gradients
    pre_grads = [
        psgd.precond_grad_kron(Qlr[0], Qlr[1], g)
        for (Qlr, g) in zip(Qs, grads)
    ]
    # update variables; do not forget to remove the perturbations on variables
    [
        W.assign_sub(lr * g + v)
        for (W, g, v) in zip(trainable_vars, pre_grads, vs)
    ]
    return loss
def train_step_exact_Hvp(inp, targ, lr):
    """
    Demenstrate the usage of exact Hessian-vector product (Hvp). 
    Typically larger graph and slower excution. But, Hvp is exact, and cleaner code. 
    """
    with tf.GradientTape() as g2nd:
        with tf.GradientTape() as g1st:
            enc_output = encoder(inp)
            dec_input = targ[:, 0]
            dec_h = enc_output[:,
                               -1, :]  # the last hidden state of encoder as the initial one for decoder
            loss = 0.0
            for t in range(1, targ.shape[1]):
                pred, dec_h = decoder(dec_input, dec_h, enc_output)
                loss += loss_function(targ[:, t], pred)
                dec_input = targ[:, t]
            loss /= targ.shape[1]
        grads = g1st.gradient(loss, trainable_vars)
        grads = [
            tf.convert_to_tensor(g) if isinstance(g, tf.IndexedSlices) else g
            for g in grads
        ]
        vs = [tf.random.normal(W.shape)
              for W in trainable_vars]  # a random vector
    hess_vs = g2nd.gradient(grads, trainable_vars,
                            vs)  # Hessian-vector products
    new_Qs = [
        psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv)
        for (Qlr, v, Hv) in zip(Qs, vs, hess_vs)
    ]
    [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])]
     for (Qlr, new_Qlr) in zip(Qs, new_Qs)]
    pre_grads = [
        psgd.precond_grad_kron(Qlr[0], Qlr[1], g)
        for (Qlr, g) in zip(Qs, grads)
    ]
    [W.assign_sub(lr * g) for (W, g) in zip(trainable_vars, pre_grads)]
    return loss
Exemple #5
0
        new_images = torch.tensor(images / 256,
                                  dtype=torch.float,
                                  device=device)
        new_labels = []
        for label in labels:
            new_labels.append(U.remove_repetitive_labels(label))

        loss = train_loss(new_images, new_labels)
        grads = grad(loss, Ws, create_graph=True)
        TrainLoss.append(loss.item())

        v = [torch.randn(W.shape, device=device) for W in Ws]
        Hv = grad(grads, Ws, v)
        with torch.no_grad():
            Qs = [
                psgd.update_precond_kron(q[0], q[1], dw, dg)
                for (q, dw, dg) in zip(Qs, v, Hv)
            ]
            pre_grads = [
                psgd.precond_grad_kron(q[0], q[1], g)
                for (q, g) in zip(Qs, grads)
            ]
            grad_norm = torch.sqrt(sum([torch.sum(g * g) for g in pre_grads]))
            lr_adjust = min(grad_norm_clip_thr / (grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= lr_adjust * lr * pre_grads[i]

        if num_iter % 100 == 0:
            print(
                'epoch: {}; iter: {}; train loss: {}; elapsed time: {}'.format(
                    epoch, num_iter, TrainLoss[-1],
Exemple #6
0
Qs = [[0.1*torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
step_size = 0.1#normalized step size
grad_norm_clip_thr = 100.0#the size of trust region
# begin iteration here
TrainLoss = []
TestLoss = []
echo_every = 100#iterations
for num_iter in range(5000):
    x, y = get_batches( )    
    # calculate the loss and gradient
    loss = train_criterion(Ws, x, y)
    grads = grad(loss, Ws, create_graph=True)
    TrainLoss.append(loss.item())
    
    v = [torch.randn(W.shape).to(device) for W in Ws]
    #grad_v = sum([torch.sum(g*d) for (g, d) in zip(grads, v)])
    #hess_v = grad(grad_v, Ws)
    Hv = grad(grads, Ws, v)#replace the above two lines, due to Bulatov
    #Hv = grads # just let Hv=grads if you use Fisher type preconditioner        
    with torch.no_grad():
        Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
        pre_grads = [psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
        grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
        step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
        for i in range(len(Ws)):
            Ws[i] -= step_adjust*step_size*pre_grads[i]
            
        if (num_iter+1) % echo_every == 0:
            loss = test_criterion(Ws)
            TestLoss.append(loss.item())
            print('train loss: {}; test loss: {}'.format(TrainLoss[-1], TestLoss[-1]))
Exemple #7
0
def main():
    """DNN speech prior training"""
    device = config.device
    wavloader = artificial_mixture_generator.WavLoader(config.wav_dir)
    mixergenerator = artificial_mixture_generator.MixerGenerator(
        wavloader, config.batch_size, config.num_mic, config.Lh,
        config.iva_hop_size * config.num_frame, config.prb_mix_change)

    Ws = [
        W.to(device)
        for W in initW(config.iva_fft_size // 2 -
                       1, config.src_prior['num_layer'], config.
                       src_prior['num_state'], config.src_prior['dim_h'])
    ]
    hs = [
        torch.zeros(config.batch_size * config.num_mic,
                    config.src_prior['num_state']).to(device)
        for _ in range(config.src_prior['num_layer'] - 1)
    ]
    for W in Ws:
        W.requires_grad = True

    # W_iva initialization
    W_iva = (100.0 *
             torch.randn(config.batch_size, config.iva_fft_size // 2 - 1,
                         config.num_mic, config.num_mic, 2)).to(device)

    # STFT window for IVA
    win_iva = stft.pre_def_win(config.iva_fft_size, config.iva_hop_size)

    # preconditioners for the source prior neural network optimization
    Qs = [[
        torch.eye(W.shape[0], device=device),
        torch.eye(W.shape[1], device=device)
    ] for W in Ws]  # preconditioners for SGD

    # buffers for STFT
    s_bfr = torch.zeros(config.batch_size, config.num_mic,
                        config.iva_fft_size - config.iva_hop_size)
    x_bfr = torch.zeros(config.batch_size, config.num_mic,
                        config.iva_fft_size - config.iva_hop_size)

    # buffer for overlap-add synthesis and reconstruction loss calculation
    ola_bfr = torch.zeros(config.batch_size, config.num_mic,
                          config.iva_fft_size).to(device)
    xtr_bfr = torch.zeros(config.batch_size, config.num_mic,
                          config.iva_fft_size - config.iva_hop_size
                          )  # extra buffer for reconstruction loss calculation

    Loss, lr = [], config.psgd_setting['lr']
    for bi in range(config.psgd_setting['num_iter']):
        srcs, xs = mixergenerator.get_mixtures()

        Ss, s_bfr = stft.stft(srcs[:, :, config.Lh:-config.Lh], win_iva,
                              config.iva_hop_size, s_bfr)
        Xs, x_bfr = stft.stft(xs, win_iva, config.iva_hop_size, x_bfr)
        Ss, Xs = Ss.to(device), Xs.to(device)

        Ys, W_iva, hs = iva(Xs, W_iva.detach(), [h.detach() for h in hs], Ws,
                            config.iva_lr)

        # loss calculation
        coherence = losses.di_pi_coh(Ss, Ys)
        loss = 1.0 - coherence
        if config.use_spectra_dist_loss:
            spectra_dist = losses.di_pi_is_dist(
                Ss[:, :, :, 0] * Ss[:, :, :, 0] +
                Ss[:, :, :, 1] * Ss[:, :, :, 1],
                Ys[:, :, :, 0] * Ys[:, :, :, 0] +
                Ys[:, :, :, 1] * Ys[:, :, :, 1])
            loss = loss + spectra_dist

        if config.reconstruction_loss_fft_sizes:
            srcs = torch.cat([xtr_bfr, srcs], dim=2)
            ys, ola_bfr = stft.istft(Ys, win_iva.to(device),
                                     config.iva_hop_size, ola_bfr.detach())
            for fft_size in config.reconstruction_loss_fft_sizes:
                win = stft.coswin(fft_size)
                Ss, _ = stft.stft(srcs, win, fft_size // 2)
                Ys, _ = stft.stft(ys, win.to(device), fft_size // 2)
                Ss = Ss.to(device)
                coherence = losses.di_pi_coh(Ss, Ys)
                loss = loss + 1.0 - coherence
                if config.use_spectra_dist_loss:
                    spectra_dist = losses.di_pi_is_dist(
                        Ss[:, :, :, 0] * Ss[:, :, :, 0] +
                        Ss[:, :, :, 1] * Ss[:, :, :, 1],
                        Ys[:, :, :, 0] * Ys[:, :, :, 0] +
                        Ys[:, :, :, 1] * Ys[:, :, :, 1])
                    loss = loss + spectra_dist

            xtr_bfr = srcs[:, :, -(config.iva_fft_size - config.iva_hop_size):]

        Loss.append(loss.item())
        if config.use_spectra_dist_loss:
            print('Loss: {}; coherence: {}; spectral_distance: {}'.format(
                loss.item(), coherence.item(), spectra_dist.item()))
        else:
            print('Loss: {}; coherence: {}'.format(loss.item(),
                                                   coherence.item()))

        # Preconditioned SGD optimizer for source prior network optimization
        Q_update_gap = max(math.floor(math.log10(bi + 1)), 1)
        if bi % Q_update_gap == 0:  # update preconditioner less frequently
            grads = grad(loss, Ws, create_graph=True)
            v = [torch.randn(W.shape, device=device) for W in Ws]
            Hv = grad(grads, Ws, v)
            with torch.no_grad():
                Qs = [
                    psgd.update_precond_kron(q[0], q[1], dw, dg)
                    for (q, dw, dg) in zip(Qs, v, Hv)
                ]
        else:
            grads = grad(loss, Ws)

        with torch.no_grad():
            pre_grads = [
                psgd.precond_grad_kron(q[0], q[1], g)
                for (q, g) in zip(Qs, grads)
            ]
            for i in range(len(Ws)):
                Ws[i] -= lr * pre_grads[i]

        if bi == int(0.9 * config.psgd_setting['num_iter']):
            lr *= 0.1
        if (bi + 1) % 1000 == 0 or bi + 1 == config.psgd_setting['num_iter']:
            scipy.io.savemat(
                'src_prior.mat',
                dict([('W' + str(i), W.cpu().detach().numpy())
                      for (i, W) in enumerate(Ws)]))

    plt.plot(Loss)
        tf.reduce_sum([tf.reduce_sum(g * g) for g in precond_grads]))
    step_size_adjust = tf.minimum(1.0,
                                  grad_norm_clip_thr / (grad_norm + 1.2e-38))
    new_Ws = [
        W - (step_size_adjust * step_size) * g
        for (W, g) in zip(Ws, precond_grads)
    ]
    update_Ws = [tf.assign(W, new_W) for (W, new_W) in zip(Ws, new_Ws)]

    delta_Ws = [tf.random_normal(W.shape, dtype=dtype) for W in Ws]
    grad_deltaw = tf.reduce_sum(
        [tf.reduce_sum(g * v) for (g, v) in zip(grads, delta_Ws)])
    hess_deltaw = tf.gradients(grad_deltaw, Ws)

    new_Qs = [
        psgd.update_precond_kron(ql, qr, dw, dg)
        for (ql, qr, dw, dg) in zip(Qs_left, Qs_right, delta_Ws, hess_deltaw)
    ]
    update_Qs = [[tf.assign(old_ql, new_q[0]),
                  tf.assign(old_qr, new_q[1])]
                 for (old_ql, old_qr, new_q) in zip(Qs_left, Qs_right, new_Qs)]

    test_loss = test_criterion(Ws)

    sess.run(tf.global_variables_initializer())
    avg_train_loss = 0.0
    TrainLoss = list()
    TestLoss = list()
    Time = list()
    for num_iter in range(20000):
        _train_inputs, _train_outputs = get_batches()
def test_loss( ):
    num_errs = 0
    with torch.no_grad():
        for data, target in test_loader:
            y = lenet5(data)
            _, pred = torch.max(y, dim=1)
            num_errs += torch.sum(pred!=target)            
    return num_errs.item()/len(test_loader.dataset)


Qs = [[torch.eye(W.shape[0]), torch.eye(W.shape[1])] for W in lenet5.parameters()]
lr = 0.1
grad_norm_clip_thr = 0.1*sum(W.numel() for W in lenet5.parameters())**0.5
TrainLosses, best_test_loss = [], 1.0
for epoch in range(10):
    for _, (data, target) in enumerate(train_loader):
        loss = train_loss(data, target)
        grads = torch.autograd.grad(loss, lenet5.parameters(), create_graph=True)
        vs = [torch.randn_like(W) for W in lenet5.parameters()]
        Hvs = torch.autograd.grad(grads, lenet5.parameters(), vs) 
        with torch.no_grad():
            Qs = [psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv) for (Qlr, v, Hv) in zip(Qs, vs, Hvs)]
            pre_grads = [psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            lr_adjust = min(grad_norm_clip_thr/grad_norm, 1.0)
            [W.subtract_(lr_adjust*lr*g) for (W, g) in zip(lenet5.parameters(), pre_grads)]                
            TrainLosses.append(loss.item())
    best_test_loss = min(best_test_loss, test_loss())
    lr *= (0.01)**(1/9)
    print('Epoch: {}; best test classification error rate: {}'.format(epoch+1, best_test_loss))
plt.plot(TrainLosses)
        ],
        [0.1 * torch.eye(R), torch.ones(1, J)],
        [
            0.1 * torch.stack([torch.ones(R), torch.zeros(R)], dim=0),
            torch.ones(1, K)
        ],
    ]

    # # example 3
    # Qs = [[0.1*torch.eye(w.shape[0]), torch.eye(w.shape[1])] for w in xyz]

    for _ in range(100):
        loss = f()
        f_values.append(loss.item())
        grads = torch.autograd.grad(loss, xyz, create_graph=True)
        vs = [torch.randn_like(w) for w in xyz]
        Hvs = torch.autograd.grad(grads, xyz, vs)
        with torch.no_grad():
            Qs = [
                psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv, step=0.1)
                for (Qlr, v, Hv) in zip(Qs, vs, Hvs)
            ]
            pre_grads = [
                psgd.precond_grad_kron(Qlr[0], Qlr[1], g)
                for (Qlr, g) in zip(Qs, grads)
            ]
            [w.subtract_(0.1 * g) for (w, g) in zip(xyz, pre_grads)]

plt.semilogy(f_values)
plt.xlabel('Iterations')
plt.ylabel('Decomposition losses')
Exemple #11
0
for num_iter in range(10000):
    x, y = get_batches()

    # calculate the loss and gradient
    loss = train_criterion(Ws, x, y)
    grads = grad(loss, Ws, create_graph=True)
    Loss.append(loss.data.numpy()[0])

    # update preconditioners
    Q_update_gap = max(int(np.floor(np.log10(num_iter + 1.0))), 1)
    if num_iter % Q_update_gap == 0:  # let us update Q less frequently
        delta = [Variable(torch.randn(W.size())) for W in Ws]
        grad_delta = sum([torch.sum(g * d) for (g, d) in zip(grads, delta)])
        hess_delta = grad(grad_delta, Ws)
        Qs = [
            psgd.update_precond_kron(q[0], q[1], dw.data.numpy(),
                                     dg.data.numpy())
            for (q, dw, dg) in zip(Qs, delta, hess_delta)
        ]

    # update Ws
    pre_grads = [
        psgd.precond_grad_kron(q[0], q[1], g.data.numpy())
        for (q, g) in zip(Qs, grads)
    ]
    grad_norm = np.sqrt(sum([np.sum(g * g) for g in pre_grads]))
    if grad_norm > grad_norm_clip_thr:
        step_adjust = grad_norm_clip_thr / grad_norm
    else:
        step_adjust = 1.0
    for i in range(len(Ws)):
        Ws[i].data = Ws[i].data - step_adjust * step_size * torch.FloatTensor(
        torch.log(torch.sigmoid(xy_pair[1] * lstm_net(xy_pair[0]))))


Qs = [[torch.eye(W.shape[0]), torch.eye(W.shape[1])]
      for W in lstm_net.parameters()]
lr = 0.02
grad_norm_clip_thr = 1.0
Losses = []
for num_iter in range(100000):
    loss = train_loss(generate_train_data())
    grads = torch.autograd.grad(loss, lstm_net.parameters(), create_graph=True)
    vs = [torch.randn_like(W) for W in lstm_net.parameters()]
    Hvs = torch.autograd.grad(grads, lstm_net.parameters(), vs)
    with torch.no_grad():
        Qs = [
            psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv)
            for (Qlr, v, Hv) in zip(Qs, vs, Hvs)
        ]
        pre_grads = [
            psgd.precond_grad_kron(Qlr[0], Qlr[1], g)
            for (Qlr, g) in zip(Qs, grads)
        ]
        grad_norm = torch.sqrt(sum([torch.sum(g * g) for g in pre_grads]))
        lr_adjust = min(grad_norm_clip_thr / grad_norm, 1.0)
        [
            W.subtract_(lr_adjust * lr * g)
            for (W, g) in zip(lstm_net.parameters(), pre_grads)
        ]
    Losses.append(loss.item())
    print('Iteration: {}; loss: {}'.format(num_iter, Losses[-1]))
    if Losses[-1] < 0.1: