def ids_to_clean_text( model, tokenizer, generated_ids, ): gen_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) return lmap(str.strip, gen_text)
def val_routine_MT(model, tokenizer, val_loader, loss_fn, device, val_stepcount, return_val_predictions=False, limit_val_batches=None): val_losses = list() gen_times = list() summ_lens = list() bleus = list() predictions = list() targets = list() for batch in val_loader: batch = {k: v.to(device) for k, v in batch.items()} lm_labels = batch["labels"].clone() # s1. forward with torch.no_grad(): t0 = time.time() generated_ids = model.generate( batch["input_ids"], attention_mask=batch["attention_mask"], use_cache=True) gen_time = (time.time() - t0) / batch["input_ids"].shape[0] # todo: cambiar esta rutina... batch.decode no hace el truco para oraciones en el src-lang preds: List[str] = ids_to_clean_text(model, tokenizer, generated_ids) target: List[str] = ids_to_clean_text(model, tokenizer, batch["labels"]) outputs = model(return_dict=True, **batch) #lm_logits = outputs['logits'] #assert lm_logits.shape[-1] == model.config.vocab_size # s2. compute objective fn loss = outputs['loss'] #loss = loss_fn(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1) ) summ_len = np.mean(lmap(len, generated_ids)) bleu = calculate_bleu(preds, target) val_losses.append(loss.item()) gen_times.append(gen_time) summ_lens.append(summ_len) bleus.append(bleu['bleu']) predictions += preds targets += target if isinstance(limit_val_batches, int) and len(val_losses) > limit_val_batches: print(f'\nreached limit_val_batches={limit_val_batches}') break base_metrics = { 'STEP': val_stepcount, 'mean_val_bleu': np.mean(bleus).round(decimals=8) } base_metrics.update( mean_val_loss=np.mean(val_losses).round(decimals=8), mean_gen_time=np.mean(gen_times).round(decimals=3), mean_gen_len=np.mean(summ_lens).round(decimals=3), ) if return_val_predictions: base_metrics.update(preds=predictions, target=targets) #print(f'example of predictions on this validation loop: \n{predictions[:2]}') #print(f'targets of those predicted sentences: \n{targets[:2]}') return base_metrics