コード例 #1
0
    def test_penntreebank(self):
        from torchtext.experimental.datasets import PennTreebank
        # smoke test to ensure penn treebank works properly
        train_dataset, valid_dataset, test_dataset = PennTreebank()
        train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset)))
        valid_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, valid_dataset)))
        test_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, test_dataset)))
        self._helper_test_func(len(train_data), 924412, train_data[20:25],
                               [9919, 9920, 9921, 9922, 9188])
        self._helper_test_func(len(test_data), 82114, test_data[30:35],
                               [397, 93, 4, 16, 7])
        self._helper_test_func(len(valid_data), 73339, valid_data[40:45],
                               [0, 0, 78, 426, 196])

        vocab = train_dataset.get_vocab()
        tokens_ids = [vocab[token] for token in 'the player characters rest'.split()]
        self.assertEqual(tokens_ids, [2, 2550, 3344, 1125])

        # Add test for the subset of the standard datasets
        train_dataset, test_dataset = PennTreebank(data_select=('train', 'test'))
        train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset)))
        test_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, test_dataset)))
        self._helper_test_func(len(train_data), 924412, train_data[20:25],
                               [9919, 9920, 9921, 9922, 9188])
        self._helper_test_func(len(test_data), 82114, test_data[30:35],
                               [397, 93, 4, 16, 7])
        train_iter, test_iter = torchtext.experimental.datasets.raw.PennTreebank(data_select=('train', 'test'))
        self._helper_test_func(len(train_iter), 42068, next(iter(train_iter))[:15], ' aer banknote b')
        self._helper_test_func(len(test_iter), 3761, next(iter(test_iter))[:25], " no it was n't black mond")
        del train_iter, test_iter
コード例 #2
0
    def test_penntreebank(self):
        from torchtext.experimental.datasets import PennTreebank
        # smoke test to ensure wikitext2 works properly
        train_dataset, test_dataset, valid_dataset = PennTreebank()
        self.assertEqual(len(train_dataset), 924412)
        self.assertEqual(len(test_dataset), 82114)
        self.assertEqual(len(valid_dataset), 73339)

        vocab = train_dataset.get_vocab()
        tokens_ids = [vocab[token] for token in 'the player characters rest'.split()]
        self.assertEqual(tokens_ids, [2, 2550, 3344, 1125])
コード例 #3
0
    def test_penntreebank(self):
        from torchtext.experimental.datasets import PennTreebank
        # smoke test to ensure wikitext2 works properly
        train_dataset, test_dataset, valid_dataset = PennTreebank()
        self.assertEqual(len(train_dataset), 924412)
        self.assertEqual(len(test_dataset), 82114)
        self.assertEqual(len(valid_dataset), 73339)

        vocab = train_dataset.get_vocab()
        tokens_ids = [
            vocab[token] for token in 'the player characters rest'.split()
        ]
        self.assertEqual(tokens_ids, [2, 2550, 3344, 1125])

        # Delete the dataset after we're done to save disk space on CI
        datafile = os.path.join(self.project_root, ".data", 'ptb.train.txt')
        conditional_remove(datafile)
        datafile = os.path.join(self.project_root, ".data", 'ptb.test.txt')
        conditional_remove(datafile)
        datafile = os.path.join(self.project_root, ".data", 'ptb.valid.txt')
        conditional_remove(datafile)