Ejemplo n.º 1
0
    def test_penntreebank(self):
        from torchtext.experimental.datasets import PennTreebank
        # smoke test to ensure penn treebank works properly
        train_dataset, valid_dataset, test_dataset = PennTreebank()
        train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset)))
        valid_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, valid_dataset)))
        test_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, test_dataset)))
        self._helper_test_func(len(train_data), 924412, train_data[20:25],
                               [9919, 9920, 9921, 9922, 9188])
        self._helper_test_func(len(test_data), 82114, test_data[30:35],
                               [397, 93, 4, 16, 7])
        self._helper_test_func(len(valid_data), 73339, valid_data[40:45],
                               [0, 0, 78, 426, 196])

        vocab = train_dataset.get_vocab()
        tokens_ids = [vocab[token] for token in 'the player characters rest'.split()]
        self.assertEqual(tokens_ids, [2, 2550, 3344, 1125])

        # Add test for the subset of the standard datasets
        train_dataset, test_dataset = PennTreebank(data_select=('train', 'test'))
        train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset)))
        test_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, test_dataset)))
        self._helper_test_func(len(train_data), 924412, train_data[20:25],
                               [9919, 9920, 9921, 9922, 9188])
        self._helper_test_func(len(test_data), 82114, test_data[30:35],
                               [397, 93, 4, 16, 7])
        train_iter, test_iter = torchtext.experimental.datasets.raw.PennTreebank(data_select=('train', 'test'))
        self._helper_test_func(len(train_iter), 42068, next(iter(train_iter))[:15], ' aer banknote b')
        self._helper_test_func(len(test_iter), 3761, next(iter(test_iter))[:25], " no it was n't black mond")
        del train_iter, test_iter
Ejemplo n.º 2
0
def get_ptb(conf):
    """
    Return PennTreeBank iterators
    """
    # raw data
    train_iter, test_iter, valid_iter = PennTreebank(split=('train', 'test', 'valid'))
    train_iter_copy, test_iter_copy, valid_iter_copy = PennTreebank(split=('train', 'test', 'valid'))
    # loader
    train, test, valid, vocab = load_dataset(train_iter, test_iter, valid_iter, train_iter_copy, test_iter_copy, valid_iter_copy, conf)
    return train, test, valid, vocab
Ejemplo n.º 3
0
    def test_penntreebank(self):
        from torchtext.experimental.datasets import PennTreebank
        # smoke test to ensure wikitext2 works properly
        train_dataset, test_dataset, valid_dataset = PennTreebank()
        self.assertEqual(len(train_dataset), 924412)
        self.assertEqual(len(test_dataset), 82114)
        self.assertEqual(len(valid_dataset), 73339)

        vocab = train_dataset.get_vocab()
        tokens_ids = [vocab[token] for token in 'the player characters rest'.split()]
        self.assertEqual(tokens_ids, [2, 2550, 3344, 1125])
Ejemplo n.º 4
0
def load_iter():
    SEED = 1
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    random.seed(SEED)
    torch.manual_seed(SEED)

    spacy_en = spacy.load('en')

    def tokenize_en(text):
        """
        Tokenizes English text from a string into a list of strings (tokens)
        """
        return [tok.text for tok in spacy_en.tokenizer(text)]

    TEXT = Field(tokenize=tokenize_en,
                 init_token='<sos>',
                 eos_token='<eos>',
                 lower=True)

    train_set, valid_set, test_set = PennTreebank.splits(TEXT)

    TEXT.build_vocab(train_set)

    train_iter, valid_iter, test_iter = BPTTIterator.splits(
        (train_set, valid_set, test_set),
        batch_size=64,
        bptt_len=6,
        device=device)

    return train_iter, valid_iter, test_iter, TEXT
Ejemplo n.º 5
0
    def test_penntreebank(self):
        from torchtext.experimental.datasets import PennTreebank
        # smoke test to ensure wikitext2 works properly
        train_dataset, test_dataset, valid_dataset = PennTreebank()
        self.assertEqual(len(train_dataset), 924412)
        self.assertEqual(len(test_dataset), 82114)
        self.assertEqual(len(valid_dataset), 73339)

        vocab = train_dataset.get_vocab()
        tokens_ids = [
            vocab[token] for token in 'the player characters rest'.split()
        ]
        self.assertEqual(tokens_ids, [2, 2550, 3344, 1125])

        # Delete the dataset after we're done to save disk space on CI
        datafile = os.path.join(self.project_root, ".data", 'ptb.train.txt')
        conditional_remove(datafile)
        datafile = os.path.join(self.project_root, ".data", 'ptb.test.txt')
        conditional_remove(datafile)
        datafile = os.path.join(self.project_root, ".data", 'ptb.valid.txt')
        conditional_remove(datafile)