def test_splits_1(self):
        vectors_cache_dir = '.cache'
        if os.path.exists(vectors_cache_dir):
            shutil.rmtree(vectors_cache_dir)

        data_dir = os.path.join(test_dir_path, 'test_datasets')
        train_path = 'sample_table_large.csv'
        valid_path = 'sample_table_large.csv'
        test_path = 'sample_table_large.csv'
        cache_file = 'cache.pth'
        cache_path = os.path.join(data_dir, cache_file)
        if os.path.exists(cache_path):
            os.remove(cache_path)

        pathdir = os.path.abspath(os.path.join(test_dir_path, 'test_datasets'))
        filename = 'fasttext_sample.vec.zip'
        url_base = urljoin('file:', pathname2url(pathdir)) + os.path.sep
        ft = FastText(filename, url_base=url_base, cache=vectors_cache_dir)

        datasets = process(data_dir,
                           train=train_path,
                           validation=valid_path,
                           test=test_path,
                           cache=cache_file,
                           embeddings=ft,
                           id_attr='_id',
                           left_prefix='ltable_',
                           right_prefix='rtable_',
                           embeddings_cache_path='',
                           pca=False)

        splits = MatchingIterator.splits(datasets, batch_size=16)
        self.assertEqual(splits[0].batch_size, 16)
        self.assertEqual(splits[1].batch_size, 16)
        self.assertEqual(splits[2].batch_size, 16)
        splits_sorted = MatchingIterator.splits(datasets,
                                                batch_sizes=[16, 32, 64],
                                                sort_in_buckets=False)
        self.assertEqual(splits_sorted[0].batch_size, 16)
        self.assertEqual(splits_sorted[1].batch_size, 32)
        self.assertEqual(splits_sorted[2].batch_size, 64)

        if os.path.exists(vectors_cache_dir):
            shutil.rmtree(vectors_cache_dir)

        if os.path.exists(cache_path):
            os.remove(cache_path)
Esempio n. 2
0
def test_create_batches_1():
    vectors_cache_dir = ".cache"
    if os.path.exists(vectors_cache_dir):
        shutil.rmtree(vectors_cache_dir)

    data_dir = os.path.join(test_dir_path, "test_datasets")
    train_path = "sample_table_large.csv"
    valid_path = "sample_table_large.csv"
    test_path = "sample_table_large.csv"
    cache_file = "cache.pth"
    cache_path = os.path.join(data_dir, cache_file)
    if os.path.exists(cache_path):
        os.remove(cache_path)

    datasets = process(
        data_dir,
        train=train_path,
        validation=valid_path,
        test=test_path,
        cache=cache_file,
        embeddings=embeddings,
        id_attr="_id",
        left_prefix="ltable_",
        right_prefix="rtable_",
        embeddings_cache_path="",
        pca=False,
    )

    splits = MatchingIterator.splits(datasets, batch_size=16)
    batch_splits = [split.create_batches() for split in splits]
    assert batch_splits

    sorted_splits = MatchingIterator.splits(datasets,
                                            batch_sizes=[16, 32, 64],
                                            sort_in_buckets=False)
    batch_sorted_splits = [
        sorted_split.create_batches() for sorted_split in sorted_splits
    ]
    assert batch_sorted_splits

    if os.path.exists(vectors_cache_dir):
        shutil.rmtree(vectors_cache_dir)

    if os.path.exists(cache_path):
        os.remove(cache_path)