コード例 #1
0
    def setUp(self):
        """set up the test environment"""

        super(FairseqGenerateCLITest, self).setUp()
        # TODO: create a dummy model instead of loading a large-size model.
        if not os.path.exists(CACHED_BART_MODEL_PATHS['bart.large.cnn']):
            make_dirs(CACHED_BART_MODEL_DIR, exist_ok=True)
            tar_model_path = os.path.join(CACHED_BART_MODEL_DIR,
                                          'bart.large.cnn.tar.gz')
            with open(tar_model_path, 'xb') as tar_model_file:
                wget(BART_MODEL_URLS['bart.large.cnn'], tar_model_file)
            decompress_file(tar_model_path, CACHED_BART_MODEL_DIR)

        self.source_path = CACHED_CNNDM_DATA_DIR
        make_dirs(self.source_path, exist_ok=True)
        file_list = [
            "dict.source.txt", "dict.target.txt",
            "valid.source-target.source.bin", "valid.source-target.target.bin",
            "valid.source-target.source.idx", "valid.source-target.target.idx"
        ]
        for f in file_list:
            f_path = os.path.join(self.source_path, f)
            if not os.path.exists(f_path):
                with open(f_path, 'xb') as new_file:
                    wget(os.path.join(CNNDM_URL, f), new_file)
                    new_file.close()

        self.bart_path = CACHED_BART_MODEL_PATHS['bart.large.cnn'] + '/model.pt'
コード例 #2
0
    def setUp(self):
        """set up the test environment"""

        super(ProphetNetModelTest, self).setUp()
        prophetnet_dir = CACHED_PROPHETNET_MODEL_PATHS[
            'prophetnet_large_160G_cnndm']
        prophetnet_url_base = PROPHETNET_MODEL_URLS[
            'prophetnet_large_160G_cnndm']
        if not os.path.exists(prophetnet_dir):
            make_dirs(prophetnet_dir)
            for download_file in ['model.pt', 'dict.src.txt', 'dict.tgt.txt']:
                output_path = os.path.join(prophetnet_dir, download_file)
                with open(output_path, 'xb') as fout:
                    download_url = urljoin(prophetnet_url_base, download_file)
                    wget(download_url, fout)

        self.prophetnet = NgramTransformerProphetModel.from_pretrained(
            prophetnet_dir, checkpoint_file='model.pt')

        self.source_path = 'tests/models/data/cnn_dm_128_bert.txt'

        # read the expected output.
        self.expected_output_path = 'tests/models/data/cnn_dm_128_bert_expected_output.hypo'  # pylint: disable=line-too-long
        self.expected_outputs = []
        with open(self.expected_output_path, 'rt',
                  encoding="utf-8") as expected_output_file:
            for line in expected_output_file:
                self.expected_outputs.append(line.strip())
コード例 #3
0
    def setUp(self):
        """set up the test environment"""

        super(FairseqBeamSearchOptimizerTest, self).setUp()
        # TODO: create a dummy model instead of loading a large-size model.
        if not os.path.exists(CACHED_BART_MODEL_PATHS['bart.large.cnn']):
            make_dirs(CACHED_BART_MODEL_DIR, exist_ok=True)
            tar_model_path = os.path.join(CACHED_BART_MODEL_DIR,
                                          'bart.large.cnn.tar.gz')
            with open(tar_model_path, 'xb') as tar_model_file:
                wget(BART_MODEL_URLS['bart.large.cnn'], tar_model_file)
            decompress_file(tar_model_path, CACHED_BART_MODEL_DIR)

        self.bart = BARTModel.from_pretrained(
            CACHED_BART_MODEL_PATHS['bart.large.cnn'],
            checkpoint_file='model.pt')

        self.source_path = 'tests/optimizer/fairseq/data/cnndm_128.txt'

        # read the expected output.
        self.expected_output_path = 'tests/optimizer/fairseq/data/expected_output.hypo'  # pylint: disable=line-too-long
        self.expected_outputs = []
        with open(self.expected_output_path, 'rt',
                  encoding="utf-8") as expected_output_file:
            for line in expected_output_file:
                self.expected_outputs.append(line.strip())
コード例 #4
0
ファイル: test_file_utils.py プロジェクト: microsoft/fastseq
    def test_make_dirs(self, directory, mode, exist_ok):
        """Test `make_dirs()`

        Args:
            directory (str): file folder.
            mode (int): directory mode.
            exist_ok (bool): indicate whether it is ok if the input directory
                             exists.
        """
        path = os.path.join(self.parent_dir, directory)
        make_dirs(path, mode, exist_ok)
        self.assertTrue(os.path.exists(path), True)
コード例 #5
0
    def setUp(self):
        """Set up the test environment.
        """
        super(FairseqBeamSearchOptimizerBenchmark, self).setUp()
        if not os.path.exists(CACHED_BART_MODEL_PATHS['bart.large.cnn']):
            make_dirs(CACHED_BART_MODEL_DIR, exist_ok=True)
            tar_model_path = os.path.join(CACHED_BART_MODEL_DIR,
                                          'bart.large.cnn.tar.gz')
            with open(tar_model_path, 'xb') as tar_model_file:
                wget(BART_MODEL_URLS['bart.large.cnn'], tar_model_file)
            decompress_file(tar_model_path, CACHED_BART_MODEL_DIR)

        self.bart = BARTModel.from_pretrained(
            CACHED_BART_MODEL_PATHS['bart.large.cnn'],
            checkpoint_file='model.pt')
        self.source_path = 'tests/optimizer/fairseq/data/cnndm_128.txt'