예제 #1
0
    def forward(self, batch, task, tokenizer, compute_loss: bool = False):
        with transformer_utils.output_hidden_states_context(self.encoder):
            encoder_output = get_output_from_encoder_and_batch(
                encoder=self.encoder, batch=batch)
        # A tuple of layers of hidden states
        hidden_states = take_one(encoder_output.other)
        layer_hidden_states = hidden_states[self.layer]

        if isinstance(self.pooler_head, heads.MeanPoolerHead):
            logits = self.pooler_head(unpooled=layer_hidden_states,
                                      input_mask=batch.input_mask)
        elif isinstance(self.pooler_head, heads.FirstPoolerHead):
            logits = self.pooler_head(layer_hidden_states)
        else:
            raise TypeError(type(self.pooler_head))

        # TODO: Abuse of notation - these aren't really logits  (issue #1187)
        if compute_loss:
            # TODO: make this optional?   (issue #1187)
            return LogitsAndLossOutput(
                logits=logits,
                loss=torch.tensor([0.0]),  # This is a horrible hack
                other=encoder_output.other,
            )
        else:
            return LogitsOutput(logits=logits, other=encoder_output.other)
예제 #2
0
파일: ropes.py 프로젝트: v-mipeng/jiant
    def read_examples(self, path, set_type):
        input_data = read_json(path, encoding="utf-8")["data"]

        is_training = set_type == PHASE.TRAIN
        examples = []
        data = take_one(input_data)
        for paragraph in maybe_tqdm(data["paragraphs"]):
            for qa in paragraph["qas"]:
                qas_id = qa["id"]
                # Because answers can also come from questions, we're going to abuse notation
                #   slightly and put the entire background+situation+question into the "context"
                #   and leave nothing for the "question"
                question_text = " "
                if self.include_background:
                    context_segments = [
                        paragraph["background"],
                        paragraph["situation"],
                        qa["question"],
                    ]
                else:
                    context_segments = [paragraph["situation"], qa["question"]]
                full_context = " ".join(segment.strip()
                                        for segment in context_segments)

                if is_training:
                    answer = qa["answers"][0]
                    start_position_character = full_context.find(
                        answer["text"])
                    answer_text = answer["text"]
                    answers = []
                else:
                    start_position_character = None
                    answer_text = None
                    answers = [{
                        "text":
                        answer["text"],
                        "answer_start":
                        full_context.find(answer["text"])
                    } for answer in qa["answers"]]

                example = Example(
                    qas_id=qas_id,
                    question_text=question_text,
                    context_text=full_context,
                    answer_text=answer_text,
                    start_position_character=start_position_character,
                    title="",
                    is_impossible=False,
                    answers=answers,
                    background_text=paragraph["background"],
                    situation_text=paragraph["situation"],
                )
                examples.append(example)
        return examples
예제 #3
0
def test_take_one():
    assert py_datastructures.take_one([9]) == 9
    assert py_datastructures.take_one((9, )) == 9
    assert py_datastructures.take_one({9}) == 9
    assert py_datastructures.take_one("9") == "9"
    assert py_datastructures.take_one({9: 10}) == 9

    with pytest.raises(IndexError):
        py_datastructures.take_one([])
    with pytest.raises(IndexError):
        py_datastructures.take_one([1, 2])
    with pytest.raises(IndexError):
        py_datastructures.take_one("2342134")