def refine(dataloader: DataLoader, model: RNN, optimizer: torch.optim.Optimizer, loss_function: Union[SplitCrossEntropyLoss, CrossEntropyLoss], prior: Prior, bptt: int, use_apex: bool = False, amp=None, alpha: float = 0, beta: float = 0, importance: Union[int, float] = 100000, device: Union[torch.device, str] = 'cpu', **kwargs): model.train() batch = 0 with tqdm(dataloader, total=len(dataloader)) as pbar: for data, targets, seq_len, lang in pbar: data = data.squeeze(0).to(device) targets = targets.squeeze(0).to(device) lang = lang.to(device) lr2 = optimizer.param_groups[0]['lr'] optimizer.param_groups[0]['lr'] = lr2 * seq_len.item() / bptt hidden = model.init_hidden(batchsize=data.size(-1)) optimizer.zero_grad() output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, lang, return_h=True) if isinstance(loss_function, SplitCrossEntropyLoss): loss = loss_function(model.decoder.weight, model.decoder.bias, output, targets) else: loss = loss_function(output, targets) penalty = importance * prior.penalty(model) loss += penalty # Activiation Regularization if alpha: loss = loss + sum(alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:]) # Temporal Activation Regularization (slowness) if beta: loss = loss + sum(beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:]) if use_apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() optimizer.param_groups[0]['lr'] = lr2 batch += 1 pbar.set_description( 'Loss {:5.2f} | bpc {:9.3f} | penalty {} |'.format(loss, loss / math.log(2), penalty.item()))
def evaluate(dataloader: DataLoader, model: RNN, loss_function: Union[SplitCrossEntropyLoss, CrossEntropyLoss], only_l: Union[torch.Tensor, int] = None, device: Union[torch.device, str] = 'cpu', **kwargs): model.eval() languages = dataloader.dataset.data.keys() if only_l: if only_l not in languages: raise ValueError(f'Language {only_l} does not exist in the dataset') local_losses = {only_l: 0} else: local_losses = {lang: 0 for lang in languages} batch = 0 prev_lang = "" with tqdm(dataloader, total=len(dataloader)) as pbar: for data, targets, seq_len, lang in pbar: data = data.squeeze(0).to(device) targets = targets.squeeze(0).to(device) lang = lang.to(device) if only_l and only_l != lang: continue if prev_lang != lang: prev_lang = lang hidden = model.init_hidden(batchsize=data.size(-1)) else: detach(hidden) with torch.no_grad(): output, hidden = model(data, hidden, lang) if isinstance(loss_function, SplitCrossEntropyLoss): loss = loss_function(model.decoder.weight, model.decoder.bias, output, targets) else: loss = loss_function(output, targets) local_losses[lang.item()] += len(data) * loss.data batch += 1 pbar.set_description('Evaluation, finished batch {} | loss {}'.format(batch, loss.data)) avg_loss = {lang: local_losses[lang].item() / len(dataloader.dataset.data[lang]) for lang in languages} if only_l is None else {only_l: local_losses[only_l].item() / len(dataloader.dataset.data[only_l])} total_loss = sum(avg_loss.values()) return total_loss / len(languages), avg_loss
def train(dataloader: DataLoader, model: RNN, optimizer: torch.optim.Optimizer, loss_function: Union[SplitCrossEntropyLoss, CrossEntropyLoss], use_apex=False, amp=None, lr_weights: dict = None, prior: str = 'ninf', scaling: str = None, total_steps: int = 0, steps: int = 0, bptt: int = 125, alpha: float = 0., beta: float = 0., log_interval: int = 200, n_samples: int = 4, device: Union[torch.device, str] = 'cpu', tb_writer=None, **kwargs): total_loss = 0 batch = 0 tr_kl = 0. logging_kl = 0. tr_loss = 0. logging_loss = 0. model.train() log.info('Starting training loop') start_time = time.time() with tqdm(dataloader, total=len(dataloader)) as pbar: for data, targets, seq_len, lang in pbar: data = data.squeeze(0).to(device) targets = targets.squeeze(0).to(device) lang = lang.to(device) hidden = model.init_hidden(batchsize=data.size(-1)) lr2 = optimizer.param_groups[0]['lr'] if lr_weights is not None: optimizer.param_groups[0]['lr'] = lr2 * seq_len.item() / bptt * lr_weights[lang.item()] else: optimizer.param_groups[0]['lr'] = lr2 * seq_len.item() / bptt hidden = detach(hidden) optimizer.zero_grad() loss = 0 if not isinstance(prior, VIPrior): n_samples = 1 for s in range(n_samples): output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, lang, return_h=True) if isinstance(loss_function, SplitCrossEntropyLoss): raw_loss = loss_function(model.decoder.weight, model.decoder.bias, output, targets) else: raw_loss = loss_function(output, targets) if alpha: raw_loss = raw_loss + sum(alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:]) # Temporal Activation Regularization (slowness) if beta: raw_loss = raw_loss + sum(beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:]) loss += raw_loss loss /= n_samples log_loss = loss if isinstance(prior, VIPrior): kl_term = prior.kl_div() if scaling == "uniform": scale = 1. / total_steps elif scaling == "linear_annealing": scale = ((total_steps - steps - 1) * 2. + 1.) / total_steps ** 2 elif scaling == "logistic_annealing": steepness = 0.0025 scale = 1. / (1 + np.exp(-steepness * (steps - total_steps / 2.))) else: scale = 1. loss = loss + scale * kl_term tr_kl += kl_term.item() if use_apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if tb_writer is not None: tb_writer.add_scalar('train/loss', log_loss.item(), steps) if isinstance(prior, VIPrior): tb_writer.add_scalar('train/kl', kl_term.item(), steps) tb_writer.add_scalar('train/loss+kl', loss.item(), steps) logging_kl += tr_kl logging_loss += tr_loss optimizer.step() total_loss += raw_loss.data batch += 1 steps += 1 # reset lr to optimiser default optimizer.param_groups[0]['lr'] = lr2 if batch % log_interval == 0 and batch > 0: cur_loss = total_loss.item() / log_interval elapsed = time.time() - start_time log.debug( '| {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format( batch, len(dataloader), optimizer.param_groups[0]['lr'], elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss), cur_loss / math.log(2))) total_loss = 0 start_time = time.time() pbar.set_description('Training, end of batch {} | Loss {}'.format(batch, loss.data)) return steps