Example #1
0
 def test_invalid_line_format(self, line):
     with tempfile.NamedTemporaryFile("w") as fp_tmp:
         fp_tmp.write(line)
         fp_tmp.flush()
         reader = Seq2SeqDatasetReader()
         with pytest.raises(ConfigurationError):
             reader.read(fp_tmp.name)
Example #2
0
    def test_default_format(self, lazy):
        reader = Seq2SeqDatasetReader(lazy=lazy)
        instances = reader.read(
            str(FIXTURES_ROOT / "generation" / "seq2seq_copy.tsv"))
        instances = ensure_list(instances)

        assert len(instances) == 3
        fields = instances[0].fields
        assert [t.text for t in fields["source_tokens"].tokens] == [
            "@start@",
            "this",
            "is",
            "a",
            "sentence",
            "@end@",
        ]
        assert [t.text for t in fields["target_tokens"].tokens] == [
            "@start@",
            "this",
            "is",
            "a",
            "sentence",
            "@end@",
        ]
        fields = instances[1].fields
        assert [t.text for t in fields["source_tokens"].tokens] == [
            "@start@",
            "this",
            "is",
            "another",
            "@end@",
        ]
        assert [t.text for t in fields["target_tokens"].tokens] == [
            "@start@",
            "this",
            "is",
            "another",
            "@end@",
        ]
        fields = instances[2].fields
        assert [t.text for t in fields["source_tokens"].tokens] == [
            "@start@",
            "all",
            "these",
            "sentences",
            "should",
            "get",
            "copied",
            "@end@",
        ]
        assert [t.text for t in fields["target_tokens"].tokens] == [
            "@start@",
            "all",
            "these",
            "sentences",
            "should",
            "get",
            "copied",
            "@end@",
        ]
Example #3
0
    def test_delimiter_parameter(self):
        reader = Seq2SeqDatasetReader(delimiter=",")
        instances = reader.read(
            str(FIXTURES_ROOT / "generation" / "seq2seq_copy.csv"))
        instances = ensure_list(instances)

        assert len(instances) == 3
        fields = instances[0].fields
        assert [t.text for t in fields["source_tokens"].tokens] == [
            "@start@",
            "this",
            "is",
            "a",
            "sentence",
            "@end@",
        ]
        assert [t.text for t in fields["target_tokens"].tokens] == [
            "@start@",
            "this",
            "is",
            "a",
            "sentence",
            "@end@",
        ]
        fields = instances[2].fields
        assert [t.text for t in fields["source_tokens"].tokens] == [
            "@start@",
            "all",
            "these",
            "sentences",
            "should",
            "get",
            "copied",
            "@end@",
        ]
        assert [t.text for t in fields["target_tokens"].tokens] == [
            "@start@",
            "all",
            "these",
            "sentences",
            "should",
            "get",
            "copied",
            "@end@",
        ]
Example #4
0
 def test_correct_quote_handling(self, line):
     with tempfile.NamedTemporaryFile("w") as fp_tmp:
         fp_tmp.write(line)
         fp_tmp.flush()
         reader = Seq2SeqDatasetReader()
         instances = reader.read(fp_tmp.name)
         instances = ensure_list(instances)
         assert len(instances) == 1
         fields = instances[0].fields
         assert [t.text for t in fields["source_tokens"].tokens] == [
             "@start@",
             "a",
             "b",
             "@end@",
         ]
         assert [t.text for t in fields["target_tokens"].tokens] == [
             "@start@",
             "c",
             "d",
             "@end@",
         ]
Example #5
0
    def test_source_add_end_token(self):
        reader = Seq2SeqDatasetReader(source_add_end_token=False)
        instances = reader.read(str(FIXTURES_ROOT / "generation" / "seq2seq_copy.tsv"))
        instances = ensure_list(instances)

        assert len(instances) == 3
        fields = instances[0].fields
        assert [t.text for t in fields["source_tokens"].tokens] == [
            "@start@",
            "this",
            "is",
            "a",
            "sentence",
        ]
        assert [t.text for t in fields["target_tokens"].tokens] == [
            "@start@",
            "this",
            "is",
            "a",
            "sentence",
            "@end@",
        ]
Example #6
0
 def test_max_length_truncation(self):
     reader = Seq2SeqDatasetReader(source_max_tokens=3, target_max_tokens=5)
     instances = reader.read(str(FIXTURES_ROOT / "generation" / "seq2seq_copy.tsv"))
     instances = ensure_list(instances)
     assert reader._source_max_exceeded == 2
     assert reader._target_max_exceeded == 1
     assert len(instances) == 3
     fields = instances[0].fields
     assert [t.text for t in fields["source_tokens"].tokens] == [
         "@start@",
         "this",
         "is",
         "a",
         "@end@",
     ]
     assert [t.text for t in fields["target_tokens"].tokens] == [
         "@start@",
         "this",
         "is",
         "a",
         "sentence",
         "@end@",
     ]
Example #7
0
 def test_bad_start_or_end_symbol(self):
     with pytest.raises(ValueError, match=r"Bad start or end symbol \('BAD SYMBOL"):
         Seq2SeqDatasetReader(start_symbol="BAD SYMBOL")