def val_step(x, y, model, criterion, bleu, device, distributed=False): # get masks and targets y_inp, y_tar = y[:, :-1], y[:, 1:] enc_mask, look_ahead_mask, dec_mask = base_transformer.create_masks( x, y_inp) # devices x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask = to_devices( (x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask), device, non_blocking=distributed) # forward model.eval() with torch.no_grad(): y_pred, _ = model(x, y_inp, enc_mask, look_ahead_mask, dec_mask) loss = loss_fn(y_pred.permute(0, 2, 1), y_tar, criterion) # metrics batch_loss = loss.detach() batch_acc = accuracy_fn(y_pred.detach(), y_tar) bleu(torch.argmax(y_pred, axis=-1), y_tar) return batch_loss, batch_acc
def aux_train_step(x, y, model, criterion, aux_criterion, aux_strength, frozen_layers, optimizer, scheduler, device, distributed=False): """ Single training step using an auxiliary loss on the encoder outputs.""" # get masks and targets y_inp, y_tar = y[:, :-1], y[:, 1:] enc_mask, look_ahead_mask, dec_mask = base_transformer.create_masks( x, y_inp) # mask for the target language encoded representation. enc_mask_aux = base_transformer.create_mask(y_inp) # devices x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask, enc_mask_aux = to_devices( (x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask, enc_mask_aux), device, non_blocking=distributed) model.train() optimizer.zero_grad() x_enc = model.encode(x, enc_mask) y_pred = model.final_layer( model.decode(y_inp, x_enc, look_ahead_mask, dec_mask)[0]) y_enc = model.encode(y_inp, enc_mask_aux) # main loss. loss_main = loss_fn(y_pred.permute(0, 2, 1), y_tar, criterion) loss_main.backward(retain_graph=True) # aux loss model = param_freeze(model, frozen_layers) loss_aux = auxiliary_loss_fn(x_enc, y_enc, aux_criterion, x_mask=enc_mask, y_mask=enc_mask_aux) scaled_loss_aux = loss_aux * aux_strength scaled_loss_aux.backward() optimizer.step() scheduler.step() model = param_freeze(model, frozen_layers, unfreeze=True) # metrics batch_loss = loss_main.detach() batch_aux = loss_aux.detach() batch_acc = accuracy_fn(y_pred.detach(), y_tar) return batch_loss, batch_aux, batch_acc
def evaluate(x, y, y_code, bleu): y_inp, y_tar = y[:,:-1].contiguous(), y[:,1:].contiguous() enc_mask = (x != 0) x, y_inp, y_tar, enc_mask = to_devices( (x, y_inp, y_tar, enc_mask), device) model.eval() y_pred = model.generate(input_ids=x, decoder_start_token_id=y_code, attention_mask=enc_mask, max_length=params.max_len+1, num_beams=params.num_beams, length_penalty=params.length_penalty, early_stopping=True) bleu(y_pred[:,1:], y_tar)
def train_step(x, y, model, criterion, aux_criterion, optimizer, scheduler, device, distributed=False): # get masks and targets y_inp, y_tar = y[:, :-1], y[:, 1:] enc_mask, look_ahead_mask, dec_mask = base_transformer.create_masks( x, y_inp) # mask for the target language encoded representation. enc_mask_aux = base_transformer.create_mask(y_inp) # devices x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask, enc_mask_aux = to_devices( (x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask, enc_mask_aux), device, non_blocking=distributed) # forward model.train() x_enc = model.encode(x, enc_mask) y_enc = model.encode(y_inp, enc_mask_aux) y_pred = model.final_layer( model.decode(y_inp, x_enc, look_ahead_mask, dec_mask)[0]) loss = loss_fn(y_pred.permute(0, 2, 1), y_tar, criterion) # backward optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() with torch.no_grad(): loss_aux = auxiliary_loss_fn(x_enc, y_enc, aux_criterion, x_mask=enc_mask, y_mask=enc_mask_aux) # metrics batch_loss = loss.detach() batch_aux = loss_aux batch_acc = accuracy_fn(y_pred.detach(), y_tar) return batch_loss, batch_aux, batch_acc
def evaluate(x, y, y_code, bleu): en_code = tokenizer.lang_code_to_id[LANG_CODES['en']] y_inp, y_tar = y[:, :-1].contiguous(), y[:, 1:].contiguous() enc_mask = (x != 0) x, y_inp, y_tar, enc_mask = to_devices( (x, y_inp, y_tar, enc_mask), device) model.eval() pivot_pred = model.generate(input_ids=x, decoder_start_token_id=en_code, attention_mask=enc_mask, max_length=x.size(1) + 1, num_beams=params.num_beams, length_penalty=params.length_penalty, early_stopping=True) pivot_pred = mask_after_stop(pivot_pred, 2) pivot_mask = (pivot_pred != 0) y_pred = model_2.generate(input_ids=pivot_pred, decoder_start_token_id=y_code, attention_mask=pivot_mask, max_length=x.size(1) + 1, num_beams=params.num_beams, length_penalty=params.length_penalty, early_stopping=True) bleu(y_pred[:, 1:], y_tar)
def train_step(x, y, aux=False): y_inp, y_tar = y[:,:-1].contiguous(), y[:,1:].contiguous() enc_mask, dec_mask = (x != 0), (y_inp != 0) x, y_inp, y_tar, enc_mask, dec_mask = to_devices( (x, y_inp, y_tar, enc_mask, dec_mask), device) model.train() if aux: freeze_layers(params.frozen_layers, unfreeze=True) output = model(input_ids=x, decoder_input_ids=y_inp, labels=y_tar, attention_mask=enc_mask, decoder_attention_mask=dec_mask) optimizer.zero_grad() loss = loss_fn(output, y_tar) loss.backward(retain_graph=aux) if aux: freeze_layers(params.frozen_layers) torch.set_grad_enabled(aux) x_enc = output.encoder_last_hidden_state y_enc = model.model.encoder(y_inp, attention_mask=dec_mask)['last_hidden_state'] x_enc = torch.max(x_enc + -999 * (1-enc_mask.type(x_enc.dtype)).unsqueeze(-1), dim=1)[0] y_enc = torch.max(y_enc + -999 * (1-dec_mask.type(y_enc.dtype)).unsqueeze(-1), dim=1)[0] aux_loss = F.cosine_embedding_loss(x_enc, y_enc, _target) scaled_aux_loss = params.aux_strength * aux_loss torch.set_grad_enabled(True) if aux: scaled_aux_loss.backward() optimizer.step() scheduler.step() accuracy = accuracy_fn(output.logits, y_tar) return loss.item(), aux_loss.item(), accuracy.item()
def train(rank, device, logger, params, train_dataloader, val_dataloader=None, tokenizer=None, verbose=50): """Training Loop""" multi = False if len(params.langs) > 2: assert tokenizer is not None multi = True add_targets = preprocess.AddTargetTokens(params.langs, tokenizer) model = initialiser.initialise_model(params, device) optimizer = torch.optim.Adam(model.parameters()) scheduler = WarmupDecay(optimizer, params.warmup_steps, params.d_model, lr_scale=params.lr_scale) criterion = torch.nn.CrossEntropyLoss(reduction='none') _aux_criterion = torch.nn.CosineEmbeddingLoss(reduction='mean') _target = torch.tensor(1.0).to(device) aux_criterion = lambda x, y: _aux_criterion(x, y, _target) epoch = 0 if params.checkpoint: model, optimizer, epoch, scheduler = logging.load_checkpoint( logger.checkpoint_path, device, model, optimizer=optimizer, scheduler=scheduler) if params.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[device.index], find_unused_parameters=True) if rank == 0: if params.wandb: wandb.watch(model) batch_losses, batch_auxs, batch_accs = [], [], [] epoch_losses, epoch_auxs, epoch_accs = [], [], [] val_epoch_losses, val_epoch_accs, val_epoch_bleus = [], [], [] while epoch < params.epochs: start_ = time.time() # train if params.FLAGS: print('training') epoch_loss = 0.0 epoch_aux = 0.0 epoch_acc = 0.0 for i, data in enumerate(train_dataloader): if multi: # sample a tranlsation direction and add target tokens (x, y), (x_lang, y_lang) = sample_direction(data, params.langs, excluded=params.excluded) x = add_targets(x, y_lang) else: x, y = data if params.auxiliary: batch_loss, batch_aux, batch_acc = aux_train_step( x, y, model, criterion, aux_criterion, params.aux_strength, params.frozen_layers, optimizer, scheduler, device, distributed=params.distributed) else: batch_loss, batch_aux, batch_acc = train_step( x, y, model, criterion, aux_criterion, optimizer, scheduler, device, distributed=params.distributed) if rank == 0: batch_loss = batch_loss.item() batch_aux = batch_aux.item() batch_acc = batch_acc.item() batch_losses.append(batch_loss) batch_auxs.append(batch_aux) batch_accs.append(batch_acc) epoch_loss += (batch_loss - epoch_loss) / (i + 1) epoch_aux += (batch_aux - epoch_aux) / (i + 1) epoch_acc += (batch_acc - epoch_acc) / (i + 1) if verbose is not None: if i % verbose == 0: print( 'Batch {} Loss {:.4f} Aux Loss {:.4f} Accuracy {:.4f} in {:.4f} s per batch' .format(i, epoch_loss, epoch_aux, epoch_acc, (time.time() - start_) / (i + 1))) if params.wandb: wandb.log({ 'loss': batch_loss, 'aux_loss': batch_aux, 'accuracy': batch_acc }) if rank == 0: epoch_losses.append(epoch_loss) epoch_auxs.append(epoch_aux) epoch_accs.append(epoch_acc) # val only on rank 0 if rank == 0: if params.FLAGS: print('validating') val_epoch_loss = 0.0 val_epoch_acc = 0.0 val_bleu = 0.0 test_bleu = 0.0 if val_dataloader is not None: bleu = BLEU() bleu.set_excluded_indices([0, 2]) for i, data in enumerate(val_dataloader): if multi: # sample a tranlsation direction and add target tokens (x, y), (x_lang, y_lang) = sample_direction( data, params.langs, excluded=params.excluded) x = add_targets(x, y_lang) else: x, y = data batch_loss, batch_acc = val_step( x, y, model, criterion, bleu, device, distributed=params.distributed) batch_loss = batch_loss.item() batch_acc = batch_acc.item() val_epoch_loss += (batch_loss - val_epoch_loss) / (i + 1) val_epoch_acc += (batch_acc - val_epoch_acc) / (i + 1) val_epoch_losses.append(val_epoch_loss) val_epoch_accs.append(val_epoch_acc) val_bleu = bleu.get_metric() # evaluate without teacher forcing if params.test_freq is not None: if epoch % params.test_freq == 0: bleu_no_tf = BLEU() bleu_no_tf.set_excluded_indices([0, 2]) for i, data in enumerate(val_dataloader): if i > params.test_batches: break else: if multi: # sample a tranlsation direction and add target tokens (x, y), (x_lang, y_lang) = sample_direction( data, params.langs, excluded=params.excluded) x = add_targets(x, y_lang) else: x, y = data y, y_tar = y[:, 0].unsqueeze(-1), y[:, 1:] enc_mask, look_ahead_mask, dec_mask = base_transformer.create_masks( x, y_tar) # devices x, y, y_tar, enc_mask = to_devices( (x, y, y_tar, enc_mask), device) y_pred = beam_search( x, y, y_tar, model, enc_mask=enc_mask, beam_length=params.beam_length, alpha=params.alpha, beta=params.beta) bleu_no_tf(y_pred, y_tar) test_bleu = bleu_no_tf.get_metric() print(test_bleu) if verbose is not None: print( 'Epoch {} Loss {:.4f} Aux Loss {:.4f} Accuracy {:.4f} Val Loss {:.4f} Val Accuracy {:.4f} Val Bleu {:.4f}' ' Test Bleu {:.4f} in {:.4f} secs \n'.format( epoch, epoch_loss, epoch_aux, epoch_acc, val_epoch_loss, val_epoch_acc, val_bleu, test_bleu, time.time() - start_)) if params.wandb: wandb.log({ 'loss': epoch_loss, 'aux_loss': epoch_aux, 'accuracy': epoch_acc, 'val_loss': val_epoch_loss, 'val_accuracy': val_epoch_acc, 'val_bleu': val_bleu, 'test_bleu': test_bleu }) else: if verbose is not None: print( 'Epoch {} Loss {:.4f} Aux Loss {:.4f} Accuracy {:.4f} in {:.4f} secs \n' .format(epoch, epoch_loss, epoch_loss, epoch_acc, time.time() - start_)) if params.wandb: wandb.log({ 'loss': epoch_loss, 'aux_loss': epoch_aux, 'accuracy': epoch_acc }) if params.FLAGS: print('logging results') logger.save_model(epoch, model, optimizer, scheduler=scheduler) logger.log_results([ epoch_loss, epoch_aux, epoch_acc, val_epoch_loss, val_epoch_acc, val_bleu, test_bleu ]) epoch += 1 return epoch_losses, epoch_accs, val_epoch_losses, val_epoch_accs
def inference_step(x, y, model, logger, tokenizer, device, bleu=None, teacher_forcing=False, pivot_mode=False, beam_length=1, alpha=0.0, beta=0.0): """ inference step. x: source language y: target language """ if teacher_forcing: y_inp, y_tar = y[:, :-1], y[:, 1:] enc_mask, look_ahead_mask, dec_mask = base_transformer.create_masks( x, y_inp) # devices x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask = to_devices( (x, y_inp, y_tar, enc_mask, look_ahead_mask, dec_mask), device) # inference model.eval() with torch.no_grad(): y_pred, _ = model(x, y_inp, enc_mask, look_ahead_mask, dec_mask) if not pivot_mode: batch_acc = accuracy_fn(y_pred.detach(), y_tar).cpu().item() bleu(torch.argmax(y_pred, axis=-1), y_tar) logger.log_examples(x, y_tar, torch.argmax(y_pred, axis=-1), tokenizer) return batch_acc else: return torch.argmax(y_pred, axis=-1) else: # Retrieve the start of sequence token and the target translation y, y_tar = y[:, 0].unsqueeze(-1), y[:, 1:] enc_mask, look_ahead_mask, dec_mask = base_transformer.create_masks( x, y_tar) # devices x, y, y_tar, enc_mask = to_devices((x, y, y_tar, enc_mask), device) # inference model.eval() if beam_length == 1: y_pred = greedy_search(x, y, y_tar, model, enc_mask=enc_mask) else: y_pred = beam_search(x, y, y_tar, model, enc_mask=enc_mask, beam_length=beam_length, alpha=alpha, beta=beta) if not pivot_mode: batch_acc = 0 if bleu is not None: bleu(y_pred, y_tar) logger.log_examples(x, y_tar, y_pred, tokenizer) return batch_acc else: return y_pred