def forward(self, batch, task, tokenizer, compute_loss: bool = False): with transformer_utils.output_hidden_states_context(self.encoder): encoder_output = get_output_from_encoder_and_batch( encoder=self.encoder, batch=batch) # A tuple of layers of hidden states hidden_states = take_one(encoder_output.other) layer_hidden_states = hidden_states[self.layer] if isinstance(self.pooler_head, heads.MeanPoolerHead): logits = self.pooler_head(unpooled=layer_hidden_states, input_mask=batch.input_mask) elif isinstance(self.pooler_head, heads.FirstPoolerHead): logits = self.pooler_head(layer_hidden_states) else: raise TypeError(type(self.pooler_head)) # TODO: Abuse of notation - these aren't really logits (issue #1187) if compute_loss: # TODO: make this optional? (issue #1187) return LogitsAndLossOutput( logits=logits, loss=torch.tensor([0.0]), # This is a horrible hack other=encoder_output.other, ) else: return LogitsOutput(logits=logits, other=encoder_output.other)
def read_examples(self, path, set_type): input_data = read_json(path, encoding="utf-8")["data"] is_training = set_type == PHASE.TRAIN examples = [] data = take_one(input_data) for paragraph in maybe_tqdm(data["paragraphs"]): for qa in paragraph["qas"]: qas_id = qa["id"] # Because answers can also come from questions, we're going to abuse notation # slightly and put the entire background+situation+question into the "context" # and leave nothing for the "question" question_text = " " if self.include_background: context_segments = [ paragraph["background"], paragraph["situation"], qa["question"], ] else: context_segments = [paragraph["situation"], qa["question"]] full_context = " ".join(segment.strip() for segment in context_segments) if is_training: answer = qa["answers"][0] start_position_character = full_context.find( answer["text"]) answer_text = answer["text"] answers = [] else: start_position_character = None answer_text = None answers = [{ "text": answer["text"], "answer_start": full_context.find(answer["text"]) } for answer in qa["answers"]] example = Example( qas_id=qas_id, question_text=question_text, context_text=full_context, answer_text=answer_text, start_position_character=start_position_character, title="", is_impossible=False, answers=answers, background_text=paragraph["background"], situation_text=paragraph["situation"], ) examples.append(example) return examples
def test_take_one(): assert py_datastructures.take_one([9]) == 9 assert py_datastructures.take_one((9, )) == 9 assert py_datastructures.take_one({9}) == 9 assert py_datastructures.take_one("9") == "9" assert py_datastructures.take_one({9: 10}) == 9 with pytest.raises(IndexError): py_datastructures.take_one([]) with pytest.raises(IndexError): py_datastructures.take_one([1, 2]) with pytest.raises(IndexError): py_datastructures.take_one("2342134")