def _bert_score_ddp(rank, world_size, preds, targets, original_score): """Define a DDP process for BERTScore.""" os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" dist.init_process_group("gloo", rank=rank, world_size=world_size) scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3, max_length=128) scorer.update(preds, targets) metrics_score = scorer.compute() for metric in _METRICS: _assert_list(metrics_score[metric], original_score[metric]) dist.destroy_process_group()
def test_score_all_layers_with_idf(preds, targets): """Tests for metric and all layers with IDF rescaling.""" original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, all_layers=True, idf=True, batch_size=3) original_score = _parse_original_bert_score(original_score) scorer = BERTScore(model_name_or_path=MODEL_NAME, all_layers=True, idf=True, batch_size=3) scorer.update(preds=preds, target=targets) metrics_score = scorer.compute() for metric in _METRICS: _assert_list(metrics_score[metric], original_score[metric])
def test_score(preds, targets): """Tests for metric.""" original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3) original_score = _parse_original_bert_score(original_score) scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3) scorer.update(preds=preds, target=targets) metrics_score = scorer.compute() for metric in _METRICS: _assert_list(metrics_score[metric], original_score[metric])
def test_accumulation(preds, targets): """Tests for metric works with accumulation.""" original_score = original_bert_score(sum(preds, []), sum(targets, []), model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3) original_score = _parse_original_bert_score(original_score) scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3) for p, r in zip(preds, targets): scorer.update(preds=p, target=r) metrics_score = scorer.compute() for metric in _METRICS: _assert_list(metrics_score[metric], original_score[metric])
input/output argument structure described below. Args: model: `torch.nn.Module` batch: `Dict[str, torch.Tensor]` Return: The model output. `torch.Tensor` """ return model(batch["input_ids"]) _PREDS = ["hello", "hello world", "world world world"] _REFS = ["hello", "hello hello", "hello world hello"] if __name__ == "__main__": tokenizer = UserTokenizer() model = get_user_model_encoder() bs = BERTScore(model=model, user_tokenizer=tokenizer, user_forward_fn=user_forward_fn, max_length=_MAX_LEN, return_hash=False) bs.update(_PREDS, _REFS) print(f"Predictions:\n {bs.preds_input_ids}\n {bs.preds_attention_mask}") pprint(bs.compute())