예제 #1
0
 def test_run_detext_cnn_ranking(self):
     """
     This method test run_detext with CNN models
     """
     output = os.path.join(DataSetup.out_dir, "cnn_model")
     self._cleanUp(output)
     args = self.ranking_args + ["--out_dir", output]
     sys.argv[1:] = args
     main(sys.argv)
     self._cleanUp(output)
예제 #2
0
 def test_run_detext_multitask_ranking(self):
     """
     This method test run_detext with multitasking models
     """
     output = os.path.join(DataSetup.out_dir, "multitask_model")
     args = self.multitask_args + [
         "--task_ids", "0", "1", "--task_weights", "0.2", "0.8",
         "--out_dir", output
     ]
     sys.argv[1:] = args
     main(sys.argv)
     self._cleanUp(output)
예제 #3
0
 def test_run_detext_lstm(self):
     """
     This method test run_detext with LSTM models
     """
     output = os.path.join(out_dir, "lstm")
     args = self.args + \
         ["--filter_window_size", "0",
          "--ftr_ext", "lstm",
          "--num_hidden", "10",
          "--out_dir", output]
     sys.argv[1:] = args
     main(sys.argv)
     self._cleanUp(output)
예제 #4
0
 def test_run_detext_cnn(self):
     """
     This method test run_detext with CNN models
     """
     output = os.path.join(out_dir, "cnn_model")
     args = self.args + \
         ["--filter_window_size", "1", "2", "3",
          "--ftr_ext", "cnn",
          "--num_hidden", "10", "10", "5",
          "--out_dir", output]
     sys.argv[1:] = args
     main(sys.argv)
     self._cleanUp(output)
예제 #5
0
 def test_run_detext_bert(self):
     """
     This method test run_detext with BERT models
     """
     output = os.path.join(out_dir, "bert")
     args = self.args + \
         ["--ftr_ext", "bert",
          "--bert_config_file", bert_config,
          "--use_bert_dropout", "True",
          "--out_dir", output]
     sys.argv[1:] = args
     main(sys.argv)
     self._cleanUp(output)
예제 #6
0
 def test_run_detext_multitask(self):
     """
     This method test run_detext with multitasking models
     """
     output = os.path.join(out_dir, "multitask_model")
     args = self.multitask_args + \
         ["--filter_window_size", "3",
          "--ftr_ext", "cnn",
          "--num_hidden", "10",
          "--task_ids", "0,1",
          "--task_weights", "0.2,0.8",
          "--out_dir", output]
     sys.argv[1:] = args
     main(sys.argv)
     self._cleanUp(output)
예제 #7
0
    def test_run_detext_libert_space_ranking(self):
        """
        This method test run_detext with LiBERT space tokenization models
        """
        output = os.path.join(DataSetup.out_dir, "libert_space_model")
        self._cleanUp(output)
        args = self.bert_args + [
            "--bert_hub_url",
            DataSetup.libert_space_hub_url,  # pretrained bert hidden size: 16
            "--out_dir",
            output
        ]

        sys.argv[1:] = args
        main(sys.argv)
        self._cleanUp(output)
예제 #8
0
 def test_run_detext_libert_classification(self):
     """
     This method tests run_detext for libert classification fine-tuning
     """
     output = os.path.join(DataSetup.out_dir, "cls_libert_model")
     args = self.base_args + [
         "--task_type", "classification", "--ftr_ext", "bert", "--lr_bert",
         "0.00001", "--bert_hub_url", DataSetup.libert_sp_hub_url,
         "--num_units", "16", f"--{InputFtrType.LABEL_COLUMN_NAME}",
         "label", f"--{InputFtrType.DOC_TEXT_COLUMN_NAMES}", "query_text",
         f"--{InputFtrType.USER_TEXT_COLUMN_NAMES}", "user_headline",
         f"--{InputFtrType.DENSE_FTRS_COLUMN_NAMES}", "dense_ftrs",
         f"--{InputFtrType.SHALLOW_TOWER_SPARSE_FTRS_COLUMN_NAMES}",
         "sparse_ftrs", "--nums_shallow_tower_sparse_ftrs", "30",
         "--nums_dense_ftrs", "8", "--num_classes", "6", "--pmetric",
         "accuracy", "--all_metrics", "accuracy", "confusion_matrix",
         "--test_file", DataSetup.cls_data_dir, "--dev_file",
         DataSetup.cls_data_dir, "--train_file", DataSetup.cls_data_dir,
         "--out_dir", output
     ]
     sys.argv[1:] = args
     main(sys.argv)
     self._cleanUp(output)
 def train(self, training_data_path, validation_data_path, metadata_file,
           checkpoint_path, execution_context, schema_params):
     # Delegate to super class
     detext_driver.main(None)