Ejemplo n.º 1
0
def train_step(xy_pair, lr):
    with tf.GradientTape() as g2nd:
        with tf.GradientTape() as g1st:
            loss = train_loss(xy_pair)
        grads = g1st.gradient(loss, lenet5_vars)
        vs = [tf.random.normal(W.shape)
              for W in lenet5_vars]  # a random vector
    hess_vs = g2nd.gradient(grads, lenet5_vars, vs)  # Hessian-vector products
    new_Qs = [
        psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv)
        for (Qlr, v, Hv) in zip(Qs, vs, hess_vs)
    ]
    [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])]
     for (Qlr, new_Qlr) in zip(Qs, new_Qs)]
    pre_grads = [
        psgd.precond_grad_kron(Qlr[0], Qlr[1], g)
        for (Qlr, g) in zip(Qs, grads)
    ]
    grad_norm = tf.sqrt(sum([tf.reduce_sum(g * g) for g in pre_grads]))
    lr_adjust = tf.minimum(grad_norm_clip_thr / grad_norm, 1.0)
    [
        W.assign_sub(lr_adjust * lr * g)
        for (W, g) in zip(lenet5_vars, pre_grads)
    ]
    return loss
 def opt_step():
     with tf.GradientTape() as g2nd: # second order derivative
         with tf.GradientTape() as g1st: # first order derivative
             cost = f()
         grads = g1st.gradient(cost, xyz) # gradient
         vs = [tf.random.normal(w.shape) for w in xyz] # a random vector
     hess_vs = g2nd.gradient(grads, xyz, vs) # Hessian-vector products
     new_Qs = [psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv, step=0.1) for (Qlr, v, Hv) in zip(Qs, vs, hess_vs)]
     [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])] for (Qlr, new_Qlr) in zip(Qs, new_Qs)]          
     pre_grads = [psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads)]
     [w.assign_sub(0.1*g) for (w, g) in zip(xyz, pre_grads)]
     return cost
def train_step_apprx_Hvp(inp, targ, lr):
    """
    Demenstrate the usage of approximated Hessian-vector product (Hvp). 
    Typically smaller graph and faster excution. But, the accuracy of Hvp might be in question. 
    """
    delta = tf.constant(
        2**(-23 / 2)
    )  # sqrt(float32.eps), the scale of perturbation for Hessian-vector product approximation

    def eval_loss_grads():
        with tf.GradientTape() as gt:
            enc_output = encoder(inp)
            dec_input = targ[:, 0]
            dec_h = enc_output[:,
                               -1, :]  # the last hidden state of encoder as the initial one for decoder
            loss = 0.0
            for t in range(1, targ.shape[1]):
                pred, dec_h = decoder(dec_input, dec_h, enc_output)
                loss += loss_function(targ[:, t], pred)
                dec_input = targ[:, t]
            loss /= targ.shape[1]
        return loss, gt.gradient(loss, trainable_vars)

    # loss and gradients
    loss, grads = eval_loss_grads()

    # calculate the perturbed gradients
    vs = [delta * tf.random.normal(W.shape) for W in trainable_vars]
    [W.assign_add(v) for (W, v) in zip(trainable_vars, vs)]
    _, perturbed_grads = eval_loss_grads()
    # update the preconditioners
    new_Qs = [
        psgd.update_precond_kron(Qlr[0], Qlr[1], v,
                                 tf.subtract(perturbed_g, g))
        for (Qlr, v, perturbed_g, g) in zip(Qs, vs, perturbed_grads, grads)
    ]
    [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])]
     for (Qlr, new_Qlr) in zip(Qs, new_Qs)]
    # calculate the preconditioned gradients
    pre_grads = [
        psgd.precond_grad_kron(Qlr[0], Qlr[1], g)
        for (Qlr, g) in zip(Qs, grads)
    ]
    # update variables; do not forget to remove the perturbations on variables
    [
        W.assign_sub(lr * g + v)
        for (W, g, v) in zip(trainable_vars, pre_grads, vs)
    ]
    return loss
def train_step_exact_Hvp(inp, targ, lr):
    """
    Demenstrate the usage of exact Hessian-vector product (Hvp). 
    Typically larger graph and slower excution. But, Hvp is exact, and cleaner code. 
    """
    with tf.GradientTape() as g2nd:
        with tf.GradientTape() as g1st:
            enc_output = encoder(inp)
            dec_input = targ[:, 0]
            dec_h = enc_output[:,
                               -1, :]  # the last hidden state of encoder as the initial one for decoder
            loss = 0.0
            for t in range(1, targ.shape[1]):
                pred, dec_h = decoder(dec_input, dec_h, enc_output)
                loss += loss_function(targ[:, t], pred)
                dec_input = targ[:, t]
            loss /= targ.shape[1]
        grads = g1st.gradient(loss, trainable_vars)
        grads = [
            tf.convert_to_tensor(g) if isinstance(g, tf.IndexedSlices) else g
            for g in grads
        ]
        vs = [tf.random.normal(W.shape)
              for W in trainable_vars]  # a random vector
    hess_vs = g2nd.gradient(grads, trainable_vars,
                            vs)  # Hessian-vector products
    new_Qs = [
        psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv)
        for (Qlr, v, Hv) in zip(Qs, vs, hess_vs)
    ]
    [[Qlr[0].assign(new_Qlr[0]), Qlr[1].assign(new_Qlr[1])]
     for (Qlr, new_Qlr) in zip(Qs, new_Qs)]
    pre_grads = [
        psgd.precond_grad_kron(Qlr[0], Qlr[1], g)
        for (Qlr, g) in zip(Qs, grads)
    ]
    [W.assign_sub(lr * g) for (W, g) in zip(trainable_vars, pre_grads)]
    return loss
Ejemplo n.º 5
0
        for label in labels:
            new_labels.append(U.remove_repetitive_labels(label))

        loss = train_loss(new_images, new_labels)
        grads = grad(loss, Ws, create_graph=True)
        TrainLoss.append(loss.item())

        v = [torch.randn(W.shape, device=device) for W in Ws]
        Hv = grad(grads, Ws, v)
        with torch.no_grad():
            Qs = [
                psgd.update_precond_kron(q[0], q[1], dw, dg)
                for (q, dw, dg) in zip(Qs, v, Hv)
            ]
            pre_grads = [
                psgd.precond_grad_kron(q[0], q[1], g)
                for (q, g) in zip(Qs, grads)
            ]
            grad_norm = torch.sqrt(sum([torch.sum(g * g) for g in pre_grads]))
            lr_adjust = min(grad_norm_clip_thr / (grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= lr_adjust * lr * pre_grads[i]

        if num_iter % 100 == 0:
            print(
                'epoch: {}; iter: {}; train loss: {}; elapsed time: {}'.format(
                    epoch, num_iter, TrainLoss[-1],
                    time.time() - t0))

    TestLoss.append(test_loss())
    if TestLoss[-1] < BestTestLoss:
Ejemplo n.º 6
0
Qs = [[0.1*torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
step_size = 0.1#normalized step size
grad_norm_clip_thr = 100.0#the size of trust region
# begin iteration here
TrainLoss = []
TestLoss = []
echo_every = 100#iterations
for num_iter in range(5000):
    x, y = get_batches( )    
    # calculate the loss and gradient
    loss = train_criterion(Ws, x, y)
    grads = grad(loss, Ws, create_graph=True)
    TrainLoss.append(loss.item())
    
    v = [torch.randn(W.shape).to(device) for W in Ws]
    #grad_v = sum([torch.sum(g*d) for (g, d) in zip(grads, v)])
    #hess_v = grad(grad_v, Ws)
    Hv = grad(grads, Ws, v)#replace the above two lines, due to Bulatov
    #Hv = grads # just let Hv=grads if you use Fisher type preconditioner        
    with torch.no_grad():
        Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
        pre_grads = [psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
        grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
        step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
        for i in range(len(Ws)):
            Ws[i] -= step_adjust*step_size*pre_grads[i]
            
        if (num_iter+1) % echo_every == 0:
            loss = test_criterion(Ws)
            TestLoss.append(loss.item())
            print('train loss: {}; test loss: {}'.format(TrainLoss[-1], TestLoss[-1]))
Ejemplo n.º 7
0
def main():
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")
    print("using device ", device)
    torch.manual_seed(args.seed)

    u.set_runs_directory('runs3')
    logger = u.TensorboardLogger(args.run)
    batch_size = 64
    shuffle = True
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '/tmp/data',
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()])),
                                               batch_size=batch_size,
                                               shuffle=shuffle,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '/tmp/data',
        train=False,
        transform=transforms.Compose([transforms.ToTensor()])),
                                              batch_size=1000,
                                              shuffle=shuffle,
                                              **kwargs)
    """input image size for the original LeNet5 is 32x32, here is 28x28"""

    #  W1 = 0.1 * torch.randn(1 * 5 * 5 + 1, 6)

    net = LeNet5().to(device)

    def train_loss(data, target):
        y = net(data)
        y = F.log_softmax(y, dim=1)
        loss = F.nll_loss(y, target)
        for w in net.W:
            loss += 0.0002 * torch.sum(w * w)

        return loss

    def test_loss():
        num_errs = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)

                y = net(data)
                _, pred = torch.max(y, dim=1)
                num_errs += torch.sum(pred != target)

        return num_errs.item() / len(test_loader.dataset)

    Qs = [[torch.eye(w.shape[0]), torch.eye(w.shape[1])] for w in net.W]
    for i in range(len(Qs)):
        for j in range(len(Qs[i])):
            Qs[i][j] = Qs[i][j].to(device)

    step_size = 0.1  # tried 0.15, diverges
    grad_norm_clip_thr = 1e10
    TrainLoss, TestLoss = [], []
    example_count = 0
    step_time_ms = 0

    for epoch in range(10):
        for batch_idx, (data, target) in enumerate(train_loader):
            step_start = time.perf_counter()
            data, target = data.to(device), target.to(device)

            loss = train_loss(data, target)

            with u.timeit('grad'):
                grads = autograd.grad(loss, net.W, create_graph=True)
            TrainLoss.append(loss.item())
            logger.set_step(example_count)
            logger('loss/train', TrainLoss[-1])
            if batch_idx % 10 == 0:
                print(
                    f'Epoch: {epoch}; batch: {batch_idx}; train loss: {TrainLoss[-1]:.2f}, step time: {step_time_ms:.0f}'
                )

            with u.timeit('Hv'):
                #        noise.normal_()
                # torch.manual_seed(args.seed)
                v = [torch.randn(w.shape).to(device) for w in net.W]
                # v = grads
                Hv = autograd.grad(grads, net.W, v)

            if args.verbose:
                print("v", v[0].mean())
                print("data", data.mean())
                print("Hv", Hv[0].mean())

            n = len(net.W)
            with torch.no_grad():
                with u.timeit('P_update'):
                    for i in range(num_updates):
                        psteps = []
                        for j in range(n):
                            q = Qs[j]
                            dw = v[j]
                            dg = Hv[j]
                            Qs[j][0], Qs[j][
                                1], pstep = psgd.update_precond_kron_with_step(
                                    q[0], q[1], dw, dg)
                            psteps.append(pstep)

                            #          print(np.array(psteps).mean())
                    logger('p_residual', np.array(psteps).mean())

                with u.timeit('g_update'):
                    pre_grads = [
                        psgd.precond_grad_kron(q[0], q[1], g)
                        for (q, g) in zip(Qs, grads)
                    ]
                    grad_norm = torch.sqrt(
                        sum([torch.sum(g * g) for g in pre_grads]))

                with u.timeit('gradstep'):
                    step_adjust = min(
                        grad_norm_clip_thr / (grad_norm + 1.2e-38), 1.0)
                    for i in range(len(net.W)):
                        net.W[i] -= step_adjust * step_size * pre_grads[i]

                total_step = step_adjust * step_size
                logger('step/adjust', step_adjust)
                logger('step/size', step_size)
                logger('step/total', total_step)
                logger('grad_norm', grad_norm)

                if args.verbose:
                    print(data.mean())
                    import pdb
                    pdb.set_trace()
                if args.early_stop:
                    sys.exit()

            example_count += batch_size
            step_time_ms = 1000 * (time.perf_counter() - step_start)
            logger('time/step', step_time_ms)

            if args.test and batch_idx >= 100:
                break
        if args.test and batch_idx >= 100:
            break

        test_loss0 = test_loss()
        TestLoss.append(test_loss0)
        logger('loss/test', test_loss0)
        step_size = (0.1**0.1) * step_size
        print('Epoch: {}; best test loss: {}'.format(epoch, min(TestLoss)))

    if args.test:
        step_times = logger.d['time/step']
        assert step_times[-1] < 30, step_times  # should be around 20ms
        losses = logger.d['loss/train']
        assert losses[0] > 2  # around 2.3887393474578857
        assert losses[-1] < 0.5, losses
        print("Test passed")
Ejemplo n.º 8
0
def main():
    """DNN speech prior training"""
    device = config.device
    wavloader = artificial_mixture_generator.WavLoader(config.wav_dir)
    mixergenerator = artificial_mixture_generator.MixerGenerator(
        wavloader, config.batch_size, config.num_mic, config.Lh,
        config.iva_hop_size * config.num_frame, config.prb_mix_change)

    Ws = [
        W.to(device)
        for W in initW(config.iva_fft_size // 2 -
                       1, config.src_prior['num_layer'], config.
                       src_prior['num_state'], config.src_prior['dim_h'])
    ]
    hs = [
        torch.zeros(config.batch_size * config.num_mic,
                    config.src_prior['num_state']).to(device)
        for _ in range(config.src_prior['num_layer'] - 1)
    ]
    for W in Ws:
        W.requires_grad = True

    # W_iva initialization
    W_iva = (100.0 *
             torch.randn(config.batch_size, config.iva_fft_size // 2 - 1,
                         config.num_mic, config.num_mic, 2)).to(device)

    # STFT window for IVA
    win_iva = stft.pre_def_win(config.iva_fft_size, config.iva_hop_size)

    # preconditioners for the source prior neural network optimization
    Qs = [[
        torch.eye(W.shape[0], device=device),
        torch.eye(W.shape[1], device=device)
    ] for W in Ws]  # preconditioners for SGD

    # buffers for STFT
    s_bfr = torch.zeros(config.batch_size, config.num_mic,
                        config.iva_fft_size - config.iva_hop_size)
    x_bfr = torch.zeros(config.batch_size, config.num_mic,
                        config.iva_fft_size - config.iva_hop_size)

    # buffer for overlap-add synthesis and reconstruction loss calculation
    ola_bfr = torch.zeros(config.batch_size, config.num_mic,
                          config.iva_fft_size).to(device)
    xtr_bfr = torch.zeros(config.batch_size, config.num_mic,
                          config.iva_fft_size - config.iva_hop_size
                          )  # extra buffer for reconstruction loss calculation

    Loss, lr = [], config.psgd_setting['lr']
    for bi in range(config.psgd_setting['num_iter']):
        srcs, xs = mixergenerator.get_mixtures()

        Ss, s_bfr = stft.stft(srcs[:, :, config.Lh:-config.Lh], win_iva,
                              config.iva_hop_size, s_bfr)
        Xs, x_bfr = stft.stft(xs, win_iva, config.iva_hop_size, x_bfr)
        Ss, Xs = Ss.to(device), Xs.to(device)

        Ys, W_iva, hs = iva(Xs, W_iva.detach(), [h.detach() for h in hs], Ws,
                            config.iva_lr)

        # loss calculation
        coherence = losses.di_pi_coh(Ss, Ys)
        loss = 1.0 - coherence
        if config.use_spectra_dist_loss:
            spectra_dist = losses.di_pi_is_dist(
                Ss[:, :, :, 0] * Ss[:, :, :, 0] +
                Ss[:, :, :, 1] * Ss[:, :, :, 1],
                Ys[:, :, :, 0] * Ys[:, :, :, 0] +
                Ys[:, :, :, 1] * Ys[:, :, :, 1])
            loss = loss + spectra_dist

        if config.reconstruction_loss_fft_sizes:
            srcs = torch.cat([xtr_bfr, srcs], dim=2)
            ys, ola_bfr = stft.istft(Ys, win_iva.to(device),
                                     config.iva_hop_size, ola_bfr.detach())
            for fft_size in config.reconstruction_loss_fft_sizes:
                win = stft.coswin(fft_size)
                Ss, _ = stft.stft(srcs, win, fft_size // 2)
                Ys, _ = stft.stft(ys, win.to(device), fft_size // 2)
                Ss = Ss.to(device)
                coherence = losses.di_pi_coh(Ss, Ys)
                loss = loss + 1.0 - coherence
                if config.use_spectra_dist_loss:
                    spectra_dist = losses.di_pi_is_dist(
                        Ss[:, :, :, 0] * Ss[:, :, :, 0] +
                        Ss[:, :, :, 1] * Ss[:, :, :, 1],
                        Ys[:, :, :, 0] * Ys[:, :, :, 0] +
                        Ys[:, :, :, 1] * Ys[:, :, :, 1])
                    loss = loss + spectra_dist

            xtr_bfr = srcs[:, :, -(config.iva_fft_size - config.iva_hop_size):]

        Loss.append(loss.item())
        if config.use_spectra_dist_loss:
            print('Loss: {}; coherence: {}; spectral_distance: {}'.format(
                loss.item(), coherence.item(), spectra_dist.item()))
        else:
            print('Loss: {}; coherence: {}'.format(loss.item(),
                                                   coherence.item()))

        # Preconditioned SGD optimizer for source prior network optimization
        Q_update_gap = max(math.floor(math.log10(bi + 1)), 1)
        if bi % Q_update_gap == 0:  # update preconditioner less frequently
            grads = grad(loss, Ws, create_graph=True)
            v = [torch.randn(W.shape, device=device) for W in Ws]
            Hv = grad(grads, Ws, v)
            with torch.no_grad():
                Qs = [
                    psgd.update_precond_kron(q[0], q[1], dw, dg)
                    for (q, dw, dg) in zip(Qs, v, Hv)
                ]
        else:
            grads = grad(loss, Ws)

        with torch.no_grad():
            pre_grads = [
                psgd.precond_grad_kron(q[0], q[1], g)
                for (q, g) in zip(Qs, grads)
            ]
            for i in range(len(Ws)):
                Ws[i] -= lr * pre_grads[i]

        if bi == int(0.9 * config.psgd_setting['num_iter']):
            lr *= 0.1
        if (bi + 1) % 1000 == 0 or bi + 1 == config.psgd_setting['num_iter']:
            scipy.io.savemat(
                'src_prior.mat',
                dict([('W' + str(i), W.cpu().detach().numpy())
                      for (i, W) in enumerate(Ws)]))

    plt.plot(Loss)
Ejemplo n.º 9
0
 def adjust_grads(self, grads):
     assert len(grads) == len(self.net.W)
     return [
         psgd.precond_grad_kron(q[0], q[1], g)
         for (q, g) in zip(self.Qs, grads)
     ]
Ejemplo n.º 10
0
with tf.Session() as sess:
    Qs_left = [
        tf.Variable(tf.eye(W.shape.as_list()[0], dtype=dtype), trainable=False)
        for W in Ws
    ]
    Qs_right = [
        tf.Variable(tf.eye(W.shape.as_list()[1], dtype=dtype), trainable=False)
        for W in Ws
    ]

    train_loss = train_criterion(Ws)
    grads = tf.gradients(train_loss, Ws)

    precond_grads = [
        psgd.precond_grad_kron(ql, qr, g)
        for (ql, qr, g) in zip(Qs_left, Qs_right, grads)
    ]
    grad_norm = tf.sqrt(
        tf.reduce_sum([tf.reduce_sum(g * g) for g in precond_grads]))
    step_size_adjust = tf.minimum(1.0,
                                  grad_norm_clip_thr / (grad_norm + 1.2e-38))
    new_Ws = [
        W - (step_size_adjust * step_size) * g
        for (W, g) in zip(Ws, precond_grads)
    ]
    update_Ws = [tf.assign(W, new_W) for (W, new_W) in zip(Ws, new_Ws)]

    delta_Ws = [tf.random_normal(W.shape, dtype=dtype) for W in Ws]
    grad_deltaw = tf.reduce_sum(
        [tf.reduce_sum(g * v) for (g, v) in zip(grads, delta_Ws)])
Ejemplo n.º 11
0
def test_loss( ):
    num_errs = 0
    with torch.no_grad():
        for data, target in test_loader:
            y = lenet5(data)
            _, pred = torch.max(y, dim=1)
            num_errs += torch.sum(pred!=target)            
    return num_errs.item()/len(test_loader.dataset)


Qs = [[torch.eye(W.shape[0]), torch.eye(W.shape[1])] for W in lenet5.parameters()]
lr = 0.1
grad_norm_clip_thr = 0.1*sum(W.numel() for W in lenet5.parameters())**0.5
TrainLosses, best_test_loss = [], 1.0
for epoch in range(10):
    for _, (data, target) in enumerate(train_loader):
        loss = train_loss(data, target)
        grads = torch.autograd.grad(loss, lenet5.parameters(), create_graph=True)
        vs = [torch.randn_like(W) for W in lenet5.parameters()]
        Hvs = torch.autograd.grad(grads, lenet5.parameters(), vs) 
        with torch.no_grad():
            Qs = [psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv) for (Qlr, v, Hv) in zip(Qs, vs, Hvs)]
            pre_grads = [psgd.precond_grad_kron(Qlr[0], Qlr[1], g) for (Qlr, g) in zip(Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            lr_adjust = min(grad_norm_clip_thr/grad_norm, 1.0)
            [W.subtract_(lr_adjust*lr*g) for (W, g) in zip(lenet5.parameters(), pre_grads)]                
            TrainLosses.append(loss.item())
    best_test_loss = min(best_test_loss, test_loss())
    lr *= (0.01)**(1/9)
    print('Epoch: {}; best test classification error rate: {}'.format(epoch+1, best_test_loss))
plt.plot(TrainLosses)
        ],
        [0.1 * torch.eye(R), torch.ones(1, J)],
        [
            0.1 * torch.stack([torch.ones(R), torch.zeros(R)], dim=0),
            torch.ones(1, K)
        ],
    ]

    # # example 3
    # Qs = [[0.1*torch.eye(w.shape[0]), torch.eye(w.shape[1])] for w in xyz]

    for _ in range(100):
        loss = f()
        f_values.append(loss.item())
        grads = torch.autograd.grad(loss, xyz, create_graph=True)
        vs = [torch.randn_like(w) for w in xyz]
        Hvs = torch.autograd.grad(grads, xyz, vs)
        with torch.no_grad():
            Qs = [
                psgd.update_precond_kron(Qlr[0], Qlr[1], v, Hv, step=0.1)
                for (Qlr, v, Hv) in zip(Qs, vs, Hvs)
            ]
            pre_grads = [
                psgd.precond_grad_kron(Qlr[0], Qlr[1], g)
                for (Qlr, g) in zip(Qs, grads)
            ]
            [w.subtract_(0.1 * g) for (w, g) in zip(xyz, pre_grads)]

plt.semilogy(f_values)
plt.xlabel('Iterations')
plt.ylabel('Decomposition losses')
Ejemplo n.º 13
0
    # update preconditioners
    Q_update_gap = max(int(np.floor(np.log10(num_iter + 1.0))), 1)
    if num_iter % Q_update_gap == 0:  # let us update Q less frequently
        delta = [Variable(torch.randn(W.size())) for W in Ws]
        grad_delta = sum([torch.sum(g * d) for (g, d) in zip(grads, delta)])
        hess_delta = grad(grad_delta, Ws)
        Qs = [
            psgd.update_precond_kron(q[0], q[1], dw.data.numpy(),
                                     dg.data.numpy())
            for (q, dw, dg) in zip(Qs, delta, hess_delta)
        ]

    # update Ws
    pre_grads = [
        psgd.precond_grad_kron(q[0], q[1], g.data.numpy())
        for (q, g) in zip(Qs, grads)
    ]
    grad_norm = np.sqrt(sum([np.sum(g * g) for g in pre_grads]))
    if grad_norm > grad_norm_clip_thr:
        step_adjust = grad_norm_clip_thr / grad_norm
    else:
        step_adjust = 1.0
    for i in range(len(Ws)):
        Ws[i].data = Ws[i].data - step_adjust * step_size * torch.FloatTensor(
            pre_grads[i])

    if num_iter % 100 == 0:
        print('training loss: {}'.format(Loss[-1]))

plt.semilogy(Loss)