Ejemplo n.º 1
0
class AlbertQA:
    """ Class to use Albert to answer questions.
    TODO: Update model and checkpoints to work with last versions of transformers """

    def __init__(self, path: str, device: str = 'cpu'):
        """ Init the QA Albert """
        if not os.path.exists(path):
            raise NotADirectoryError(
                f"{os.path.abspath(path)} must be a directory containing the model files: config, tokenizer, weights.")

        files = os.listdir(path)
        if CONFIG_JSON_FILE not in files:
            raise FileNotFoundError(f"{CONFIG_JSON_FILE} must be in {path}.")
        if WEIGHTS_FILE not in files:
            raise FileNotFoundError(f"{WEIGHTS_FILE} must be in {path}.")

        with open(os.path.join(path, CONFIG_JSON_FILE), "r") as f:
            config = json.load(f)
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        weights = torch.load(os.path.join(path, WEIGHTS_FILE),
                                  map_location=lambda storage, loc: storage)
        # Load pretrained model/tokenizer
        config = AlbertConfig.from_dict(config)
        self.model = AlbertForQuestionAnswering(config)
        self.model.load_state_dict(weights)
        self.model = self.model.eval()
        self.args = albert_args_squad
        if device == "cuda":
            logger.debug("Setting model with CUDA")
            self.args['device'] = 'cuda'
            self.model.to('cuda')

    def answer(self, question: str, context: str, **kwargs: dict) -> str:
        """ Look the answer to question in context

        Keyword Arguments:
             :param question: Question to answer
             :param context: Context to look for the answer into
             :return: Answer to question
        """
        for key in kwargs:
            if key in self.args:
                self.args[key] = kwargs[key]
        inputs = self.tokenizer.encode_plus(question, context, **self.args)
        for key in inputs.keys():
            inputs[key] = inputs[key].to(self.args['device'])
        input_ids = inputs["input_ids"].tolist()[0]

        answer_start_scores, answer_end_scores = self.model(**inputs)

        answer_start = torch.argmax(answer_start_scores)  # Get the most likely beginning of answer
        answer_end = torch.argmax(answer_end_scores) + 1  # Get the most likely end of answer

        answer = self.tokenizer.convert_tokens_to_string(
            self.tokenizer.convert_ids_to_tokens(
                input_ids[answer_start:answer_end]
            )
        )
        answer = answer.replace("[CLS]", "").replace("[SEP]", " ").replace("<s>", "").replace("</s>", "")
        return answer
 def create_and_check_for_question_answering(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     model = AlbertForQuestionAnswering(config=config)
     model.to(torch_device)
     model.eval()
     result = model(
         input_ids,
         attention_mask=input_mask,
         token_type_ids=token_type_ids,
         start_positions=sequence_labels,
         end_positions=sequence_labels,
     )
     self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
     self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
Ejemplo n.º 3
0
 def create_and_check_albert_for_question_answering(
         self, config, input_ids, token_type_ids, input_mask,
         sequence_labels, token_labels, choice_labels):
     model = AlbertForQuestionAnswering(config=config)
     model.to(torch_device)
     model.eval()
     result = model(
         input_ids,
         attention_mask=input_mask,
         token_type_ids=token_type_ids,
         start_positions=sequence_labels,
         end_positions=sequence_labels,
     )
     self.parent.assertListEqual(list(result["start_logits"].size()),
                                 [self.batch_size, self.seq_length])
     self.parent.assertListEqual(list(result["end_logits"].size()),
                                 [self.batch_size, self.seq_length])
     self.check_loss_output(result)