Beispiel #1
0
def test_train_tiny():
    command = "textattack train --model-name-or-path lstm --attack deepwordbug --dataset glue^cola --model-max-length 32 --num-epochs 2 --num-clean-epochs 1 --num-train-adv-examples 200"

    # Run command and validate outputs.
    result = run_command_and_get_result(command)

    assert result.stdout is not None
    assert result.stderr is not None
    assert result.returncode == 0

    stdout = result.stdout.decode().strip()
    print("stdout =>", stdout)
    stderr = result.stderr.decode().strip()
    print("stderr =>", stderr)

    train_args_json_path = re.findall(
        r"Wrote original training args to (\S+)\.", stderr)
    assert len(train_args_json_path) and os.path.exists(
        train_args_json_path[0])

    train_acc = re.findall(r"Train accuracy: (\S+)", stderr)
    assert train_acc
    train_acc = float(train_acc[0][:-1])  # [:-1] removes percent sign
    assert train_acc > 60

    eval_acc = re.findall(r"Eval accuracy: (\S+)", stderr)
    assert eval_acc
    eval_acc = float(eval_acc[0][:-1])  # [:-1] removes percent sign
    assert train_acc > 60
Beispiel #2
0
def test_command_line_eval(name, command, sample_output_file):
    """Tests the command-line function, `textattack eval`.

    Different from other tests, this one compares the sample output file
    to *stderr* output of the evaluation.
    """
    desired_text = open(sample_output_file).read().strip()
    desired_text_lines = desired_text.split("\n")

    # Run command and validate outputs.
    result = run_command_and_get_result(command)

    assert result.stdout is not None
    assert result.stderr is not None

    stdout = result.stdout.decode().strip()
    print("stdout =>", stdout)
    stderr = result.stderr.decode().strip()
    print("stderr =>", stderr)

    print("desired_text =>", desired_text)
    stderr_lines = stderr.split("\n")
    assert desired_text_lines <= stderr_lines

    assert result.returncode == 0
Beispiel #3
0
def test_command_line_attack(name, command, sample_output_file):
    """Runs attack tests and compares their outputs to a reference file."""
    # read in file and create regex
    desired_output = open(sample_output_file, "r").read().strip()
    print("desired_output.encoded =>", desired_output.encode())
    print("desired_output =>", desired_output)
    # regex in sample file look like /.*/
    # / is escaped in python 3.6, but not 3.7+, so we support both
    desired_re = (
        re.escape(desired_output)
        .replace("/\\.\\*/", ".*")
        .replace("\\/\\.\\*\\/", ".*")
    )
    result = run_command_and_get_result(command)
    # get output and check match
    assert result.stdout is not None
    stdout = result.stdout.decode().strip()
    print("stdout.encoded =>", result.stdout)
    print("stdout =>", stdout)
    assert result.stderr is not None
    stderr = result.stderr.decode().strip()
    print("stderr =>", stderr)

    if DEBUG and not re.match(desired_re, stdout, flags=re.S):
        pdb.set_trace()
    assert re.match(desired_re, stdout, flags=re.S)

    assert result.returncode == 0
Beispiel #4
0
def test_command_line_list(name, command, sample_output_file):
    desired_text = open(sample_output_file).read().strip()

    # Run command and validate outputs.
    result = run_command_and_get_result(command)

    assert result.stdout is not None
    assert result.stderr is not None

    stdout = result.stdout.decode().strip()
    print("stdout =>", stdout)
    stderr = result.stderr.decode().strip()
    print("stderr =>", stderr)

    assert stdout == desired_text
Beispiel #5
0
def update_test(command, outfile, add_magic_str=False):
    if isinstance(command, str):
        print(">", command)
    else:
        print("\n".join(f"> {c}" for c in command))
    result = run_command_and_get_result(command)
    stdout = result.stdout.decode().strip()
    if add_magic_str:
        # add magic string to beginning
        magic_str = "/.*/"
        stdout = magic_str + stdout
        # add magic string after attack
        mid_attack_str = "\n--------------------------------------------- Result 1"
        stdout.replace(mid_attack_str, magic_str + mid_attack_str)
    # write to file
    open(outfile, "w").write(stdout + "\n")
Beispiel #6
0
def test_command_line_augmentation(name, command, outfile, sample_output_file):
    import os

    desired_text = open(sample_output_file).read().strip()

    # Run command and validate outputs.
    result = run_command_and_get_result(command)

    assert result.stdout is not None
    stdout = result.stdout.decode().strip()
    print("stdout =>", stdout)
    assert stdout == ""

    assert result.stderr is not None
    stderr = result.stderr.decode().strip()
    print("stderr =>", stderr)
    assert "Wrote 9 augmentations to augment_test.csv" in stderr

    # Ensure CSV file exists, then delete it.
    assert os.path.exists(outfile)
    os.remove(outfile)
def update_test(command, outfile):
    if isinstance(command, str):
        command = (command, )
    command = command + (f"tee {outfile}", )
    print("\n".join(f"> {c}" for c in command))
    run_command_and_get_result(command)