def inference(args, model, eval_dataloader, device, tokenizer):
    model.eval()
    eval_examples = eval_dataloader.dataset.examples
    eval_drop_features = eval_dataloader.dataset.drop_features
    [start_tok_id, end_tok_id] = tokenizer.convert_tokens_to_ids([START_TOK, END_TOK]) # [1030, 1032]
    all_dec_ids, all_label_ids, all_type_preds, all_start_preds, all_end_preds, all_type_logits = [], [], [], [], [], []
    all_input_ids = []
    nb_eval_examples, eval_accuracy, eval_err_sum = 0, 0, 0
    
    for batch in tqdm(eval_dataloader, desc="Inference"):
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, segment_ids, label_ids, head_type, q_spans, p_spans = batch
        
        with torch.no_grad():
            out = model(input_ids, segment_ids, task='inference', max_decoding_steps=eval_dataloader.dataset.max_dec_steps)
            # here segment_ids are only used to get the best span prediction
            dec_preds, type_preds, start_preds, end_preds, type_logits = tuple(t.cpu() for t in out)
            # dec_preds: [bsz, max_deocoding_steps], has start_tok
#           # type_preds, start_preds, end_preds, type_logits : [bsz], [bsz], [bsz], [bsz, 2]
        assert dec_preds.size() == label_ids.size()
        assert dec_preds.dim() == 2
        
#         bch_errs = ((bch_preds != target_ids).float() * (target_ids != IGNORE_IDX).float()).sum(dim=-1)
#         bch_eval_accuracy = (bch_errs == 0).sum().item()
#         eval_accuracy += bch_eval_accuracy
#         eval_err_sum += bch_errs.sum().item() # all errors
        nb_eval_examples += input_ids.size(0)
        all_dec_ids.append(dec_preds); all_label_ids.append(label_ids); all_type_preds.append(type_preds)
        all_start_preds.append(start_preds); all_end_preds.append(end_preds); all_type_logits.append(type_logits)
        all_input_ids.append(input_ids)
        #break
#     eval_accuracy /= nb_eval_examples
#     eval_err_sum /= nb_eval_examples
#     result = {'eval_accuracy': eval_accuracy,
#               'eval_err_sum': eval_err_sum}

#     logger.info("***** Eval results *****")
#     for key in sorted(result.keys()):
#         logger.info("  %s = %s", key, str(result[key]))
    
    tup = all_dec_ids, all_label_ids, all_type_preds, all_start_preds, all_end_preds, all_type_logits, all_input_ids
    all_dec_ids, all_label_ids, all_type_preds, all_start_preds, all_end_preds, all_type_logits, all_input_ids = \
                                                                tuple(torch.cat(t, dim=0).tolist() for t in tup)
    def trim(ids):
        # remove start tok
        ids = ids[1:] if ids[0] == start_tok_id else ids
        # only keep predictions until the first pad/end token
        _ids = []
        for id in ids:
            if id in [IGNORE_IDX, end_tok_id]:
                break
            else:
                _ids.append(id)
        return _ids
    def process(text):
        processed = '.'.join([x.strip() for x in text.split('.')]) # remove space around decimal
        try:
            float(processed)  #'.' is a decimal only if final str is a number
        except ValueError:
            processed = text
        return '-'.join([x.strip() for x in processed.split('-')]) # remove space around "-"
        
    predictions, ems, drop_ems = [], [], []
    for i in range(len(all_dec_ids)):
        example = eval_examples[i]
        drop_feature = eval_drop_features[i]
        answer_text = (SPAN_SEP+' ').join(example.answer_texts).strip().lower()
        processed_answer_text = process(answer_text)
        # generator prediction
        dec_ids = trim(all_dec_ids[i])
        dec_toks = tokenizer.convert_ids_to_tokens(dec_ids)
        dec_text = detokenize(dec_toks)
        dec_processed = process(dec_text)
        # span prediction
        start_pred, end_pred, input_ids = all_start_preds[i], all_end_preds[i], all_input_ids[i]
        [start_pred, end_pred] = sorted([start_pred, end_pred])
        span_ids = [x for x in input_ids[start_pred:end_pred+1] if x != 0]
        span_toks = tokenizer.convert_ids_to_tokens(span_ids)
        span_text = detokenize(span_toks)
        span_processed = process(span_text)
        
        span_pred, used_orig = wrapped_get_final_text(example, drop_feature, start_pred, end_pred)
        if not used_orig:
            span_pred = process(span_pred)
    
        prediction = span_pred if all_type_preds[i] else dec_processed
        head_pred = 'span_extraction' if all_type_preds[i] else 'generator'
        
        # compute drop em and f1
        drp = DropEmAndF1()
        drp(prediction, example.answer_annotations)
        drop_em, drop_f1 = drp.get_metric()
        em = exact_match_score(prediction, processed_answer_text)
        
        predictions.append({'query_id': example.qas_id, 'passage_id':example.passage_id, 
                            'processed_dec_out': dec_processed, 'prediction': prediction, 
                            'ans_used': processed_answer_text, 'type_logits': all_type_logits[i],
                            'head_pred': head_pred, 'processed_span_out': span_processed,
                            'dec_out': dec_text, 'span_out': span_text, 'span_pred': span_pred,
                            'drop_em': drop_em, 'drop_f1': drop_f1, 'em': em})
        ems.append(em); drop_ems.append(drop_em)
        if i < 20:
            print(prediction, processed_answer_text, end=' || ')
    logger.info(f'EM: {np.mean(ems)}, Drop EM: {np.mean(drop_ems)}')
    logger.info('saving predictions.jsonl in ' + args.output_dir)
    os.makedirs(args.output_dir, exist_ok=True)
    write_file(predictions, os.path.join(args.output_dir, "predictions.jsonl"))
示例#2
0
文件: aluqa.py 项目: orperel/ALUQANet
    def __init__(self,
                 vocab: Vocabulary,
                 bert_pretrained_model: str,
                 dropout_prob: float = 0.1,
                 max_count: int = 10,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 answering_abilities: List[str] = None,
                 number_rep: str = 'first',
                 arithmetic: str = 'base',
                 special_numbers: List[int] = None) -> None:
        super().__init__(vocab, regularizer)

        self.ner_tagger = fine_grained_named_entity_recognition_with_elmo_peters_2018(
        )

        if answering_abilities is None:
            self.answering_abilities = [
                "passage_span_extraction", "question_span_extraction",
                "arithmetic", "counting"
            ]
        else:
            self.answering_abilities = answering_abilities
        self.number_rep = number_rep

        self.BERT = BertModel.from_pretrained(bert_pretrained_model)
        self.tokenizer = BertTokenizer.from_pretrained(bert_pretrained_model)
        bert_dim = self.BERT.pooler.dense.out_features

        self.dropout = dropout_prob

        self._passage_weights_predictor = torch.nn.Linear(bert_dim, 1)
        self._question_weights_predictor = torch.nn.Linear(bert_dim, 1)
        self._number_weights_predictor = torch.nn.Linear(bert_dim, 1)
        self._arithmetic_weights_predictor = torch.nn.Linear(bert_dim, 1)

        if len(self.answering_abilities) > 1:
            self._answer_ability_predictor = \
                self.ff(2 * bert_dim, bert_dim, len(self.answering_abilities))

        if "passage_span_extraction" in self.answering_abilities:
            self._passage_span_extraction_index = self.answering_abilities.index(
                "passage_span_extraction")
            self._passage_span_start_predictor = torch.nn.Linear(bert_dim, 1)
            self._passage_span_end_predictor = torch.nn.Linear(bert_dim, 1)

        if "question_span_extraction" in self.answering_abilities:
            self._question_span_extraction_index = self.answering_abilities.index(
                "question_span_extraction")
            self._question_span_start_predictor = \
                self.ff(2 * bert_dim, bert_dim, 1)
            self._question_span_end_predictor = \
                self.ff(2 * bert_dim, bert_dim, 1)

        if "arithmetic" in self.answering_abilities:
            self.arithmetic = arithmetic
            self._arithmetic_index = self.answering_abilities.index(
                "arithmetic")
            self.special_numbers = special_numbers
            self.num_special_numbers = len(self.special_numbers)
            self.special_embedding = torch.nn.Embedding(
                self.num_special_numbers, bert_dim)
            if self.arithmetic == "base":
                self._number_sign_predictor = \
                    self.ff(2 * bert_dim, bert_dim, 3)
            else:
                self.init_arithmetic(bert_dim,
                                     bert_dim,
                                     bert_dim,
                                     layers=2,
                                     dropout=dropout_prob)

        if "counting" in self.answering_abilities:
            self._counting_index = self.answering_abilities.index("counting")
            self._count_number_predictor = \
                self.ff(bert_dim, bert_dim, max_count + 1)

        self._drop_metrics = DropEmAndF1()
        initializer(self)
示例#3
0
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoding_in_dim:int,
                 encoding_out_dim:int,
                 modeling_in_dim:int,
                 modeling_out_dim:int,
                 dropout_prob: float = 0.1,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 external_number: List[int] = None,
                 answering_abilities: List[str] = None) -> None:
        super().__init__(vocab, regularizer)


        #print (vocab)

        if answering_abilities is None:
            self.answering_abilities = ["span_extraction",
                                        "addition_subtraction", "counting"]
        else:
            self.answering_abilities = answering_abilities


        self.W = torch.nn.Linear(768*2,768)
        
        
        text_embed_dim = text_field_embedder.get_output_dim()
        
        self._text_field_embedder = text_field_embedder

        #self._embedding_proj_layer = torch.nn.Linear(text_embed_dim, encoding_in_dim)

        """
            为了用于self attention
        """

        if len(self.answering_abilities) > 1:
            self._answer_ability_predictor = FeedForward(text_embed_dim,
                                                         activations=[Activation.by_name('relu')(inplace=True),
                                                                      Activation.by_name('linear')()],
                                                         hidden_dims=[encoding_out_dim,
                                                                      len(self.answering_abilities)],
                                                         num_layers=2,
                                                         dropout=dropout_prob)

        if "span_extraction" in self.answering_abilities:
            self._span_extraction_index = self.answering_abilities.index("span_extraction")
            self._span_start_predictor = FeedForward(text_embed_dim,
                                                      activations=[Activation.by_name('relu')(inplace=True),
                                                                   Activation.by_name('linear')()],
                                                      hidden_dims=[encoding_out_dim,1],
                                                      num_layers=2,
                                                      dropout=dropout_prob)
            self._span_end_predictor = FeedForward(text_embed_dim ,
                                                      activations=[Activation.by_name('relu')(inplace=True),
                                                                   Activation.by_name('linear')()],
                                                      hidden_dims=[encoding_out_dim,1],
                                                      num_layers=2,
                                                      dropout=dropout_prob)


        if "addition_subtraction" in self.answering_abilities:
            self._addition_subtraction_index = self.answering_abilities.index("addition_subtraction")
            self._number_sign_predictor = FeedForward(text_embed_dim*2,
                                                      activations=[Activation.by_name('relu')(inplace=True),
                                                                   Activation.by_name('linear')()],
                                                      hidden_dims=[encoding_out_dim,3],
                                                      num_layers=2,
                                                      dropout=dropout_prob)

        if "counting" in self.answering_abilities:
            self._counting_index = self.answering_abilities.index("counting")
            self._count_number_predictor = FeedForward(text_embed_dim,
                                                       activations=[Activation.by_name('relu')(inplace=True),
                                                                    Activation.by_name('linear')()],
                                                       hidden_dims=[encoding_out_dim, 10],
                                                       num_layers=2,
                                                       dropout=dropout_prob)
        



        self._drop_metrics = DropEmAndF1()
        self._dropout = torch.nn.Dropout(p=dropout_prob)

        initializer(self)
示例#4
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 matrix_attention_layer: MatrixAttention,
                 modeling_layer: Seq2SeqEncoder,
                 dropout_prob: float = 0.1,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 answering_abilities: List[str] = None,
                 gnn_steps: int = 1) -> None:
        super().__init__(vocab, regularizer)

        if answering_abilities is None:
            self.answering_abilities = [
                "passage_span_extraction", "question_span_extraction",
                "addition_subtraction", "counting"
            ]
        else:
            self.answering_abilities = answering_abilities

        text_embed_dim = text_field_embedder.get_output_dim()
        encoding_in_dim = phrase_layer.get_input_dim()
        encoding_out_dim = phrase_layer.get_output_dim()
        modeling_in_dim = modeling_layer.get_input_dim()
        modeling_out_dim = modeling_layer.get_output_dim()

        self._text_field_embedder = text_field_embedder

        self._embedding_proj_layer = torch.nn.Linear(text_embed_dim,
                                                     encoding_in_dim)
        self._highway_layer = Highway(encoding_in_dim, num_highway_layers)

        self._encoding_proj_layer = torch.nn.Linear(encoding_in_dim,
                                                    encoding_in_dim)
        self._phrase_layer = phrase_layer

        self._matrix_attention = matrix_attention_layer

        self._modeling_proj_layer = torch.nn.Linear(encoding_out_dim * 4,
                                                    modeling_in_dim)
        self._modeling_layer = modeling_layer

        self._modeling_layer_0 = copy.deepcopy(modeling_layer)

        self._passage_weights_predictor = torch.nn.Linear(modeling_out_dim, 1)
        self._question_weights_predictor = torch.nn.Linear(encoding_out_dim, 1)

        self._passage_weights_predictor_0_for_100 = torch.nn.Linear(
            modeling_out_dim, 1)
        self._passage_weights_predictor_3_for_100 = torch.nn.Linear(
            modeling_out_dim, 1)

        if len(self.answering_abilities) > 1:
            self._answer_ability_predictor = FeedForward(
                modeling_out_dim + encoding_out_dim,
                activations=[
                    Activation.by_name('relu')(),
                    Activation.by_name('linear')()
                ],
                hidden_dims=[modeling_out_dim,
                             len(self.answering_abilities)],
                num_layers=2,
                dropout=dropout_prob)

        if "passage_span_extraction" in self.answering_abilities:
            self._passage_span_extraction_index = self.answering_abilities.index(
                "passage_span_extraction")
            self._passage_span_start_predictor = FeedForward(
                modeling_out_dim * 2,
                activations=[
                    Activation.by_name('relu')(),
                    Activation.by_name('linear')()
                ],
                hidden_dims=[modeling_out_dim, 1],
                num_layers=2)
            self._passage_span_end_predictor = FeedForward(
                modeling_out_dim * 2,
                activations=[
                    Activation.by_name('relu')(),
                    Activation.by_name('linear')()
                ],
                hidden_dims=[modeling_out_dim, 1],
                num_layers=2)

        if "question_span_extraction" in self.answering_abilities:
            self._question_span_extraction_index = self.answering_abilities.index(
                "question_span_extraction")
            self._question_span_start_predictor = FeedForward(
                modeling_out_dim * 2,
                activations=[
                    Activation.by_name('relu')(),
                    Activation.by_name('linear')()
                ],
                hidden_dims=[modeling_out_dim, 1],
                num_layers=2)
            self._question_span_end_predictor = FeedForward(
                modeling_out_dim * 2,
                activations=[
                    Activation.by_name('relu')(),
                    Activation.by_name('linear')()
                ],
                hidden_dims=[modeling_out_dim, 1],
                num_layers=2)

        if "addition_subtraction" in self.answering_abilities:
            self._addition_subtraction_index = self.answering_abilities.index(
                "addition_subtraction")
            self._number_sign_predictor = FeedForward(
                modeling_out_dim * 3,
                activations=[
                    Activation.by_name('relu')(),
                    Activation.by_name('linear')()
                ],
                hidden_dims=[modeling_out_dim, 3],
                num_layers=2)

        if "counting" in self.answering_abilities:
            self._counting_index = self.answering_abilities.index("counting")
            self._count_number_predictor = FeedForward(
                modeling_out_dim,
                activations=[
                    Activation.by_name('relu')(),
                    Activation.by_name('linear')()
                ],
                hidden_dims=[modeling_out_dim, 10],
                num_layers=2)

        self._drop_metrics = DropEmAndF1()
        self._dropout = torch.nn.Dropout(p=dropout_prob)

        node_dim = modeling_out_dim
        self._gnn = NumGNN(node_dim=node_dim, iteration_steps=gnn_steps)
        self._proj_fc = torch.nn.Linear(node_dim * 2, node_dim, bias=True)

        initializer(self)
    def __init__(self, 
                 vocab: Vocabulary, 
                 bert_pretrained_model: str, 
                 dropout_prob: float = 0.1, 
                 max_count: int = 10,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 answering_abilities: List[str] = None,
                 number_rep: str = 'first',
                 special_numbers : List[int] = None) -> None:
        super().__init__(vocab, regularizer)

        if answering_abilities is None:
            self.answering_abilities = ["passage_span_extraction", "question_span_extraction",
                                        "arithmetic", "counting"]
        else:
            self.answering_abilities = answering_abilities
        self.number_rep = number_rep
        
        self.BERT = BertModel.from_pretrained(bert_pretrained_model)
        self.tokenizer = BertTokenizer.from_pretrained(bert_pretrained_model)
        bert_dim = self.BERT.pooler.dense.out_features
        
        self.dropout = dropout_prob

        self._passage_weights_predictor = torch.nn.Linear(bert_dim, 1)
        self._question_weights_predictor = torch.nn.Linear(bert_dim, 1)
        self._number_weights_predictor = torch.nn.Linear(bert_dim, 1)
            
        if len(self.answering_abilities) > 1:
            self._answer_ability_predictor = \
                self.ff(2 * bert_dim, bert_dim, len(self.answering_abilities))

        if "passage_span_extraction" in self.answering_abilities:
            self._passage_span_extraction_index = self.answering_abilities.index("passage_span_extraction")
            self._passage_span_start_predictor = torch.nn.Linear(bert_dim, 1)
            self._passage_span_end_predictor = torch.nn.Linear(bert_dim, 1)

        if "question_span_extraction" in self.answering_abilities:
            self._question_span_extraction_index = self.answering_abilities.index("question_span_extraction")
            self._question_span_start_predictor = \
                self.ff(2 * bert_dim, bert_dim, 1)
            self._question_span_end_predictor = \
                self.ff(2 * bert_dim, bert_dim, 1)
            self._qspan_passage_weight_predictor = torch.nn.Linear(bert_dim, 1)

        if "arithmetic" in self.answering_abilities:
            self._arithmetic_index = self.answering_abilities.index("arithmetic")
            self.special_numbers = special_numbers
            self.num_special_numbers = len(self.special_numbers)
            self.special_embedding = torch.nn.Embedding(self.num_special_numbers, bert_dim)
            self.num_arithmetic_templates = 5
            self.num_template_slots = 3
            self._arithmetic_template_predictor = self.ff(2 * bert_dim, bert_dim, self.num_arithmetic_templates)
            self._arithmetic_template_slot_predictor = \
                torch.nn.Linear(2 * bert_dim, self.num_arithmetic_templates * self.num_template_slots)
            
            self._arithmetic_passage_weight_predictor = torch.nn.Linear(bert_dim, 1)
            self._arithmetic_question_weight_predictor = torch.nn.Linear(bert_dim, 1)
            
            self.templates = [lambda x,y,z: (x + y) * z,
                              lambda x,y,z: (x - y) * z,
                              lambda x,y,z: (x + y) / z,
                              lambda x,y,z: (x - y) / z,
                              lambda x,y,z: x * y / z]

        if "counting" in self.answering_abilities:
            self._counting_index = self.answering_abilities.index("counting")
            self._count_number_predictor = \
                self.ff(bert_dim, bert_dim, max_count + 1) 
            self._count_passage_weight_predictor = torch.nn.Linear(bert_dim, 1)

        self._drop_metrics = DropEmAndF1()
        initializer(self)
    def __init__(self,
                 vocab: Vocabulary,
                 experimental_mode: str,
                 bert_pretrained_model: str,
                 dropout_prob: float = 0.1,
                 entropy_loss_weight: float = 10.0,
                 max_count: int = 10,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 answering_abilities: List[str] = None,
                 number_rep: str = 'first',
                 arithmetic: str = 'base',
                 special_numbers : List[int] = None) -> None:
        super().__init__(vocab, regularizer)

        # Options:
        # 'original'
        # 'expected_count_per_sentence'
        self.experimental_mode = experimental_mode
        self.entropy_loss_weight = entropy_loss_weight

        if answering_abilities is None:
            self.answering_abilities = ["passage_span_extraction", "question_span_extraction",
                                        "arithmetic", "counting"]
        else:
            self.answering_abilities = answering_abilities
        self.number_rep = number_rep

        self.BERT = BertModel.from_pretrained(bert_pretrained_model)
        self.tokenizer = BertTokenizer.from_pretrained(bert_pretrained_model)
        bert_dim = self.BERT.pooler.dense.out_features

        self.dropout = dropout_prob

        self._passage_weights_predictor = torch.nn.Linear(bert_dim, 1)
        self._question_weights_predictor = torch.nn.Linear(bert_dim, 1)
        self._number_weights_predictor = torch.nn.Linear(bert_dim, 1)
        self._arithmetic_weights_predictor = torch.nn.Linear(bert_dim, 1)
        self._sentence_weights_predictor = torch.nn.Linear(bert_dim, 1)

        if len(self.answering_abilities) > 1:
            self._answer_ability_predictor = \
                self.ff(2 * bert_dim, bert_dim, len(self.answering_abilities))

        if "passage_span_extraction" in self.answering_abilities:
            self._passage_span_extraction_index = self.answering_abilities.index("passage_span_extraction")
            self._passage_span_start_predictor = torch.nn.Linear(bert_dim, 1)
            self._passage_span_end_predictor = torch.nn.Linear(bert_dim, 1)

        if "question_span_extraction" in self.answering_abilities:
            self._question_span_extraction_index = self.answering_abilities.index("question_span_extraction")
            self._question_span_start_predictor = \
                self.ff(2 * bert_dim, bert_dim, 1)
            self._question_span_end_predictor = \
                self.ff(2 * bert_dim, bert_dim, 1)

        if "arithmetic" in self.answering_abilities:
            self.arithmetic = arithmetic
            self._arithmetic_index = self.answering_abilities.index("arithmetic")
            self.special_numbers = special_numbers
            self.num_special_numbers = len(self.special_numbers)
            self.special_embedding = torch.nn.Embedding(self.num_special_numbers, bert_dim)
            if self.arithmetic == "base":
                self._number_sign_predictor = \
                    self.ff(2 * bert_dim, bert_dim, 3)
            else:
                self.init_arithmetic(bert_dim, bert_dim, bert_dim, layers=2, dropout=dropout_prob)

        if "counting" in self.answering_abilities:
            self._counting_index = self.answering_abilities.index("counting")

            if self.experimental_mode == 'original':
                self._count_number_predictor = \
                    self.ff(bert_dim, bert_dim, max_count + 1)

            elif self.experimental_mode == 'expected_count_per_sentence':
                self._count_number_predictor = \
                    self.ff(2 * bert_dim, bert_dim, max_count + 1)
                self.count_classes = torch.arange(max_count + 1).float()

        self._drop_metrics = DropEmAndF1()
        initializer(self)