class BPTTUpdater(training.updaters.StandardUpdater): """An updater for a chainer LM :param chainer.dataset.Iterator train_iter : The train iterator :param optimizer: :param schedulers: :param int device : The device id :param int accum_grad : """ def __init__(self, train_iter, optimizer, schedulers, device, accum_grad): super(BPTTUpdater, self).__init__(train_iter, optimizer, device=device) self.scheduler = ChainerScheduler(schedulers, optimizer) self.accum_grad = accum_grad # The core part of the update routine can be customized by overriding. def update_core(self): # When we pass one iterator and optimizer to StandardUpdater.__init__, # they are automatically named 'main'. train_iter = self.get_iterator('main') optimizer = self.get_optimizer('main') count = 0 sum_loss = 0 optimizer.target.cleargrads() # Clear the parameter gradients for _ in range(self.accum_grad): # Progress the dataset iterator for sentences at each iteration. batch = train_iter.__next__() x, t = convert.concat_examples(batch, device=self.device, padding=(0, -1)) # Concatenate the token IDs to matrices and send them to the device # self.converter does this job # (it is chainer.dataset.concat_examples by default) xp = chainer.backends.cuda.get_array_module(x) loss = 0 state = None batch_size, sequence_length = x.shape for i in six.moves.range(sequence_length): # Compute the loss at this time step and accumulate it state, loss_batch = optimizer.target(state, chainer.Variable(x[:, i]), chainer.Variable(t[:, i])) non_zeros = xp.count_nonzero(x[:, i]) loss += loss_batch * non_zeros count += int(non_zeros) # backward loss /= batch_size * self.accum_grad # normalized by batch size sum_loss += float(loss.data) loss.backward() # Backprop loss.unchain_backward() # Truncate the graph reporter.report({'loss': sum_loss}, optimizer.target) reporter.report({'count': count}, optimizer.target) # update optimizer.update() # Update the parameters self.scheduler.step(self.iteration)
def test_chainer_scheduler(): warmup = 30000 s = scheduler.NoamScheduler.build("lr", warmup=warmup) net = chainer.links.Linear(2, 1) o = chainer.optimizers.SGD(lr=1.0) o.setup(net) so = ChainerScheduler([s], o) so.step(0) assert o.lr == s.scale(0) so.step(warmup) numpy.testing.assert_allclose(o.lr, 1.0, rtol=1e-4)
def __init__(self, train_iter, optimizer, schedulers, device, accum_grad): super(BPTTUpdater, self).__init__(train_iter, optimizer, device=device) self.scheduler = ChainerScheduler(schedulers, optimizer) self.accum_grad = accum_grad