def test_explain_model_BERT_seq_classification(self):
     device = torch.device(
         "cpu" if not torch.cuda.is_available() else "cuda")
     mnli_test_dataset = get_mnli_test_dataset()
     text = "rare bird has more than enough charm to make it memorable."
     model = get_bert_model()
     model.to(device)
     interpreter_unified = UnifiedInformationExplainer(
         model=model,
         train_dataset=list(mnli_test_dataset),
         device=device,
         target_layer=14,
     )
     explanation_unified = interpreter_unified.explain_local(text)
     valid_imp_vals = np.array([
         0.16004231572151184, 0.17308972775936127, 0.18205846846103668,
         0.26146841049194336, 0.25957807898521423, 0.3549807369709015,
         0.23873654007911682, 0.2826242744922638, 0.2700383961200714,
         0.3673151433467865, 0.3899800479412079, 0.20173774659633636
     ])
     print(explanation_unified.local_importance_values)
     local_importance_values = np.array(
         explanation_unified.local_importance_values)
     cos_sim = dot(valid_imp_vals, local_importance_values) / (
         norm(valid_imp_vals) * norm(local_importance_values))
     assert cos_sim >= 0.80
예제 #2
0
 def test_make_bert_embeddings(self):
     model = get_bert_model()
     device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
     train_dataset = get_mnli_test_dataset("train")
     train_dataset = list(train_dataset[TEXT_COL])
     training_embeddings = make_bert_embeddings(train_dataset, model, device)
     assert training_embeddings is not None
 def test_get_single_embedding(self):
     model = get_bert_model()
     device = torch.device(
         "cpu" if not torch.cuda.is_available() else "cuda")
     text = "rare bird has more than enough charm to make it memorable."
     embedded_input = get_single_embedding(model, text, device)
     assert embedded_input is not None