예제 #1
0
    def __init__(self,
                 train_iter,
                 model,
                 optimizer,
                 schedulers,
                 device_id,
                 gradclip=None,
                 use_apex=False,
                 accum_grad=1):
        """Initialize class.

        Args:
            train_iter (chainer.dataset.Iterator): The train iterator
            model (LMInterface) : The model to update
            optimizer (torch.optim.Optimizer): The optimizer for training
            schedulers (espnet.scheduler.scheduler.SchedulerInterface): The schedulers of `optimizer`
            device (int): The device id
            gradclip (float): The gradient clipping value to use
            use_apex (bool): The flag to use Apex in backprop.
            accum_grad (int): The number of gradient accumulation.

        """
        super(BPTTUpdater, self).__init__(train_iter, optimizer)
        self.model = model
        self.device_id = device_id
        self.gradclip = gradclip
        self.use_apex = use_apex
        self.scheduler = PyTorchScheduler(schedulers, optimizer)
        self.accum_grad = accum_grad
예제 #2
0
def test_pytorch_scheduler():
    warmup = 30000
    s = scheduler.NoamScheduler.build("lr", warmup=warmup)
    net = torch.nn.Linear(2, 1)
    o = torch.optim.SGD(net.parameters(), lr=1.0)
    so = PyTorchScheduler([s], o)
    so.step(0)
    for g in o.param_groups:
        assert g["lr"] == s.scale(0)

    so.step(warmup)
    for g in o.param_groups:
        numpy.testing.assert_allclose(g["lr"], 1.0, rtol=1e-4)
예제 #3
0
class BPTTUpdater(training.StandardUpdater):
    """An updater for a pytorch LM."""
    def __init__(self,
                 train_iter,
                 model,
                 optimizer,
                 schedulers,
                 device_id,
                 gradclip=None,
                 use_apex=False,
                 accum_grad=1):
        """Initialize class.

        Args:
            train_iter (chainer.dataset.Iterator): The train iterator
            model (LMInterface) : The model to update
            optimizer (torch.optim.Optimizer): The optimizer for training
            schedulers (espnet.scheduler.scheduler.SchedulerInterface): The schedulers of `optimizer`
            device (int): The device id
            gradclip (float): The gradient clipping value to use
            use_apex (bool): The flag to use Apex in backprop.
            accum_grad (int): The number of gradient accumulation.

        """
        super(BPTTUpdater, self).__init__(train_iter, optimizer)
        self.model = model
        self.device_id = device_id
        self.gradclip = gradclip
        self.use_apex = use_apex
        self.scheduler = PyTorchScheduler(schedulers, optimizer)
        self.accum_grad = accum_grad

    # The core part of the update routine can be customized by overriding.
    def update_core(self):
        """Update the model."""
        # When we pass one iterator and optimizer to StandardUpdater.__init__,
        # they are automatically named 'main'.
        train_iter = self.get_iterator('main')
        optimizer = self.get_optimizer('main')
        # Progress the dataset iterator for sentences at each iteration.
        self.model.zero_grad()  # Clear the parameter gradients
        accum = {"loss": 0.0, "nll": 0.0, "count": 0}
        for _ in range(self.accum_grad):
            batch = train_iter.__next__()
            # Concatenate the token IDs to matrices and send them to the device
            # self.converter does this job
            # (it is chainer.dataset.concat_examples by default)
            x, t = concat_examples(batch,
                                   device=self.device_id[0],
                                   padding=(0, -100))
            if self.device_id[0] == -1:
                loss, nll, count = self.model(x, t)
            else:
                # apex does not support torch.nn.DataParallel
                loss, nll, count = data_parallel(self.model, (x, t),
                                                 self.device_id)

            # backward
            loss = loss.mean() / self.accum_grad
            if self.use_apex:
                from apex import amp
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()  # Backprop
            # accumulate stats
            accum["loss"] += float(loss)
            accum["nll"] += float(nll.sum())
            accum["count"] += int(count.sum())

        for k, v in accum.items():
            reporter.report({k: v}, optimizer.target)
        if self.gradclip is not None:
            nn.utils.clip_grad_norm_(self.model.parameters(), self.gradclip)
        optimizer.step()  # Update the parameters
        self.scheduler.step(n_iter=self.iteration)