示例#1
0
class BPTTUpdater(training.updaters.StandardUpdater):
    """An updater for a chainer LM

    :param chainer.dataset.Iterator train_iter : The train iterator
    :param optimizer:
    :param schedulers:
    :param int device : The device id
    :param int accum_grad :
    """
    def __init__(self, train_iter, optimizer, schedulers, device, accum_grad):
        super(BPTTUpdater, self).__init__(train_iter, optimizer, device=device)
        self.scheduler = ChainerScheduler(schedulers, optimizer)
        self.accum_grad = accum_grad

    # The core part of the update routine can be customized by overriding.
    def update_core(self):
        # When we pass one iterator and optimizer to StandardUpdater.__init__,
        # they are automatically named 'main'.
        train_iter = self.get_iterator('main')
        optimizer = self.get_optimizer('main')

        count = 0
        sum_loss = 0
        optimizer.target.cleargrads()  # Clear the parameter gradients
        for _ in range(self.accum_grad):
            # Progress the dataset iterator for sentences at each iteration.
            batch = train_iter.__next__()
            x, t = convert.concat_examples(batch,
                                           device=self.device,
                                           padding=(0, -1))
            # Concatenate the token IDs to matrices and send them to the device
            # self.converter does this job
            # (it is chainer.dataset.concat_examples by default)
            xp = chainer.backends.cuda.get_array_module(x)
            loss = 0
            state = None
            batch_size, sequence_length = x.shape
            for i in six.moves.range(sequence_length):
                # Compute the loss at this time step and accumulate it
                state, loss_batch = optimizer.target(state,
                                                     chainer.Variable(x[:, i]),
                                                     chainer.Variable(t[:, i]))
                non_zeros = xp.count_nonzero(x[:, i])
                loss += loss_batch * non_zeros
                count += int(non_zeros)
            # backward
            loss /= batch_size * self.accum_grad  # normalized by batch size
            sum_loss += float(loss.data)
            loss.backward()  # Backprop
            loss.unchain_backward()  # Truncate the graph

        reporter.report({'loss': sum_loss}, optimizer.target)
        reporter.report({'count': count}, optimizer.target)
        # update
        optimizer.update()  # Update the parameters
        self.scheduler.step(self.iteration)
示例#2
0
def test_chainer_scheduler():
    warmup = 30000
    s = scheduler.NoamScheduler.build("lr", warmup=warmup)
    net = chainer.links.Linear(2, 1)
    o = chainer.optimizers.SGD(lr=1.0)
    o.setup(net)
    so = ChainerScheduler([s], o)
    so.step(0)
    assert o.lr == s.scale(0)

    so.step(warmup)
    numpy.testing.assert_allclose(o.lr, 1.0, rtol=1e-4)
示例#3
0
 def __init__(self, train_iter, optimizer, schedulers, device, accum_grad):
     super(BPTTUpdater, self).__init__(train_iter, optimizer, device=device)
     self.scheduler = ChainerScheduler(schedulers, optimizer)
     self.accum_grad = accum_grad