def setUp(self): """set up the test environment""" super(FairseqGenerateCLITest, self).setUp() # TODO: create a dummy model instead of loading a large-size model. if not os.path.exists(CACHED_BART_MODEL_PATHS['bart.large.cnn']): make_dirs(CACHED_BART_MODEL_DIR, exist_ok=True) tar_model_path = os.path.join(CACHED_BART_MODEL_DIR, 'bart.large.cnn.tar.gz') with open(tar_model_path, 'xb') as tar_model_file: wget(BART_MODEL_URLS['bart.large.cnn'], tar_model_file) decompress_file(tar_model_path, CACHED_BART_MODEL_DIR) self.source_path = CACHED_CNNDM_DATA_DIR make_dirs(self.source_path, exist_ok=True) file_list = [ "dict.source.txt", "dict.target.txt", "valid.source-target.source.bin", "valid.source-target.target.bin", "valid.source-target.source.idx", "valid.source-target.target.idx" ] for f in file_list: f_path = os.path.join(self.source_path, f) if not os.path.exists(f_path): with open(f_path, 'xb') as new_file: wget(os.path.join(CNNDM_URL, f), new_file) new_file.close() self.bart_path = CACHED_BART_MODEL_PATHS['bart.large.cnn'] + '/model.pt'
def setUp(self): """set up the test environment""" super(ProphetNetModelTest, self).setUp() prophetnet_dir = CACHED_PROPHETNET_MODEL_PATHS[ 'prophetnet_large_160G_cnndm'] prophetnet_url_base = PROPHETNET_MODEL_URLS[ 'prophetnet_large_160G_cnndm'] if not os.path.exists(prophetnet_dir): make_dirs(prophetnet_dir) for download_file in ['model.pt', 'dict.src.txt', 'dict.tgt.txt']: output_path = os.path.join(prophetnet_dir, download_file) with open(output_path, 'xb') as fout: download_url = urljoin(prophetnet_url_base, download_file) wget(download_url, fout) self.prophetnet = NgramTransformerProphetModel.from_pretrained( prophetnet_dir, checkpoint_file='model.pt') self.source_path = 'tests/models/data/cnn_dm_128_bert.txt' # read the expected output. self.expected_output_path = 'tests/models/data/cnn_dm_128_bert_expected_output.hypo' # pylint: disable=line-too-long self.expected_outputs = [] with open(self.expected_output_path, 'rt', encoding="utf-8") as expected_output_file: for line in expected_output_file: self.expected_outputs.append(line.strip())
def setUp(self): """set up the test environment""" super(FairseqBeamSearchOptimizerTest, self).setUp() # TODO: create a dummy model instead of loading a large-size model. if not os.path.exists(CACHED_BART_MODEL_PATHS['bart.large.cnn']): make_dirs(CACHED_BART_MODEL_DIR, exist_ok=True) tar_model_path = os.path.join(CACHED_BART_MODEL_DIR, 'bart.large.cnn.tar.gz') with open(tar_model_path, 'xb') as tar_model_file: wget(BART_MODEL_URLS['bart.large.cnn'], tar_model_file) decompress_file(tar_model_path, CACHED_BART_MODEL_DIR) self.bart = BARTModel.from_pretrained( CACHED_BART_MODEL_PATHS['bart.large.cnn'], checkpoint_file='model.pt') self.source_path = 'tests/optimizer/fairseq/data/cnndm_128.txt' # read the expected output. self.expected_output_path = 'tests/optimizer/fairseq/data/expected_output.hypo' # pylint: disable=line-too-long self.expected_outputs = [] with open(self.expected_output_path, 'rt', encoding="utf-8") as expected_output_file: for line in expected_output_file: self.expected_outputs.append(line.strip())
def disable_test_wget(self, url, target_file_name): """Test `wget()`. It is disabled because it is time consuming to download the model file. Once we find a small test file, it will be enabled. Currently, `wget()` tests are coverred by `test_wget_and_decompress_file()`. Args: url (str): download url. target_file_name (target): the expected file name. """ target_file = os.path.join(self.parent_dir, target_file_name) with open(target_file, "xb") as output_file: wget(url, output_file) self.assertTrue(os.path.exists(target_file), True)
def setUp(self): """Set up the test environment. """ super(FairseqBeamSearchOptimizerBenchmark, self).setUp() if not os.path.exists(CACHED_BART_MODEL_PATHS['bart.large.cnn']): make_dirs(CACHED_BART_MODEL_DIR, exist_ok=True) tar_model_path = os.path.join(CACHED_BART_MODEL_DIR, 'bart.large.cnn.tar.gz') with open(tar_model_path, 'xb') as tar_model_file: wget(BART_MODEL_URLS['bart.large.cnn'], tar_model_file) decompress_file(tar_model_path, CACHED_BART_MODEL_DIR) self.bart = BARTModel.from_pretrained( CACHED_BART_MODEL_PATHS['bart.large.cnn'], checkpoint_file='model.pt') self.source_path = 'tests/optimizer/fairseq/data/cnndm_128.txt'
def test_wget_and_decompress_file(self, tar_file_url, tar_file_name, output_folder): """Test `wget()` and `decompress_file(). Args: tar_file_url (str): download url for tar file. tar_file_name (str): tar file name. output_folder (str): directory for decompressing the tar file. """ # download the tar file. tar_file_path = os.path.join(self.parent_dir, tar_file_name) with open(tar_file_path, "xb") as tar_file: wget(tar_file_url, tar_file) # decompress the tar file. output_dir = os.path.join(self.parent_dir, output_folder) output_file = decompress_file(tar_file_path, output_dir) self.assertTrue(os.path.exists(output_file), True)