def create_and_check_flaubert_qa( self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, choice_labels, input_mask, ): model = FlaubertForQuestionAnswering(config) model.to(torch_device) model.eval() result = model(input_ids) result_with_labels = model( input_ids, start_positions=sequence_labels, end_positions=sequence_labels, cls_index=sequence_labels, is_impossible=is_impossible_labels, p_mask=input_mask, ) result_with_labels = model( input_ids, start_positions=sequence_labels, end_positions=sequence_labels, cls_index=sequence_labels, is_impossible=is_impossible_labels, ) (total_loss, ) = result_with_labels.to_tuple() result_with_labels = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels) (total_loss, ) = result_with_labels.to_tuple() self.parent.assertEqual(result_with_labels.loss.shape, ()) self.parent.assertEqual(result.start_top_log_probs.shape, (self.batch_size, model.config.start_n_top)) self.parent.assertEqual(result.start_top_index.shape, (self.batch_size, model.config.start_n_top)) self.parent.assertEqual(result.end_top_log_probs.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)) self.parent.assertEqual(result.end_top_index.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)) self.parent.assertEqual(result.cls_logits.shape, (self.batch_size, ))
def create_and_check_flaubert_qa( self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask, ): model = FlaubertForQuestionAnswering(config) model.to(torch_device) model.eval() outputs = model(input_ids) start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = outputs outputs = model( input_ids, start_positions=sequence_labels, end_positions=sequence_labels, cls_index=sequence_labels, is_impossible=is_impossible_labels, p_mask=input_mask, ) outputs = model( input_ids, start_positions=sequence_labels, end_positions=sequence_labels, cls_index=sequence_labels, is_impossible=is_impossible_labels, ) (total_loss, ) = outputs outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels) (total_loss, ) = outputs result = { "loss": total_loss, "start_top_log_probs": start_top_log_probs, "start_top_index": start_top_index, "end_top_log_probs": end_top_log_probs, "end_top_index": end_top_index, "cls_logits": cls_logits, } self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual( list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]) self.parent.assertListEqual( list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top]) self.parent.assertListEqual( list(result["end_top_log_probs"].size()), [ self.batch_size, model.config.start_n_top * model.config.end_n_top ], ) self.parent.assertListEqual( list(result["end_top_index"].size()), [ self.batch_size, model.config.start_n_top * model.config.end_n_top ], ) self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size])