示例#1
0
    def test(self, input_filepath, args=QATestArgs()):
        """
        Tests the question answering model. Used to obtain results

        input_filepath: a string that contains the location of a csv file
        for training. Contains the following header values:
        context, question

        args: Either a QATestArgs() object or a dictionary that contains all of the same keys as ARGS_QA_TEST

        return: A list of dictionaries. Each dictionary
        contains the keys: "score", "start", "end" and "answer"
        """
        if type(args) == dict:
            method_dataclass_args = create_args_dataclass(
                default_dic_args=ARGS_QA_TEST,
                input_dic_args=args,
                method_dataclass_args=QATestArgs)
        elif type(args) == QATestArgs:
            method_dataclass_args = args
        else:
            raise ValueError(
                "Invalid args type. Use a QATestArgs object or a dictionary")

        return self._trainer.test(input_filepath=input_filepath,
                                  solve=self.answer_question,
                                  dataclass_args=method_dataclass_args)
示例#2
0
    def train(self, input_filepath, args=QATrainArgs()):
        """
        Trains the question answering model

        input_filepath: a string that contains the location of a csv file
        for training. Contains the following header values: context,
        question, answer_text, answer_start

        args: Either a QATrainArgs() object or a dictionary that contains all of the same keys as ARGS_QA_TRAIN

        return: None
        """

        if type(args) == dict:
            method_dataclass_args = create_args_dataclass(
                default_dic_args=ARGS_QA_TRAIN,
                input_dic_args=args,
                method_dataclass_args=QATrainArgs)
        elif type(args) == QATrainArgs:
            method_dataclass_args = args
        else:
            raise ValueError(
                "Invalid args type. Use a QATrainArgs object or a dictionary")

        self._trainer.train(input_filepath=input_filepath,
                            dataclass_args=method_dataclass_args)
示例#3
0
    def eval(self, input_filepath, args=QAEvalArgs()) -> EvalResult:
        """
        Trains the question answering model

        input_filepath: a string that contains the location of a csv file
        for training. Contains the following header values:
        context, question, answer_text, answer_start

        args: Either a QAEvalArgs() object or a dictionary that contains all of the same keys as ARGS_QA_EVAl

        return: A dictionary that contains a key called "eval_loss"

        """
        if type(args) == dict:
            method_dataclass_args = create_args_dataclass(
                default_dic_args=ARGS_QA_EVAl,
                input_dic_args=args,
                method_dataclass_args=QAEvalArgs)
        elif type(args) == QAEvalArgs:
            method_dataclass_args = args
        else:
            raise ValueError(
                "Invalid args type. Use a QAEvalArgs object or a dictionary")

        return self._trainer.eval(input_filepath=input_filepath,
                                  dataclass_args=method_dataclass_args)
示例#4
0
    def train(self, input_filepath, args=ARGS_WP_TRAIN):
        if type(args) == dict:
            method_dataclass_args = create_args_dataclass(
                default_dic_args=ARGS_WP_TRAIN,
                input_dic_args=args,
                method_dataclass_args=WPTrainArgs)
        elif type(args) == WPTrainArgs:
            method_dataclass_args = args
        else:
            raise ValueError(
                "Invalid args type. Use a WPTrainArgs object or a dictionary")

        self._trainer.train(input_filepath=input_filepath,
                            dataclass_args=method_dataclass_args)
示例#5
0
    def eval(self, input_filepath, args=ARGS_WP_EVAl) -> EvalResult:
        if type(args) == dict:

            method_dataclass_args = create_args_dataclass(
                default_dic_args=ARGS_WP_EVAl,
                input_dic_args=args,
                method_dataclass_args=WPEvalArgs)
        elif type(args) == WPEvalArgs:
            method_dataclass_args = args
        else:
            raise ValueError(
                "Invalid args type. Use a ARGS_WP_EVAl object or a dictionary")

        return self._trainer.eval(input_filepath=input_filepath,
                                  dataclass_args=method_dataclass_args)
示例#6
0
    def eval(self, input_filepath: str, args=GENEvalArgs()) -> EvalResult:
        """
        :param input_filepath:a file path to a text file that contains nothing but evaluating data
        :param args: either a GENEvalArgs() object or a dictionary that contains all of the same keys as ARGS_GEN_EVAl
        :return: None
        """
        if type(args) == dict:
            method_dataclass_args = create_args_dataclass(
                default_dic_args=ARGS_GEN_EVAl,
                input_dic_args=args,
                method_dataclass_args=GENEvalArgs)
        elif type(args) == GENEvalArgs:
            method_dataclass_args = args
        else:
            raise ValueError(
                "Invalid args type. Use a GENEvalArgs object or a dictionary")

        return self._trainer.eval(input_filepath=input_filepath,
                                  dataclass_args=method_dataclass_args)
示例#7
0
    def train(self, input_filepath: str, args=GENTrainArgs()):
        """
        :param input_filepath:a file path to a text file that contains nothing but training data
        :param args: either a GENTrainArgs() object or a dictionary that contains all of the same keys as ARGS_GEN_TRAIN
        :return: None
        """

        if type(args) == dict:
            method_dataclass_args = create_args_dataclass(
                default_dic_args=ARGS_GEN_TRAIN,
                input_dic_args=args,
                method_dataclass_args=GENTrainArgs)
        elif type(args) == GENTrainArgs:
            method_dataclass_args = args
        else:
            raise ValueError(
                "Invalid args type. Use a GENTrainArgs object or a dictionary")

        self._trainer.train(input_filepath=input_filepath,
                            dataclass_args=method_dataclass_args)
示例#8
0
    def eval(self, input_filepath, args=TCEvalArgs()) -> EvalResult:
        """
        Evaluated the text classification answering model
        input_filepath: a string that contains the location of a csv file
        for training. Contains the following header values:
        text, label

        return: an EvalResult() object
        """
        if type(args) == dict:
            method_dataclass_args = create_args_dataclass(
                default_dic_args=ARGS_TC_EVAL,
                input_dic_args=args,
                method_dataclass_args=TCEvalArgs)
        elif type(args) == TCEvalArgs:
            method_dataclass_args = args
        else:
            raise ValueError(
                "Invalid args type. Use a TCEvalArgs object or a dictionary")

        return self._trainer.eval(input_filepath=input_filepath,
                                  dataclass_args=method_dataclass_args)
示例#9
0
    def test(self, input_filepath, args=TCTestArgs()):
        """
        Tests the text classification  model. Used to obtain results
        input_filepath: a string that contains the location of a csv file
        for training. Contains the following header value:
         text
        return: A list of TextClassificationResult() objects
        """

        if type(args) == dict:
            method_dataclass_args = create_args_dataclass(
                default_dic_args=ARGS_TC_TEST,
                input_dic_args=args,
                method_dataclass_args=TCTestArgs)
        elif type(args) == TCTestArgs:
            method_dataclass_args = args
        else:
            raise ValueError(
                "Invalid args type. Use a TCTestArgs() object or a dictionary")

        return self._trainer.test(input_filepath=input_filepath,
                                  solve=self.classify_text,
                                  dataclass_args=method_dataclass_args)