Exemplo n.º 1
0
def ids_to_clean_text(
    model,
    tokenizer,
    generated_ids,
):

    gen_text = tokenizer.batch_decode(generated_ids,
                                      skip_special_tokens=True,
                                      clean_up_tokenization_spaces=True)

    return lmap(str.strip, gen_text)
Exemplo n.º 2
0
def val_routine_MT(model,
                   tokenizer,
                   val_loader,
                   loss_fn,
                   device,
                   val_stepcount,
                   return_val_predictions=False,
                   limit_val_batches=None):
    val_losses = list()
    gen_times = list()
    summ_lens = list()
    bleus = list()
    predictions = list()
    targets = list()
    for batch in val_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        lm_labels = batch["labels"].clone()

        # s1. forward
        with torch.no_grad():
            t0 = time.time()
            generated_ids = model.generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                use_cache=True)
            gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
            # todo: cambiar esta rutina... batch.decode no hace el truco para oraciones en el src-lang
            preds: List[str] = ids_to_clean_text(model, tokenizer,
                                                 generated_ids)
            target: List[str] = ids_to_clean_text(model, tokenizer,
                                                  batch["labels"])

            outputs = model(return_dict=True, **batch)
            #lm_logits = outputs['logits']
            #assert lm_logits.shape[-1] == model.config.vocab_size

            # s2. compute objective fn
            loss = outputs['loss']
            #loss = loss_fn(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1) )

        summ_len = np.mean(lmap(len, generated_ids))
        bleu = calculate_bleu(preds, target)
        val_losses.append(loss.item())
        gen_times.append(gen_time)
        summ_lens.append(summ_len)
        bleus.append(bleu['bleu'])
        predictions += preds
        targets += target

        if isinstance(limit_val_batches,
                      int) and len(val_losses) > limit_val_batches:
            print(f'\nreached limit_val_batches={limit_val_batches}')
            break

    base_metrics = {
        'STEP': val_stepcount,
        'mean_val_bleu': np.mean(bleus).round(decimals=8)
    }
    base_metrics.update(
        mean_val_loss=np.mean(val_losses).round(decimals=8),
        mean_gen_time=np.mean(gen_times).round(decimals=3),
        mean_gen_len=np.mean(summ_lens).round(decimals=3),
    )
    if return_val_predictions:
        base_metrics.update(preds=predictions, target=targets)

        #print(f'example of predictions on this validation loop: \n{predictions[:2]}')
        #print(f'targets of those predicted sentences: \n{targets[:2]}')

    return base_metrics