def test_run_detext_cnn_ranking(self): """ This method test run_detext with CNN models """ output = os.path.join(DataSetup.out_dir, "cnn_model") self._cleanUp(output) args = self.ranking_args + ["--out_dir", output] sys.argv[1:] = args main(sys.argv) self._cleanUp(output)
def test_run_detext_multitask_ranking(self): """ This method test run_detext with multitasking models """ output = os.path.join(DataSetup.out_dir, "multitask_model") args = self.multitask_args + [ "--task_ids", "0", "1", "--task_weights", "0.2", "0.8", "--out_dir", output ] sys.argv[1:] = args main(sys.argv) self._cleanUp(output)
def test_run_detext_lstm(self): """ This method test run_detext with LSTM models """ output = os.path.join(out_dir, "lstm") args = self.args + \ ["--filter_window_size", "0", "--ftr_ext", "lstm", "--num_hidden", "10", "--out_dir", output] sys.argv[1:] = args main(sys.argv) self._cleanUp(output)
def test_run_detext_cnn(self): """ This method test run_detext with CNN models """ output = os.path.join(out_dir, "cnn_model") args = self.args + \ ["--filter_window_size", "1", "2", "3", "--ftr_ext", "cnn", "--num_hidden", "10", "10", "5", "--out_dir", output] sys.argv[1:] = args main(sys.argv) self._cleanUp(output)
def test_run_detext_bert(self): """ This method test run_detext with BERT models """ output = os.path.join(out_dir, "bert") args = self.args + \ ["--ftr_ext", "bert", "--bert_config_file", bert_config, "--use_bert_dropout", "True", "--out_dir", output] sys.argv[1:] = args main(sys.argv) self._cleanUp(output)
def test_run_detext_multitask(self): """ This method test run_detext with multitasking models """ output = os.path.join(out_dir, "multitask_model") args = self.multitask_args + \ ["--filter_window_size", "3", "--ftr_ext", "cnn", "--num_hidden", "10", "--task_ids", "0,1", "--task_weights", "0.2,0.8", "--out_dir", output] sys.argv[1:] = args main(sys.argv) self._cleanUp(output)
def test_run_detext_libert_space_ranking(self): """ This method test run_detext with LiBERT space tokenization models """ output = os.path.join(DataSetup.out_dir, "libert_space_model") self._cleanUp(output) args = self.bert_args + [ "--bert_hub_url", DataSetup.libert_space_hub_url, # pretrained bert hidden size: 16 "--out_dir", output ] sys.argv[1:] = args main(sys.argv) self._cleanUp(output)
def test_run_detext_libert_classification(self): """ This method tests run_detext for libert classification fine-tuning """ output = os.path.join(DataSetup.out_dir, "cls_libert_model") args = self.base_args + [ "--task_type", "classification", "--ftr_ext", "bert", "--lr_bert", "0.00001", "--bert_hub_url", DataSetup.libert_sp_hub_url, "--num_units", "16", f"--{InputFtrType.LABEL_COLUMN_NAME}", "label", f"--{InputFtrType.DOC_TEXT_COLUMN_NAMES}", "query_text", f"--{InputFtrType.USER_TEXT_COLUMN_NAMES}", "user_headline", f"--{InputFtrType.DENSE_FTRS_COLUMN_NAMES}", "dense_ftrs", f"--{InputFtrType.SHALLOW_TOWER_SPARSE_FTRS_COLUMN_NAMES}", "sparse_ftrs", "--nums_shallow_tower_sparse_ftrs", "30", "--nums_dense_ftrs", "8", "--num_classes", "6", "--pmetric", "accuracy", "--all_metrics", "accuracy", "confusion_matrix", "--test_file", DataSetup.cls_data_dir, "--dev_file", DataSetup.cls_data_dir, "--train_file", DataSetup.cls_data_dir, "--out_dir", output ] sys.argv[1:] = args main(sys.argv) self._cleanUp(output)
def train(self, training_data_path, validation_data_path, metadata_file, checkpoint_path, execution_context, schema_params): # Delegate to super class detext_driver.main(None)