def __init__(self, train_iter, model, optimizer, schedulers, device_id, gradclip=None, use_apex=False, accum_grad=1): """Initialize class. Args: train_iter (chainer.dataset.Iterator): The train iterator model (LMInterface) : The model to update optimizer (torch.optim.Optimizer): The optimizer for training schedulers (espnet.scheduler.scheduler.SchedulerInterface): The schedulers of `optimizer` device (int): The device id gradclip (float): The gradient clipping value to use use_apex (bool): The flag to use Apex in backprop. accum_grad (int): The number of gradient accumulation. """ super(BPTTUpdater, self).__init__(train_iter, optimizer) self.model = model self.device_id = device_id self.gradclip = gradclip self.use_apex = use_apex self.scheduler = PyTorchScheduler(schedulers, optimizer) self.accum_grad = accum_grad
def test_pytorch_scheduler(): warmup = 30000 s = scheduler.NoamScheduler.build("lr", warmup=warmup) net = torch.nn.Linear(2, 1) o = torch.optim.SGD(net.parameters(), lr=1.0) so = PyTorchScheduler([s], o) so.step(0) for g in o.param_groups: assert g["lr"] == s.scale(0) so.step(warmup) for g in o.param_groups: numpy.testing.assert_allclose(g["lr"], 1.0, rtol=1e-4)
class BPTTUpdater(training.StandardUpdater): """An updater for a pytorch LM.""" def __init__(self, train_iter, model, optimizer, schedulers, device_id, gradclip=None, use_apex=False, accum_grad=1): """Initialize class. Args: train_iter (chainer.dataset.Iterator): The train iterator model (LMInterface) : The model to update optimizer (torch.optim.Optimizer): The optimizer for training schedulers (espnet.scheduler.scheduler.SchedulerInterface): The schedulers of `optimizer` device (int): The device id gradclip (float): The gradient clipping value to use use_apex (bool): The flag to use Apex in backprop. accum_grad (int): The number of gradient accumulation. """ super(BPTTUpdater, self).__init__(train_iter, optimizer) self.model = model self.device_id = device_id self.gradclip = gradclip self.use_apex = use_apex self.scheduler = PyTorchScheduler(schedulers, optimizer) self.accum_grad = accum_grad # The core part of the update routine can be customized by overriding. def update_core(self): """Update the model.""" # When we pass one iterator and optimizer to StandardUpdater.__init__, # they are automatically named 'main'. train_iter = self.get_iterator('main') optimizer = self.get_optimizer('main') # Progress the dataset iterator for sentences at each iteration. self.model.zero_grad() # Clear the parameter gradients accum = {"loss": 0.0, "nll": 0.0, "count": 0} for _ in range(self.accum_grad): batch = train_iter.__next__() # Concatenate the token IDs to matrices and send them to the device # self.converter does this job # (it is chainer.dataset.concat_examples by default) x, t = concat_examples(batch, device=self.device_id[0], padding=(0, -100)) if self.device_id[0] == -1: loss, nll, count = self.model(x, t) else: # apex does not support torch.nn.DataParallel loss, nll, count = data_parallel(self.model, (x, t), self.device_id) # backward loss = loss.mean() / self.accum_grad if self.use_apex: from apex import amp with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # Backprop # accumulate stats accum["loss"] += float(loss) accum["nll"] += float(nll.sum()) accum["count"] += int(count.sum()) for k, v in accum.items(): reporter.report({k: v}, optimizer.target) if self.gradclip is not None: nn.utils.clip_grad_norm_(self.model.parameters(), self.gradclip) optimizer.step() # Update the parameters self.scheduler.step(n_iter=self.iteration)