Esempio n. 1
0
    def _train_epoch(
        self,
        model: L2XModel,
        train_dataloader: torch.utils.data.DataLoader,
        criterion: Union[torch.nn.modules.loss._Loss, Callable],
        optimizer: torch.optim.Optimizer,
        device: torch.device,
        gamma: float,
    ) -> float:
        """
        Train only one epoch.

        Args:
            model: Trained L2X model.
            train_dataloader: Dataloader, used for training.
            loss: Torch loss object. Unnormalized (biased) negative log-likelihood.
            optimizer: Optimizer that should be one or nothing.

        """
        model.train()
        if self.verbose:
            train_dataloader = tqdm(train_dataloader,
                                    desc="train",
                                    disable=False)

        accum_loss = 0.0
        iters = 0
        for data in train_dataloader:
            x = data["text"]
            x = x.to(device)
            y = data["target"]
            y = y.to(device)
            optimizer.zero_grad()
            pred, corr_pred = model(x)
            # Negative loglikelihood up to a constant
            nll_loss = criterion(pred, y)
            # Encouragement of neighbour tokens
            corr_loss = torch.mean(
                ((corr_pred[:, 1:])**2 * (corr_pred[:, :-1])**2).sum(-1))
            # Not sure that optima of this pair of losses is the same
            # but dunno how get best validation score
            loss = nll_loss - gamma * corr_loss
            loss.backward()
            optimizer.step(),
            nll_loss = nll_loss.data.cpu().detach().numpy()
            accum_loss += nll_loss
            iters += 1

            if self.verbose:
                train_dataloader.set_description(
                    "train nll (loss={:.4f})".format(accum_loss / iters))

        return accum_loss / iters
def replace_loader_dataset(dataloader: torch.utils.data.DataLoader,
                           dataset: torch.utils.data.Dataset,
                           sampler=None):
    dataloader.dataset = dataset
    if sampler is None:
        print(
            f"* Warning - sampler {dataloader.sampler.__class__.__name__} is being replaced by RandomSampler *"
        )
        sampler = RandomSampler(dataset)
    batch_sampler = BatchSampler(sampler, dataloader.batch_size,
                                 dataloader.drop_last)
    dataloader.batch_sampler = batch_sampler
Esempio n. 3
0
def train(epochs: int, model, n_nabels, loader: torch.utils.data.DataLoader, optimizer, print_delay=5000, printer=True):
    loader.pin_memory = True
    for epoch in range(epochs):
        running_loss = 0.0
        print_tot = 1
        model.train()
        for i, data in enumerate(loader):
            inputs, target = data[0].to(device), data[1].to(device)

            if mode == Mode.task_il:
                mask = get_mask(inputs, target, device, n_nabels)

            optimizer.zero_grad()

            outputs = model(inputs)

            if mode == Mode.task_il:
                outputs = outputs+mask

            outputs = F.log_softmax(outputs, dim=1)

            loss = F.nll_loss(outputs, target) + \
                lambda_reg * model.get_regularizer()

            loss.backward()
            optimizer.step()

            # print statistics
            if(printer):
                running_loss += loss.item()
                if print_delay != None and i*loader.batch_size >= print_delay*print_tot:    # print every 2000 mini-batches
                    print_tot += 1
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, i*loader.batch_size, running_loss / 2000))
                    running_loss = 0.0
Esempio n. 4
0
def test(model, n_nabels, loader: torch.utils.data.DataLoader, printer=True):
    model.eval()
    test_loss = 0
    correct = 0
    loader.pin_memory = True
    with torch.no_grad():
        for d, t in loader:
            data = d.to(device)
            target = t.to(device)

            if mode == Mode.task_il:
                mask = get_mask(data, target, device, n_nabels)

            output = model(data)

            if mode == Mode.task_il:
                output = output+mask

            output = F.log_softmax(output, dim=1)

            test_loss += F.nll_loss(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()

    test_loss /= len(loader.dataset)
    accuracy = 100. * correct / len(loader.dataset)
    if(printer):
        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(loader.dataset),
            100. * correct / len(loader.dataset)))
    return test_loss, accuracy
Esempio n. 5
0
def train(epochs: int,
          model,
          loader: torch.utils.data.DataLoader,
          optimizer,
          device=torch.device("cpu"),
          print_delay=64,
          loss_fn=F.nll_loss,
          task=None,
          regularizer_fn=None,
          lambda_reg=0):
    loader.pin_memory = True
    for epoch in range(epochs):
        running_loss = 0.0
        model.train()
        for i, data in enumerate(loader):
            inputs, target = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = loss_fn(outputs, target)
            if (regularizer_fn != None):
                loss += lambda_reg * regularizer_fn()
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if print_delay != None and i % print_delay == print_delay - 1:  # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
Esempio n. 6
0
def test(model,
         loader: torch.utils.data.DataLoader,
         device=torch.device("cpu"),
         loss_fn=F.nll_loss,
         task=None):
    model.eval()
    test_loss = 0
    correct = 0
    loader.pin_memory = True
    with torch.no_grad():
        for d, t in loader:
            data = d.to(device)
            target = t.to(device)

            output = model(data)
            test_loss += loss_fn(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()

    test_loss /= len(loader.dataset)
    accuracy = 100. * correct / len(loader.dataset)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(loader.dataset),
        100. * correct / len(loader.dataset)))
    return test_loss, accuracy
Esempio n. 7
0
def train(epochs: int,
          model,
          n_nabels,
          loader: torch.utils.data.DataLoader,
          optimizer,
          print_delay=5000,
          printer=True):
    loader.pin_memory = True
    for epoch in range(epochs):
        running_loss = 0.0
        print_tot = 1
        model.train()
        for i, data in enumerate(loader):
            inputs, targets = data[0].to(device), data[1].to(device)

            buf_inputs, buf_targets = model.sample_buffer(n_buf_samples)

            # extended_inputs = torch.cat((inputs, buf_inputs))
            # extended_targets = torch.cat((targets, buf_targets))

            if mode == Mode.task_il:
                mask = get_mask(inputs, targets, device, n_nabels)
                mask_buf = get_mask(buf_inputs, buf_targets, device,
                                    n_nabels) if buf_inputs != None else None

            optimizer.zero_grad()

            outputs = model(inputs)
            outputs_buff = model(buf_inputs) if buf_inputs != None else None

            if mode == Mode.task_il:
                outputs = outputs + mask
                outputs_buff = outputs_buff + mask_buf if buf_inputs != None else None

            outputs = F.log_softmax(outputs, dim=1)
            outputs_buff = F.log_softmax(outputs_buff,
                                         dim=1) if buf_inputs != None else None

            loss = F.nll_loss(outputs, targets)
            loss_buf = F.nll_loss(outputs_buff,
                                  buf_targets) if buf_inputs != None else 0

            loss = loss + lambda_reg * loss_buf

            loss.backward()
            optimizer.step()

            model.add_batch_buffer(inputs, targets)

            # print statistics
            if (printer):
                running_loss += loss.item()
                if print_delay != None and i * loader.batch_size >= print_delay * print_tot:  # print every 2000 mini-batches
                    print_tot += 1
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, i * loader.batch_size,
                           running_loss / 2000))
                    running_loss = 0.0