コード例 #1
0
    def test_default_format(self, lazy):
        reader = Seq2SeqDatasetReader(lazy=lazy)
        instances = reader.read('tests/fixtures/data/seq2seq_copy.tsv')
        instances = ensure_list(instances)

        assert len(instances) == 3
        fields = instances[0].fields
        assert [t.text for t in fields["source_tokens"].tokens
                ] == ["@@START@@", "this", "is", "a", "sentence", "@@END@@"]
        assert [t.text for t in fields["target_tokens"].tokens
                ] == ["@@START@@", "this", "is", "a", "sentence", "@@END@@"]
        fields = instances[1].fields
        assert [t.text for t in fields["source_tokens"].tokens
                ] == ["@@START@@", "this", "is", "another", "@@END@@"]
        assert [t.text for t in fields["target_tokens"].tokens
                ] == ["@@START@@", "this", "is", "another", "@@END@@"]
        fields = instances[2].fields
        assert [t.text for t in fields["source_tokens"].tokens] == [
            "@@START@@", "all", "these", "sentences", "should", "get",
            "copied", "@@END@@"
        ]
        assert [t.text for t in fields["target_tokens"].tokens] == [
            "@@START@@", "all", "these", "sentences", "should", "get",
            "copied", "@@END@@"
        ]
コード例 #2
0
 def test_invalid_line_format(self, line):
     with tempfile.NamedTemporaryFile("w") as fp_tmp:
         fp_tmp.write(line)
         fp_tmp.flush()
         reader = Seq2SeqDatasetReader()
         with pytest.raises(ConfigurationError):
             reader.read(fp_tmp.name)
コード例 #3
0
    def test_default_format(self, lazy):
        reader = Seq2SeqDatasetReader(lazy=lazy)
        instances = reader.read(
            str(AllenNlpTestCase.FIXTURES_ROOT / "data" / "seq2seq_copy.tsv"))
        instances = ensure_list(instances)

        assert len(instances) == 3
        fields = instances[0].fields
        assert [t.text for t in fields["source_tokens"].tokens] == [
            "@start@",
            "this",
            "is",
            "a",
            "sentence",
            "@end@",
        ]
        assert [t.text for t in fields["target_tokens"].tokens] == [
            "@start@",
            "this",
            "is",
            "a",
            "sentence",
            "@end@",
        ]
        fields = instances[1].fields
        assert [t.text for t in fields["source_tokens"].tokens] == [
            "@start@",
            "this",
            "is",
            "another",
            "@end@",
        ]
        assert [t.text for t in fields["target_tokens"].tokens] == [
            "@start@",
            "this",
            "is",
            "another",
            "@end@",
        ]
        fields = instances[2].fields
        assert [t.text for t in fields["source_tokens"].tokens] == [
            "@start@",
            "all",
            "these",
            "sentences",
            "should",
            "get",
            "copied",
            "@end@",
        ]
        assert [t.text for t in fields["target_tokens"].tokens] == [
            "@start@",
            "all",
            "these",
            "sentences",
            "should",
            "get",
            "copied",
            "@end@",
        ]
コード例 #4
0
    def test_source_add_start_token(self):
        reader = Seq2SeqDatasetReader(source_add_start_token=False)
        instances = reader.read('tests/fixtures/data/seq2seq_copy.tsv')
        instances = ensure_list(instances)

        assert len(instances) == 3
        fields = instances[0].fields
        assert [t.text for t in fields["source_tokens"].tokens] == ["this", "is", "a", "sentence", "@@END@@"]
        assert [t.text for t in fields["target_tokens"].tokens] == ["@@START@@", "this", "is",
                                                                    "a", "sentence", "@@END@@"]
コード例 #5
0
    def test_source_add_start_token(self):
        reader = Seq2SeqDatasetReader(source_add_start_token=False)
        instances = reader.read(str(AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'seq2seq_copy.tsv'))
        instances = ensure_list(instances)

        assert len(instances) == 3
        fields = instances[0].fields
        assert [t.text for t in fields["source_tokens"].tokens] == ["this", "is", "a", "sentence", "@end@"]
        assert [t.text for t in fields["target_tokens"].tokens] == ["@start@", "this", "is",
                                                                    "a", "sentence", "@end@"]
コード例 #6
0
 def test_correct_quote_handling(self, line):
     with tempfile.NamedTemporaryFile("w") as fp_tmp:
         fp_tmp.write(line)
         fp_tmp.flush()
         reader = Seq2SeqDatasetReader()
         instances = reader.read(fp_tmp.name)
         instances = ensure_list(instances)
         assert len(instances) == 1
         fields = instances[0].fields
         assert [t.text for t in fields["source_tokens"].tokens] == ["@start@", "a", "b", "@end@"]
         assert [t.text for t in fields["target_tokens"].tokens] == ["@start@", "c", "d", "@end@"]
コード例 #7
0
 def test_max_length_truncation(self):
     reader = Seq2SeqDatasetReader(source_max_tokens=3, target_max_tokens=5)
     instances = reader.read(str(AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'seq2seq_copy.tsv'))
     instances = ensure_list(instances)
     assert reader._source_max_exceeded == 2 # pylint: disable=protected-access
     assert reader._target_max_exceeded == 1 # pylint: disable=protected-access
     assert len(instances) == 3
     fields = instances[0].fields
     assert [t.text for t in fields["source_tokens"].tokens] == ["@start@", "this", "is",
                                                                 "a", "@end@"]
     assert [t.text for t in fields["target_tokens"].tokens] == ["@start@", "this", "is",
                                                                 "a", "sentence", "@end@"]
コード例 #8
0
    def test_delimiter_parameter(self):
        reader = Seq2SeqDatasetReader(delimiter=",")
        instances = reader.read(str(AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'seq2seq_copy.csv'))
        instances = ensure_list(instances)

        assert len(instances) == 3
        fields = instances[0].fields
        assert [t.text for t in fields["source_tokens"].tokens] == ["@start@", "this", "is",
                                                                    "a", "sentence", "@end@"]
        assert [t.text for t in fields["target_tokens"].tokens] == ["@start@", "this", "is",
                                                                    "a", "sentence", "@end@"]
        fields = instances[2].fields
        assert [t.text for t in fields["source_tokens"].tokens] == ["@start@", "all", "these", "sentences",
                                                                    "should", "get", "copied", "@end@"]
        assert [t.text for t in fields["target_tokens"].tokens] == ["@start@", "all", "these", "sentences",
                                                                    "should", "get", "copied", "@end@"]
コード例 #9
0
ファイル: seq2seq.py プロジェクト: uhauha2929/AllennlpSeq2Seq
INSTANCES_PER_EPOCH = batch_size * 5000  # 因为训练集太大, 一个epoch设定为5000个batch
num_epochs = 20

embedding_dim = 200
hidden_dim = 256
learning_rate = 1e-4
grad_clipping = 10

max_decoding_steps = 20
beam_size = 5

serialization_dir = 'checkpoints/seq2seq/'

reader = Seq2SeqDatasetReader(
    source_tokenizer=WordTokenizer(word_splitter=JustSpacesWordSplitter()),
    target_tokenizer=WordTokenizer(word_splitter=JustSpacesWordSplitter()),
    source_token_indexers={'tokens': SingleIdTokenIndexer(namespace='source_tokens')},
    target_token_indexers={'tokens': SingleIdTokenIndexer(namespace='target_tokens')},
    lazy=True)

train_dataset = reader.read(train_file)
validation_dataset = reader.read(valid_file)

if os.path.exists(vocab_dir):
    vocab = Vocabulary.from_files(vocab_dir)
else:
    vocab = Vocabulary.from_instances(train_dataset,
                                      min_count={'source_tokens': min_count, 'target_tokens': min_count},
                                      max_vocab_size=max_vocab_size)
    vocab.save_to_files(vocab_dir)

en_embedding = Embedding(num_embeddings=vocab.get_vocab_size('source_tokens'),