def _train_epoch( self, model: L2XModel, train_dataloader: torch.utils.data.DataLoader, criterion: Union[torch.nn.modules.loss._Loss, Callable], optimizer: torch.optim.Optimizer, device: torch.device, gamma: float, ) -> float: """ Train only one epoch. Args: model: Trained L2X model. train_dataloader: Dataloader, used for training. loss: Torch loss object. Unnormalized (biased) negative log-likelihood. optimizer: Optimizer that should be one or nothing. """ model.train() if self.verbose: train_dataloader = tqdm(train_dataloader, desc="train", disable=False) accum_loss = 0.0 iters = 0 for data in train_dataloader: x = data["text"] x = x.to(device) y = data["target"] y = y.to(device) optimizer.zero_grad() pred, corr_pred = model(x) # Negative loglikelihood up to a constant nll_loss = criterion(pred, y) # Encouragement of neighbour tokens corr_loss = torch.mean( ((corr_pred[:, 1:])**2 * (corr_pred[:, :-1])**2).sum(-1)) # Not sure that optima of this pair of losses is the same # but dunno how get best validation score loss = nll_loss - gamma * corr_loss loss.backward() optimizer.step(), nll_loss = nll_loss.data.cpu().detach().numpy() accum_loss += nll_loss iters += 1 if self.verbose: train_dataloader.set_description( "train nll (loss={:.4f})".format(accum_loss / iters)) return accum_loss / iters
def replace_loader_dataset(dataloader: torch.utils.data.DataLoader, dataset: torch.utils.data.Dataset, sampler=None): dataloader.dataset = dataset if sampler is None: print( f"* Warning - sampler {dataloader.sampler.__class__.__name__} is being replaced by RandomSampler *" ) sampler = RandomSampler(dataset) batch_sampler = BatchSampler(sampler, dataloader.batch_size, dataloader.drop_last) dataloader.batch_sampler = batch_sampler
def train(epochs: int, model, n_nabels, loader: torch.utils.data.DataLoader, optimizer, print_delay=5000, printer=True): loader.pin_memory = True for epoch in range(epochs): running_loss = 0.0 print_tot = 1 model.train() for i, data in enumerate(loader): inputs, target = data[0].to(device), data[1].to(device) if mode == Mode.task_il: mask = get_mask(inputs, target, device, n_nabels) optimizer.zero_grad() outputs = model(inputs) if mode == Mode.task_il: outputs = outputs+mask outputs = F.log_softmax(outputs, dim=1) loss = F.nll_loss(outputs, target) + \ lambda_reg * model.get_regularizer() loss.backward() optimizer.step() # print statistics if(printer): running_loss += loss.item() if print_delay != None and i*loader.batch_size >= print_delay*print_tot: # print every 2000 mini-batches print_tot += 1 print('[%d, %5d] loss: %.3f' % (epoch + 1, i*loader.batch_size, running_loss / 2000)) running_loss = 0.0
def test(model, n_nabels, loader: torch.utils.data.DataLoader, printer=True): model.eval() test_loss = 0 correct = 0 loader.pin_memory = True with torch.no_grad(): for d, t in loader: data = d.to(device) target = t.to(device) if mode == Mode.task_il: mask = get_mask(data, target, device, n_nabels) output = model(data) if mode == Mode.task_il: output = output+mask output = F.log_softmax(output, dim=1) test_loss += F.nll_loss(output, target, size_average=False).item() pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).sum() test_loss /= len(loader.dataset) accuracy = 100. * correct / len(loader.dataset) if(printer): print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(loader.dataset), 100. * correct / len(loader.dataset))) return test_loss, accuracy
def train(epochs: int, model, loader: torch.utils.data.DataLoader, optimizer, device=torch.device("cpu"), print_delay=64, loss_fn=F.nll_loss, task=None, regularizer_fn=None, lambda_reg=0): loader.pin_memory = True for epoch in range(epochs): running_loss = 0.0 model.train() for i, data in enumerate(loader): inputs, target = data[0].to(device), data[1].to(device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = model(inputs) loss = loss_fn(outputs, target) if (regularizer_fn != None): loss += lambda_reg * regularizer_fn() loss.backward() optimizer.step() # print statistics running_loss += loss.item() if print_delay != None and i % print_delay == print_delay - 1: # print every 2000 mini-batches print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0
def test(model, loader: torch.utils.data.DataLoader, device=torch.device("cpu"), loss_fn=F.nll_loss, task=None): model.eval() test_loss = 0 correct = 0 loader.pin_memory = True with torch.no_grad(): for d, t in loader: data = d.to(device) target = t.to(device) output = model(data) test_loss += loss_fn(output, target, size_average=False).item() pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).sum() test_loss /= len(loader.dataset) accuracy = 100. * correct / len(loader.dataset) print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(loader.dataset), 100. * correct / len(loader.dataset))) return test_loss, accuracy
def train(epochs: int, model, n_nabels, loader: torch.utils.data.DataLoader, optimizer, print_delay=5000, printer=True): loader.pin_memory = True for epoch in range(epochs): running_loss = 0.0 print_tot = 1 model.train() for i, data in enumerate(loader): inputs, targets = data[0].to(device), data[1].to(device) buf_inputs, buf_targets = model.sample_buffer(n_buf_samples) # extended_inputs = torch.cat((inputs, buf_inputs)) # extended_targets = torch.cat((targets, buf_targets)) if mode == Mode.task_il: mask = get_mask(inputs, targets, device, n_nabels) mask_buf = get_mask(buf_inputs, buf_targets, device, n_nabels) if buf_inputs != None else None optimizer.zero_grad() outputs = model(inputs) outputs_buff = model(buf_inputs) if buf_inputs != None else None if mode == Mode.task_il: outputs = outputs + mask outputs_buff = outputs_buff + mask_buf if buf_inputs != None else None outputs = F.log_softmax(outputs, dim=1) outputs_buff = F.log_softmax(outputs_buff, dim=1) if buf_inputs != None else None loss = F.nll_loss(outputs, targets) loss_buf = F.nll_loss(outputs_buff, buf_targets) if buf_inputs != None else 0 loss = loss + lambda_reg * loss_buf loss.backward() optimizer.step() model.add_batch_buffer(inputs, targets) # print statistics if (printer): running_loss += loss.item() if print_delay != None and i * loader.batch_size >= print_delay * print_tot: # print every 2000 mini-batches print_tot += 1 print('[%d, %5d] loss: %.3f' % (epoch + 1, i * loader.batch_size, running_loss / 2000)) running_loss = 0.0