예제 #1
0
    def test_generate_sp_model(self):
        """
        Test the function to train a sentencepiece tokenizer.
        """

        asset_name = 'text_normalization_ag_news_test.csv'
        asset_path = get_asset_path(asset_name)
        # We use temporary directory for two reasons:
        # 1. buck (fb internal) generates test environment which contains ',' in its path.
        #    SentencePieceTrainer considers such path as comma-delimited file list.
        #    So as workaround we copy the asset data to temporary directory and load it from there.
        # 2. when fb infra performs stress tests, multiple instances of this test run.
        #    The name of the generated models have to be unique and they need to be cleaned up.
        with tempfile.TemporaryDirectory() as dir_name:
            data_path = os.path.join(dir_name, asset_name)
            shutil.copy(asset_path, data_path)

            model_prefix = os.path.join(dir_name, f'spm_user_{uuid.uuid4()}')
            model_file = f'{model_prefix}.model'
            generate_sp_model(data_path,
                              vocab_size=23456,
                              model_prefix=model_prefix)

            sp_user = spm.SentencePieceProcessor()
            sp_user.Load(model_file)

            self.assertEqual(len(sp_user), 23456)
예제 #2
0
def setup_datasets(dataset_name,
                   root='.data',
                   vocab_size=20000,
                   include_unk=False):
    dataset_tar = download_from_url(URLS[dataset_name], root=root)
    extracted_files = extract_archive(dataset_tar)

    for fname in extracted_files:
        if fname.endswith('train.csv'):
            train_csv_path = fname
        if fname.endswith('test.csv'):
            test_csv_path = fname

    # generate sentencepiece  pretrained tokenizer
    if not path.exists('m_user.model'):
        logging.info('Generate SentencePiece pretrained tokenizer...')
        generate_sp_model(train_csv_path, vocab_size)

    sp_model = load_sp_model("m_user.model")
    sp_generator = sentencepiece_numericalizer(sp_model)
    train_data, train_labels = _create_data_with_sp_transform(
        sp_generator, train_csv_path)
    test_data, test_labels = _create_data_with_sp_transform(
        sp_generator, test_csv_path)

    if len(train_labels ^ test_labels) > 0:
        raise ValueError("Training and test labels don't match")
    return (text_classification.TextClassificationDataset(
        None, train_data, train_labels),
            text_classification.TextClassificationDataset(
                None, test_data, test_labels))
예제 #3
0
    def train(self, train_dset, vocab_size, prefix='sample', **kwargs):

        self.vocab_size = vocab_size
        df = pd.DataFrame([x[0] for x in train_dset])
        df.to_csv('sample.csv', index=False, header=False)
        generate_sp_model('sample.csv',
                          vocab_size=vocab_size,
                          model_prefix=prefix)
        vocab_tokenizer = load_sp_model(prefix + ".model")
        self.tokenizer = sentencepiece_tokenizer(sp_model=vocab_tokenizer)
예제 #4
0
    def test_generate_sp_model(self):
        # Test the function to train a sentencepiece tokenizer

        data_path = 'test/asset/text_normalization_ag_news_test.csv'
        generate_sp_model(data_path, vocab_size=23456, model_prefix='spm_user')

        sp_user = spm.SentencePieceProcessor()
        sp_user.Load('spm_user.model')

        self.assertEqual(len(sp_user), 23456)

        if os.path.isfile('spm_user.model'):
            os.remove('spm_user.model')
        if os.path.isfile('spm_user.vocab'):
            os.remove('spm_user.vocab')
예제 #5
0
    def test_generate_sp_model(self):
        # Test the function to train a sentencepiece tokenizer

        # buck (fb internal) generates test environment which contains ',' in its path.
        # SentencePieceTrainer considers such path as comma-delimited file list.
        # So as workaround we copy the asset data to temporary directory and load it from there.
        data_path = get_asset_path('text_normalization_ag_news_test.csv',
                                   use_temp_dir=True)
        generate_sp_model(data_path, vocab_size=23456, model_prefix='spm_user')

        sp_user = spm.SentencePieceProcessor()
        sp_user.Load('spm_user.model')

        self.assertEqual(len(sp_user), 23456)

        if os.path.isfile('spm_user.model'):
            os.remove('spm_user.model')
        if os.path.isfile('spm_user.vocab'):
            os.remove('spm_user.vocab')