Ejemplo n.º 1
0
 def generator():
     while True:
         seq = data.create_sequence(
             np.random.randint(MIN_SEQ_LEN, num_chars)
             if num_chars > MIN_SEQ_LEN
             else MIN_SEQ_LEN
         )
         new_seq, label = data.create_example(seq)
         yield pad_sequence(new_seq, max_length, padding_token), label
Ejemplo n.º 2
0
def test_example_contains_no_repeated_next_token(length: int):
    seq = [str(token) for token in range(length)]
    output_seq, _ = data.create_example(seq)
    assert collections.Counter(output_seq)["NTKN"] == 1
Ejemplo n.º 3
0
def test_example_next_token_is_positioned_between_two_chars(length: int):
    seq = [str(token) for token in range(length)]
    output_seq, _ = data.create_example(seq)
    assert 0 < output_seq.index("NTKN") < length + 1
Ejemplo n.º 4
0
def test_example_label_follows_next_token(length: int):
    seq = [str(token) for token in range(length)]
    output_seq, output_label = data.create_example(seq)
    assert output_label == seq[output_seq.index("NTKN")]
    assert output_seq[output_seq.index("NTKN") +
                      1] == seq[output_seq.index("NTKN")]
Ejemplo n.º 5
0
def test_example_preserves_seq_elements_positions(length: int):
    seq = [str(token) for token in range(length)]
    output_seq, _ = data.create_example(seq)
    output_seq.remove("NTKN")
    assert output_seq == seq
Ejemplo n.º 6
0
def test_example_contains_next_token(length: int):
    seq = sequence(length)
    output_seq, _ = data.create_example(seq)
    assert "NTKN" in output_seq