def test_train_tiny(): command = "textattack train --model-name-or-path lstm --attack deepwordbug --dataset glue^cola --model-max-length 32 --num-epochs 2 --num-clean-epochs 1 --num-train-adv-examples 200" # Run command and validate outputs. result = run_command_and_get_result(command) assert result.stdout is not None assert result.stderr is not None assert result.returncode == 0 stdout = result.stdout.decode().strip() print("stdout =>", stdout) stderr = result.stderr.decode().strip() print("stderr =>", stderr) train_args_json_path = re.findall( r"Wrote original training args to (\S+)\.", stderr) assert len(train_args_json_path) and os.path.exists( train_args_json_path[0]) train_acc = re.findall(r"Train accuracy: (\S+)", stderr) assert train_acc train_acc = float(train_acc[0][:-1]) # [:-1] removes percent sign assert train_acc > 60 eval_acc = re.findall(r"Eval accuracy: (\S+)", stderr) assert eval_acc eval_acc = float(eval_acc[0][:-1]) # [:-1] removes percent sign assert train_acc > 60
def test_command_line_eval(name, command, sample_output_file): """Tests the command-line function, `textattack eval`. Different from other tests, this one compares the sample output file to *stderr* output of the evaluation. """ desired_text = open(sample_output_file).read().strip() desired_text_lines = desired_text.split("\n") # Run command and validate outputs. result = run_command_and_get_result(command) assert result.stdout is not None assert result.stderr is not None stdout = result.stdout.decode().strip() print("stdout =>", stdout) stderr = result.stderr.decode().strip() print("stderr =>", stderr) print("desired_text =>", desired_text) stderr_lines = stderr.split("\n") assert desired_text_lines <= stderr_lines assert result.returncode == 0
def test_command_line_attack(name, command, sample_output_file): """Runs attack tests and compares their outputs to a reference file.""" # read in file and create regex desired_output = open(sample_output_file, "r").read().strip() print("desired_output.encoded =>", desired_output.encode()) print("desired_output =>", desired_output) # regex in sample file look like /.*/ # / is escaped in python 3.6, but not 3.7+, so we support both desired_re = ( re.escape(desired_output) .replace("/\\.\\*/", ".*") .replace("\\/\\.\\*\\/", ".*") ) result = run_command_and_get_result(command) # get output and check match assert result.stdout is not None stdout = result.stdout.decode().strip() print("stdout.encoded =>", result.stdout) print("stdout =>", stdout) assert result.stderr is not None stderr = result.stderr.decode().strip() print("stderr =>", stderr) if DEBUG and not re.match(desired_re, stdout, flags=re.S): pdb.set_trace() assert re.match(desired_re, stdout, flags=re.S) assert result.returncode == 0
def test_command_line_list(name, command, sample_output_file): desired_text = open(sample_output_file).read().strip() # Run command and validate outputs. result = run_command_and_get_result(command) assert result.stdout is not None assert result.stderr is not None stdout = result.stdout.decode().strip() print("stdout =>", stdout) stderr = result.stderr.decode().strip() print("stderr =>", stderr) assert stdout == desired_text
def update_test(command, outfile, add_magic_str=False): if isinstance(command, str): print(">", command) else: print("\n".join(f"> {c}" for c in command)) result = run_command_and_get_result(command) stdout = result.stdout.decode().strip() if add_magic_str: # add magic string to beginning magic_str = "/.*/" stdout = magic_str + stdout # add magic string after attack mid_attack_str = "\n--------------------------------------------- Result 1" stdout.replace(mid_attack_str, magic_str + mid_attack_str) # write to file open(outfile, "w").write(stdout + "\n")
def test_command_line_augmentation(name, command, outfile, sample_output_file): import os desired_text = open(sample_output_file).read().strip() # Run command and validate outputs. result = run_command_and_get_result(command) assert result.stdout is not None stdout = result.stdout.decode().strip() print("stdout =>", stdout) assert stdout == "" assert result.stderr is not None stderr = result.stderr.decode().strip() print("stderr =>", stderr) assert "Wrote 9 augmentations to augment_test.csv" in stderr # Ensure CSV file exists, then delete it. assert os.path.exists(outfile) os.remove(outfile)
def update_test(command, outfile): if isinstance(command, str): command = (command, ) command = command + (f"tee {outfile}", ) print("\n".join(f"> {c}" for c in command)) run_command_and_get_result(command)