def test(self, input_filepath, args=QATestArgs()): """ Tests the question answering model. Used to obtain results input_filepath: a string that contains the location of a csv file for training. Contains the following header values: context, question args: Either a QATestArgs() object or a dictionary that contains all of the same keys as ARGS_QA_TEST return: A list of dictionaries. Each dictionary contains the keys: "score", "start", "end" and "answer" """ if type(args) == dict: method_dataclass_args = create_args_dataclass( default_dic_args=ARGS_QA_TEST, input_dic_args=args, method_dataclass_args=QATestArgs) elif type(args) == QATestArgs: method_dataclass_args = args else: raise ValueError( "Invalid args type. Use a QATestArgs object or a dictionary") return self._trainer.test(input_filepath=input_filepath, solve=self.answer_question, dataclass_args=method_dataclass_args)
def train(self, input_filepath, args=QATrainArgs()): """ Trains the question answering model input_filepath: a string that contains the location of a csv file for training. Contains the following header values: context, question, answer_text, answer_start args: Either a QATrainArgs() object or a dictionary that contains all of the same keys as ARGS_QA_TRAIN return: None """ if type(args) == dict: method_dataclass_args = create_args_dataclass( default_dic_args=ARGS_QA_TRAIN, input_dic_args=args, method_dataclass_args=QATrainArgs) elif type(args) == QATrainArgs: method_dataclass_args = args else: raise ValueError( "Invalid args type. Use a QATrainArgs object or a dictionary") self._trainer.train(input_filepath=input_filepath, dataclass_args=method_dataclass_args)
def eval(self, input_filepath, args=QAEvalArgs()) -> EvalResult: """ Trains the question answering model input_filepath: a string that contains the location of a csv file for training. Contains the following header values: context, question, answer_text, answer_start args: Either a QAEvalArgs() object or a dictionary that contains all of the same keys as ARGS_QA_EVAl return: A dictionary that contains a key called "eval_loss" """ if type(args) == dict: method_dataclass_args = create_args_dataclass( default_dic_args=ARGS_QA_EVAl, input_dic_args=args, method_dataclass_args=QAEvalArgs) elif type(args) == QAEvalArgs: method_dataclass_args = args else: raise ValueError( "Invalid args type. Use a QAEvalArgs object or a dictionary") return self._trainer.eval(input_filepath=input_filepath, dataclass_args=method_dataclass_args)
def train(self, input_filepath, args=ARGS_WP_TRAIN): if type(args) == dict: method_dataclass_args = create_args_dataclass( default_dic_args=ARGS_WP_TRAIN, input_dic_args=args, method_dataclass_args=WPTrainArgs) elif type(args) == WPTrainArgs: method_dataclass_args = args else: raise ValueError( "Invalid args type. Use a WPTrainArgs object or a dictionary") self._trainer.train(input_filepath=input_filepath, dataclass_args=method_dataclass_args)
def eval(self, input_filepath, args=ARGS_WP_EVAl) -> EvalResult: if type(args) == dict: method_dataclass_args = create_args_dataclass( default_dic_args=ARGS_WP_EVAl, input_dic_args=args, method_dataclass_args=WPEvalArgs) elif type(args) == WPEvalArgs: method_dataclass_args = args else: raise ValueError( "Invalid args type. Use a ARGS_WP_EVAl object or a dictionary") return self._trainer.eval(input_filepath=input_filepath, dataclass_args=method_dataclass_args)
def eval(self, input_filepath: str, args=GENEvalArgs()) -> EvalResult: """ :param input_filepath:a file path to a text file that contains nothing but evaluating data :param args: either a GENEvalArgs() object or a dictionary that contains all of the same keys as ARGS_GEN_EVAl :return: None """ if type(args) == dict: method_dataclass_args = create_args_dataclass( default_dic_args=ARGS_GEN_EVAl, input_dic_args=args, method_dataclass_args=GENEvalArgs) elif type(args) == GENEvalArgs: method_dataclass_args = args else: raise ValueError( "Invalid args type. Use a GENEvalArgs object or a dictionary") return self._trainer.eval(input_filepath=input_filepath, dataclass_args=method_dataclass_args)
def train(self, input_filepath: str, args=GENTrainArgs()): """ :param input_filepath:a file path to a text file that contains nothing but training data :param args: either a GENTrainArgs() object or a dictionary that contains all of the same keys as ARGS_GEN_TRAIN :return: None """ if type(args) == dict: method_dataclass_args = create_args_dataclass( default_dic_args=ARGS_GEN_TRAIN, input_dic_args=args, method_dataclass_args=GENTrainArgs) elif type(args) == GENTrainArgs: method_dataclass_args = args else: raise ValueError( "Invalid args type. Use a GENTrainArgs object or a dictionary") self._trainer.train(input_filepath=input_filepath, dataclass_args=method_dataclass_args)
def eval(self, input_filepath, args=TCEvalArgs()) -> EvalResult: """ Evaluated the text classification answering model input_filepath: a string that contains the location of a csv file for training. Contains the following header values: text, label return: an EvalResult() object """ if type(args) == dict: method_dataclass_args = create_args_dataclass( default_dic_args=ARGS_TC_EVAL, input_dic_args=args, method_dataclass_args=TCEvalArgs) elif type(args) == TCEvalArgs: method_dataclass_args = args else: raise ValueError( "Invalid args type. Use a TCEvalArgs object or a dictionary") return self._trainer.eval(input_filepath=input_filepath, dataclass_args=method_dataclass_args)
def test(self, input_filepath, args=TCTestArgs()): """ Tests the text classification model. Used to obtain results input_filepath: a string that contains the location of a csv file for training. Contains the following header value: text return: A list of TextClassificationResult() objects """ if type(args) == dict: method_dataclass_args = create_args_dataclass( default_dic_args=ARGS_TC_TEST, input_dic_args=args, method_dataclass_args=TCTestArgs) elif type(args) == TCTestArgs: method_dataclass_args = args else: raise ValueError( "Invalid args type. Use a TCTestArgs() object or a dictionary") return self._trainer.test(input_filepath=input_filepath, solve=self.classify_text, dataclass_args=method_dataclass_args)