コード例 #1
0
ファイル: data_handler.py プロジェクト: a-domingu/tbcnn
 def get_predict_iter(self,
                      data: Iterable[Dict[str, Any]],
                      batch_size: Optional[int] = None):
     ds = self.gen_dataset(data, include_label_fields=False)
     num_batches = (1 if batch_size is None else math.ceil(
         len(ds) / float(batch_size)))
     it = BatchIterator(
         textdata.Iterator(
             ds,
             batch_size=len(ds) if batch_size is None else batch_size,
             device="cuda:{}".format(torch.cuda.current_device())
             if cuda.CUDA_ENABLED else "cpu",
             sort=True,
             repeat=False,
             train=False,
             sort_key=self.sort_key,
             sort_within_batch=self.sort_within_batch,
             shuffle=self.shuffle,
         ),
         self._postprocess_batch,
         include_target=False,
         is_train=False,
         num_batches=num_batches,
     )
     if batch_size is not None:
         return it
     else:
         for input, _, context in it:
             # only return the first batch since there is only one
             return input, context
コード例 #2
0
ファイル: test_batch.py プロジェクト: pjamil21/text
    def test_batch_iter(self):
        self.write_test_numerical_features_dataset()
        FLOAT = data.Field(use_vocab=False,
                           sequential=False,
                           dtype=torch.float)
        INT = data.Field(use_vocab=False, sequential=False, is_target=True)
        TEXT = data.Field(sequential=False)

        dst = data.TabularDataset(
            path=self.test_numerical_features_dataset_path,
            format="tsv",
            skip_header=False,
            fields=[("float", FLOAT), ("int", INT), ("text", TEXT)])
        TEXT.build_vocab(dst)
        itr = data.Iterator(dst, batch_size=2, device=-1, shuffle=False)
        fld_order = [
            k for k, v in dst.fields.items()
            if v is not None and not v.is_target
        ]
        batch = next(iter(itr))
        (x1, x2), y = batch
        x = (x1, x2)[fld_order.index("float")]
        self.assertEquals(y.data[0], 1)
        self.assertEquals(y.data[1], 12)
        self.assertAlmostEqual(x.data[0], 0.1, places=4)
        self.assertAlmostEqual(x.data[1], 0.5, places=4)
コード例 #3
0
 def test_subword_trec(self):
     TEXT = data.SubwordField()
     LABEL = data.Field(sequential=False)
     RAW = data.Field(sequential=False, use_vocab=False)
     raw, _ = TREC.splits(RAW, LABEL)
     cooked, _ = TREC.splits(TEXT, LABEL)
     LABEL.build_vocab(cooked)
     TEXT.build_vocab(cooked, max_size=100)
     TEXT.segment(cooked)
     print(cooked[0].text)
     batch = next(iter(data.Iterator(cooked, 1, shuffle=False)))
     self.assertEqual(TEXT.reverse(batch.text.data)[0], raw[0].text)
コード例 #4
0
 def sentencelist2iterator(self, sentences):
     examples = list()
     for sentence in sentences:
         example = self.sent2example(sentence)
         examples.append(example)
     dataset = data.Dataset(examples,
                            fields=[('src', self.SRC), ('rsrc', self.rSRC)])
     self.iterator = data.Iterator(dataset,
                                   batch_size=1,
                                   sort_key=lambda x: len(x.src),
                                   sort=True,
                                   sort_within_batch=True,
                                   device=self.device)
コード例 #5
0
ファイル: test_batch.py プロジェクト: pjamil21/text
    def test_batch_with_missing_field(self):
        # smoke test to see if batches with missing attributes are shown properly
        with open(self.test_missing_field_dataset_path, "wt") as f:
            f.write("text,label\n1,0")

        dst = data.TabularDataset(path=self.test_missing_field_dataset_path,
                                  format="csv",
                                  skip_header=True,
                                  fields=[("text",
                                           data.Field(use_vocab=False,
                                                      sequential=False)),
                                          ("label", None)])
        itr = data.Iterator(dst, batch_size=64)
        str(next(itr.__iter__()))
コード例 #6
0
ファイル: data_handler.py プロジェクト: a-domingu/tbcnn
 def _get_test_iter(self, test_dataset: textdata.Dataset,
                    batch_size: int) -> BatchIterator:
     return BatchIterator(
         textdata.Iterator(
             test_dataset,
             batch_size=batch_size,
             device="cuda:{}".format(torch.cuda.current_device())
             if cuda.CUDA_ENABLED else "cpu",
             sort=True,
             repeat=False,
             train=False,
             sort_key=self.sort_key,
         ),
         self._postprocess_batch,
         is_train=False,
         num_batches=math.ceil(len(test_dataset) / float(batch_size)),
     )
コード例 #7
0
ファイル: test_dataset.py プロジェクト: zkneupper/text
    def test_csv_file_with_header(self):
        example_with_header = [("text", "label"), ("HELLO WORLD", "0"),
                               ("goodbye world", "1")]

        TEXT = data.Field(lower=True, tokenize=lambda x: x.split())
        fields = {
            "label": ("label", data.Field(use_vocab=False, sequential=False)),
            "text": ("text", TEXT)
        }

        for format_, delim in zip(["csv", "tsv"], [",", "\t"]):
            with open(self.test_has_header_dataset_path, "wt") as f:
                for line in example_with_header:
                    f.write("{}\n".format(delim.join(line)))

            # check that an error is raised here if a non-existent field is specified
            with self.assertRaises(ValueError):
                data.TabularDataset(
                    path=self.test_has_header_dataset_path,
                    format=format_,
                    fields={"non_existent": ("label", data.Field())})

            dataset = data.TabularDataset(
                path=self.test_has_header_dataset_path,
                format=format_,
                skip_header=False,
                fields=fields)

            TEXT.build_vocab(dataset)

            for i, example in enumerate(dataset):
                self.assertEqual(example.text,
                                 example_with_header[i + 1][0].lower().split())
                self.assertEqual(example.label, example_with_header[i + 1][1])

            # check that the vocabulary is built correctly (#225)
            expected_freqs = {"hello": 1, "world": 2, "goodbye": 1, "text": 0}
            for k, v in expected_freqs.items():
                self.assertEqual(TEXT.vocab.freqs[k], v)

            data_iter = data.Iterator(dataset,
                                      batch_size=1,
                                      sort_within_batch=False,
                                      repeat=False)
            next(data_iter.__iter__())
コード例 #8
0
def test_intent():
    config = tm.Config()

    text_field = data.Field(lower=True, tokenize=tokenize)
    label_field = data.Field(sequential=False)
    fields = [('text', text_field), ('label', label_field)]

    train_dataset, val_dataset = data.TabularDataset.splits(
        path='./',
        format='csv',
        skip_header=True,
        train=train_data_path,
        test=val_data_path,
        fields=fields)
    vectors = Vectors(name="./model/word2vec")
    text_field.build_vocab(train_dataset,
                           val_dataset,
                           min_freq=1,
                           vectors=vectors)

    label_field.build_vocab(train_dataset, val_dataset)

    test_dataset = data.TabularDataset(path=test_data_path,
                                       format='csv',
                                       fields=fields,
                                       skip_header=True)
    test_iter = data.Iterator(test_dataset,
                              batch_size=config.batch_size,
                              sort_key=lambda x: len(x.text))

    print('Loading model from {}...'.format(config.snapshot))
    embed_num = len(text_field.vocab)
    class_num = len(label_field.vocab) - 1
    kernel_sizes = [int(k) for k in config.kernel_sizes.split(',')]

    config.snapshot = './model/snapshot/best_steps_200.pt'

    cnn = tm.TextCnn(embed_num, config.embed_dim, class_num, config.kernel_num,
                     kernel_sizes, config.dropout)
    cnn.load_state_dict(tm.torch.load(config.snapshot))

    summary_predict(cnn, text_field, label_field)
コード例 #9
0
def get_IMDb_DataLoaders_and_TEXT(max_length=256, batch_size=24):
    """IMDbのDataLoaderとTEXTオブジェクトを取得する。 """

    # 訓練データのtsvファイルを作成します
    f = open('./data/IMDb_train.tsv', 'w')

    path = './data/aclImdb/train/pos/'
    for fname in glob.glob(os.path.join(path, '*.txt')):
        with io.open(fname, 'r', encoding="utf-8") as ff:
            text = ff.readline()

            # タブがあれば消しておきます
            text = text.replace('\t', " ")

            text = text+'\t'+'1'+'\t'+'\n'
            f.write(text)

    path = './data/aclImdb/train/neg/'
    for fname in glob.glob(os.path.join(path, '*.txt')):
        with io.open(fname, 'r', encoding="utf-8") as ff:
            text = ff.readline()

            # タブがあれば消しておきます
            text = text.replace('\t', " ")

            text = text+'\t'+'0'+'\t'+'\n'
            f.write(text)

    f.close()

   # テストデータの作成
    f = open('./data/IMDb_test.tsv', 'w')

    path = './data/aclImdb/test/pos/'
    for fname in glob.glob(os.path.join(path, '*.txt')):
        with io.open(fname, 'r', encoding="utf-8") as ff:
            text = ff.readline()

            # タブがあれば消しておきます
            text = text.replace('\t', " ")

            text = text+'\t'+'1'+'\t'+'\n'
            f.write(text)

    path = './data/aclImdb/test/neg/'
    for fname in glob.glob(os.path.join(path, '*.txt')):
        with io.open(fname, 'r', encoding="utf-8") as ff:
            text = ff.readline()

            # タブがあれば消しておきます
            text = text.replace('\t', " ")

            text = text+'\t'+'0'+'\t'+'\n'
            f.write(text)
    f.close()

    def preprocessing_text(text):
        # 改行コードを消去
        text = re.sub('<br />', '', text)

        # カンマ、ピリオド以外の記号をスペースに置換
        for p in string.punctuation:
            if (p == ".") or (p == ","):
                continue
            else:
                text = text.replace(p, " ")

        # ピリオドなどの前後にはスペースを入れておく
        text = text.replace(".", " . ")
        text = text.replace(",", " , ")
        return text

    # 分かち書き(今回はデータが英語で、簡易的にスペースで区切る)
    def tokenizer_punctuation(text):
        return text.strip().split()


    # 前処理と分かち書きをまとめた関数を定義
    def tokenizer_with_preprocessing(text):
        text = preprocessing_text(text)
        ret = tokenizer_punctuation(text)
        return ret


    # データを読み込んだときに、読み込んだ内容に対して行う処理を定義します
    # max_length
    TEXT = data.Field(sequential=True, tokenize=tokenizer_with_preprocessing, use_vocab=True,
                                lower=True, include_lengths=True, batch_first=True, fix_length=max_length, init_token="<cls>", eos_token="<eos>")
    LABEL = data.Field(sequential=False, use_vocab=False)

    # フォルダ「data」から各tsvファイルを読み込みます
    train_val_ds, test_ds = data.TabularDataset.splits(
        path='./data/', train='IMDb_train.tsv',
        test='IMDb_test.tsv', format='tsv',
        fields=[('Text', TEXT), ('Label', LABEL)])

    # data.Datasetのsplit関数で訓練データとvalidationデータを分ける
    train_ds, val_ds = train_val_ds.split(
        split_ratio=0.8, random_state=random.seed(1234))

    # torchtextで単語ベクトルとして英語学習済みモデルを読み込みます
    english_fasttext_vectors = Vectors(name='data/wiki-news-300d-1M.vec')

    # ベクトル化したバージョンのボキャブラリーを作成します
    TEXT.build_vocab(train_ds, vectors=english_fasttext_vectors, min_freq=10)

    # DataLoaderを作成します(torchtextの文脈では単純にiteraterと呼ばれています)
    train_dl = data.Iterator(
        train_ds, batch_size=batch_size, train=True)

    val_dl = data.Iterator(
        val_ds, batch_size=batch_size, train=False, sort=False)

    test_dl = data.Iterator(
        test_ds, batch_size=batch_size, train=False, sort=False)

    return train_dl, val_dl, test_dl, TEXT