예제 #1
0
    def setUp(self):
        """set up the test environment"""

        super(FairseqGenerateCLITest, self).setUp()
        # TODO: create a dummy model instead of loading a large-size model.
        if not os.path.exists(CACHED_BART_MODEL_PATHS['bart.large.cnn']):
            make_dirs(CACHED_BART_MODEL_DIR, exist_ok=True)
            tar_model_path = os.path.join(CACHED_BART_MODEL_DIR,
                                          'bart.large.cnn.tar.gz')
            with open(tar_model_path, 'xb') as tar_model_file:
                wget(BART_MODEL_URLS['bart.large.cnn'], tar_model_file)
            decompress_file(tar_model_path, CACHED_BART_MODEL_DIR)

        self.source_path = CACHED_CNNDM_DATA_DIR
        make_dirs(self.source_path, exist_ok=True)
        file_list = [
            "dict.source.txt", "dict.target.txt",
            "valid.source-target.source.bin", "valid.source-target.target.bin",
            "valid.source-target.source.idx", "valid.source-target.target.idx"
        ]
        for f in file_list:
            f_path = os.path.join(self.source_path, f)
            if not os.path.exists(f_path):
                with open(f_path, 'xb') as new_file:
                    wget(os.path.join(CNNDM_URL, f), new_file)
                    new_file.close()

        self.bart_path = CACHED_BART_MODEL_PATHS['bart.large.cnn'] + '/model.pt'
예제 #2
0
    def setUp(self):
        """set up the test environment"""

        super(ProphetNetModelTest, self).setUp()
        prophetnet_dir = CACHED_PROPHETNET_MODEL_PATHS[
            'prophetnet_large_160G_cnndm']
        prophetnet_url_base = PROPHETNET_MODEL_URLS[
            'prophetnet_large_160G_cnndm']
        if not os.path.exists(prophetnet_dir):
            make_dirs(prophetnet_dir)
            for download_file in ['model.pt', 'dict.src.txt', 'dict.tgt.txt']:
                output_path = os.path.join(prophetnet_dir, download_file)
                with open(output_path, 'xb') as fout:
                    download_url = urljoin(prophetnet_url_base, download_file)
                    wget(download_url, fout)

        self.prophetnet = NgramTransformerProphetModel.from_pretrained(
            prophetnet_dir, checkpoint_file='model.pt')

        self.source_path = 'tests/models/data/cnn_dm_128_bert.txt'

        # read the expected output.
        self.expected_output_path = 'tests/models/data/cnn_dm_128_bert_expected_output.hypo'  # pylint: disable=line-too-long
        self.expected_outputs = []
        with open(self.expected_output_path, 'rt',
                  encoding="utf-8") as expected_output_file:
            for line in expected_output_file:
                self.expected_outputs.append(line.strip())
    def setUp(self):
        """set up the test environment"""

        super(FairseqBeamSearchOptimizerTest, self).setUp()
        # TODO: create a dummy model instead of loading a large-size model.
        if not os.path.exists(CACHED_BART_MODEL_PATHS['bart.large.cnn']):
            make_dirs(CACHED_BART_MODEL_DIR, exist_ok=True)
            tar_model_path = os.path.join(CACHED_BART_MODEL_DIR,
                                          'bart.large.cnn.tar.gz')
            with open(tar_model_path, 'xb') as tar_model_file:
                wget(BART_MODEL_URLS['bart.large.cnn'], tar_model_file)
            decompress_file(tar_model_path, CACHED_BART_MODEL_DIR)

        self.bart = BARTModel.from_pretrained(
            CACHED_BART_MODEL_PATHS['bart.large.cnn'],
            checkpoint_file='model.pt')

        self.source_path = 'tests/optimizer/fairseq/data/cnndm_128.txt'

        # read the expected output.
        self.expected_output_path = 'tests/optimizer/fairseq/data/expected_output.hypo'  # pylint: disable=line-too-long
        self.expected_outputs = []
        with open(self.expected_output_path, 'rt',
                  encoding="utf-8") as expected_output_file:
            for line in expected_output_file:
                self.expected_outputs.append(line.strip())
예제 #4
0
    def disable_test_wget(self, url, target_file_name):
        """Test `wget()`.

        It is disabled because it is time consuming to download the model file.
        Once we find a small test file, it will be enabled. Currently, `wget()`
        tests are coverred by `test_wget_and_decompress_file()`.

        Args:
            url (str): download url.
            target_file_name (target): the expected file name.
        """
        target_file = os.path.join(self.parent_dir, target_file_name)
        with open(target_file, "xb") as output_file:
            wget(url, output_file)
        self.assertTrue(os.path.exists(target_file), True)
    def setUp(self):
        """Set up the test environment.
        """
        super(FairseqBeamSearchOptimizerBenchmark, self).setUp()
        if not os.path.exists(CACHED_BART_MODEL_PATHS['bart.large.cnn']):
            make_dirs(CACHED_BART_MODEL_DIR, exist_ok=True)
            tar_model_path = os.path.join(CACHED_BART_MODEL_DIR,
                                          'bart.large.cnn.tar.gz')
            with open(tar_model_path, 'xb') as tar_model_file:
                wget(BART_MODEL_URLS['bart.large.cnn'], tar_model_file)
            decompress_file(tar_model_path, CACHED_BART_MODEL_DIR)

        self.bart = BARTModel.from_pretrained(
            CACHED_BART_MODEL_PATHS['bart.large.cnn'],
            checkpoint_file='model.pt')
        self.source_path = 'tests/optimizer/fairseq/data/cnndm_128.txt'
예제 #6
0
    def test_wget_and_decompress_file(self, tar_file_url, tar_file_name,
                                      output_folder):
        """Test `wget()` and `decompress_file().

        Args:
            tar_file_url (str): download url for tar file.
            tar_file_name (str): tar file name.
            output_folder (str): directory for decompressing the tar file.
        """
        # download the tar file.
        tar_file_path = os.path.join(self.parent_dir, tar_file_name)
        with open(tar_file_path, "xb") as tar_file:
            wget(tar_file_url, tar_file)

        # decompress the tar file.
        output_dir = os.path.join(self.parent_dir, output_folder)
        output_file = decompress_file(tar_file_path, output_dir)
        self.assertTrue(os.path.exists(output_file), True)