def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): model = XLMForQuestionAnswering(config) model.eval() outputs = model(input_ids) start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels, cls_index=sequence_labels, is_impossible=is_impossible_labels, p_mask=input_mask) outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels, cls_index=sequence_labels, is_impossible=is_impossible_labels) (total_loss, ) = outputs outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels) (total_loss, ) = outputs result = { "loss": total_loss, "start_top_log_probs": start_top_log_probs, "start_top_index": start_top_index, "end_top_log_probs": end_top_log_probs, "end_top_index": end_top_index, "cls_logits": cls_logits, } self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual( list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]) self.parent.assertListEqual( list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top]) self.parent.assertListEqual( list(result["end_top_log_probs"].size()), [ self.batch_size, model.config.start_n_top * model.config.end_n_top ]) self.parent.assertListEqual(list(result["end_top_index"].size()), [ self.batch_size, model.config.start_n_top * model.config.end_n_top ]) self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size])
def create_and_check_xlm_qa( self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, choice_labels, input_mask, ): model = XLMForQuestionAnswering(config) model.to(torch_device) model.eval() result = model(input_ids) result_with_labels = model( input_ids, start_positions=sequence_labels, end_positions=sequence_labels, cls_index=sequence_labels, is_impossible=is_impossible_labels, p_mask=input_mask, ) result_with_labels = model( input_ids, start_positions=sequence_labels, end_positions=sequence_labels, cls_index=sequence_labels, is_impossible=is_impossible_labels, ) (total_loss, ) = result_with_labels.to_tuple() result_with_labels = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels) (total_loss, ) = result_with_labels.to_tuple() self.parent.assertEqual(result_with_labels.loss.shape, ()) self.parent.assertEqual(result.start_top_log_probs.shape, (self.batch_size, model.config.start_n_top)) self.parent.assertEqual(result.start_top_index.shape, (self.batch_size, model.config.start_n_top)) self.parent.assertEqual(result.end_top_log_probs.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)) self.parent.assertEqual(result.end_top_index.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)) self.parent.assertEqual(result.cls_logits.shape, (self.batch_size, ))