Exemplo n.º 1
0
    def _create_dataset_from_args(cls, args):
        dataset_args = args.dataset.split(ARGS_SPLIT_TOKEN)
        # TODO `HuggingFaceDataset` -> `HuggingFaceDataset`
        if args.dataset_train_split:
            train_dataset = HuggingFaceDataset(
                *dataset_args, split=args.dataset_train_split
            )
        else:
            try:
                train_dataset = HuggingFaceDataset(*dataset_args, split="train")
                args.dataset_train_split = "train"
            except KeyError:
                raise KeyError(
                    f"Error: no `train` split found in `{args.dataset}` dataset"
                )

        if args.dataset_eval_split:
            eval_dataset = HuggingFaceDataset(
                *dataset_args, split=args.dataset_eval_split
            )
        else:
            # try common dev split names
            try:
                eval_dataset = HuggingFaceDataset(*dataset_args, split="dev")
                args.dataset_eval_split = "dev"
            except KeyError:
                try:
                    eval_dataset = HuggingFaceDataset(*dataset_args, split="eval")
                    args.dataset_eval_split = "eval"
                except KeyError:
                    try:
                        eval_dataset = HuggingFaceDataset(
                            *dataset_args, split="validation"
                        )
                        args.dataset_eval_split = "validation"
                    except KeyError:
                        try:
                            eval_dataset = HuggingFaceDataset(
                                *dataset_args, split="test"
                            )
                            args.dataset_eval_split = "test"
                        except KeyError:
                            raise KeyError(
                                f"Could not find `dev`, `eval`, `validation`, or `test` split in dataset {args.dataset}."
                            )

        if args.filter_train_by_labels:
            train_dataset.filter_by_labels_(args.filter_train_by_labels)
        if args.filter_eval_by_labels:
            eval_dataset.filter_by_labels_(args.filter_eval_by_labels)

        return train_dataset, eval_dataset
Exemplo n.º 2
0
    default=False,
    action="store_true",
    help="log metrics to Weights & Biases",
)
args = parser.parse_args()

date_now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")
current_dir = os.path.dirname(os.path.realpath(__file__))
outputs_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir,
                           "outputs", "training")
outputs_dir = os.path.normpath(outputs_dir)
args.output_dir = os.path.join(outputs_dir,
                               f"{args.model}-{args.dataset}-{date_now}/")

train_dataset = load_ocnliDataset()
train_hugdataset = HuggingFaceDataset(train_dataset)
val_dataset = load_ocnliDataset(split="dev")
val_hugdataset = HuggingFaceDataset(val_dataset)

train_text, train_labels, eval_text, eval_labels = dataset_for_training(
    train_hugdataset, val_hugdataset)

config = BertConfig.from_pretrained(
    args.tokenizer)  # "hfl/chinese-macbert-base"
config.output_attentions = False
config.output_token_type_ids = False
# config.max_length = 30
tokenizer = BertTokenizerFast.from_pretrained(args.tokenizer,
                                              config=config,
                                              max_length=35)
Exemplo n.º 3
0
                outputs.append([1 - score, score])
            else:
                outputs.append([score, 1 - score])
        return np.array(outputs)


# Create the model: a French sentiment analysis model.
# see https://github.com/TheophileBlard/french-sentiment-analysis-with-bert
model = TFAutoModelForSequenceClassification.from_pretrained("tblard/tf-allocine")
tokenizer = AutoTokenizer.from_pretrained("tblard/tf-allocine")
pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)

model_wrapper = HuggingFaceSentimentAnalysisPipelineWrapper(pipeline)

# Create the recipe: PWWS uses a WordNet transformation.
recipe = PWWSRen2019.build(model_wrapper)
# WordNet defaults to english. Set the default language to French ('fra')
#
# See
# "Building a free French wordnet from multilingual resources",
# E. L. R. A. (ELRA) (ed.),
# Proceedings of the Sixth International Language Resources and Evaluation (LREC’08).

recipe.transformation.language = "fra"

dataset = HuggingFaceDataset("allocine", split="test")
for idx, result in enumerate(recipe.attack_dataset(dataset)):
    print(("-" * 20), f"Result {idx+1}", ("-" * 20))
    print(result.__str__(color_method="ansi"))
    print()
Exemplo n.º 4
0
    #############################################################

    out = {}

    out['run_num'] = run_num
    out['num_train_per_class'] = num_train_per_class
    out['task'] = task
    out['transform'] = t
    out['run'] = checkpoint
    out['model_name'] = MODEL_NAME
    out['transform'] = t

    if loaded_checkpoint:

        mw = CustomModelWrapper(model, tokenizer)
        dataset = HuggingFaceDataset(test_dataset, shuffle=True)
        attack_args = textattack.AttackArgs(num_examples=num_advs,
                                            disable_stdout=True)

        for recipe in recipes:

            attack = recipe.build(mw)
            attacker = Attacker(attack, dataset, attack_args)
            attack_results = attacker.attack_dataset()

            num_results = 0
            num_failures = 0
            num_successes = 0

            for result in attack_results:
Exemplo n.º 5
0
def main():
	parser = argparse.ArgumentParser()
	parser.add_argument("--num_examples", default=1, type=int) #50485
	parser.add_argument("--model", default="hfl/chinese-roberta-wwm-ext", type=str)
	parser.add_argument("--num_labels", default=3, type=int)
	parser.add_argument("--cuda", default=0, type=int)
	parser.add_argument("--tokenizer", default="hfl/chinese-roberta-wwm-ext", type=str)
	parser.add_argument(
		"--transformation",
		type=str,
		required=False,
		default="word-swap-embedding",
		help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
		,
	)

	# add_model_args(parser)
	# add_dataset_args(parser)

	parser.add_argument(
		"--constraints",
		type=str,
		required=False,
		nargs="*",
		default=["repeat", "stopword"],
		help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
		,
	)

	parser.add_argument(
		"--log-to-txt",
		"-l",
		nargs="?",
		default=None,
		const="",
		type=str,
		help="Save attack logs to <install-dir>/outputs/~ by default; Include '/' at the end of argument to save "
		"output to specified directory in default naming convention; otherwise enter argument to specify "
		"file name",
	)

	parser.add_argument(
		"--log-to-csv",
		nargs="?",
		default="/home/guest/r09944010/2020MLSECURITY/final/ml-security-proj/attack/OCNLI/roberta/",
		const="",
		type=str,
		help="Save attack logs to <install-dir>/outputs/~ by default; Include '/' at the end of argument to save "
		"output to specified directory in default naming convention; otherwise enter argument to specify "
		"file name",
	)

	parser.add_argument(
		"--csv-style",
		default=None,
		const="fancy",
		nargs="?",
		type=str,
		help="Use --csv-style plain to remove [[]] around words",
	)

	parser.add_argument(
		"--enable-visdom", action="store_true", help="Enable logging to visdom."
	)

	parser.add_argument(
		"--enable-wandb",
		action="store_true",
		help="Enable logging to Weights & Biases.",
	)

	parser.add_argument(
		"--disable-stdout", action="store_true", help="Disable logging to stdout"
	)

	parser.add_argument(
		"--interactive",
		action="store_true",
		default=False,
		help="Whether to run attacks interactively.",
	)

	parser.add_argument(
		"--attack-n",
		action="store_true",
		default=False,
		help="Whether to run attack until `n` examples have been attacked (not skipped).",
	)

	parser.add_argument(
		"--parallel",
		action="store_true",
		default=False,
		help="Run attack using multiple GPUs.",
	)

	# goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
	parser.add_argument(
		"--goal-function",
		"-g",
		default="untargeted-classification",
		# help=f"The goal function to use. choices: {goal_function_choices}",
	)

	def str_to_int(s):
		return sum((ord(c) for c in s))

	parser.add_argument("--random-seed", default=str_to_int("TEXTATTACK"), type=int)

	parser.add_argument(
		"--checkpoint-dir",
		required=False,
		type=str,
		default=None,
		help="The directory to save checkpoint files.",
	)

	parser.add_argument(
		"--checkpoint-interval",
		required=False,
		type=int,
		help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
	)

	parser.add_argument(
		"--query-budget",
		"-q",
		type=int,
		default=float("inf"),
		help="The maximum number of model queries allowed per example attacked.",
	)
	parser.add_argument(
		"--model-batch-size",
		type=int,
		default=28,
		help="The batch size for making calls to the model.",
	)
	parser.add_argument(
		"--model-cache-size",
		type=int,
		default=2 ** 18,
		help="The maximum number of items to keep in the model results cache at once.",
	)
	parser.add_argument(
		"--constraint-cache-size",
		type=int,
		default=2 ** 18,
		help="The maximum number of items to keep in the constraints cache at once.",
	)

	attack_group = parser.add_mutually_exclusive_group(required=False)
	attack_group.add_argument(
		"--search",
		"--search-method",
		"-s",
		type=str,
		required=False,
		default="greedy-word-wir",
		# help=f"The search method to use. choices: {search_choices}",
	)
	attack_group.add_argument(
		"--recipe",
		"--attack-recipe",
		"-r",
		type=str,
		required=False,
		default=None,
		# help="full attack recipe (overrides provided goal function, transformation & constraints)",
		# choices=ATTACK_RECIPE_NAMES.keys(),
	)
	attack_group.add_argument(
		"--attack-from-file",
		type=str,
		required=False,
		default=None,
		help="attack to load from file (overrides provided goal function, transformation & constraints)",
	)
	args = parser.parse_args()

	

	# dataset = load_dataset()
	dataset = load_ocnliDataset(split="dev")
	dataset = HuggingFaceDataset(dataset)
	

	

	
	num_remaining_attacks = args.num_examples
	worklist = deque(range(0, args.num_examples))
	worklist_tail = worklist[-1]
	# multi processing
	pytorch_multiprocessing_workaround()
	args = torch.multiprocessing.Manager().Namespace(**vars(args))
	# We reserve the first GPU for coordinating workers.
	num_gpus = torch.cuda.device_count()
	textattack.shared.logger.info(f"Running on {num_gpus} GPUs")

	start_time = time.time()
	in_queue = torch.multiprocessing.Queue()
	out_queue = torch.multiprocessing.Queue()
	missing_datapoints = set()
	for i in worklist:
		try:
			text, output = dataset[i]
			in_queue.put((i, text, output))
		except IndexError:
			missing_datapoints.add(i)
	# if our dataset is shorter than the number of samples chosen, remove the
	# out-of-bounds indices from the dataset
	for i in missing_datapoints:
		worklist.remove(i)
	# Start workers.
	torch.multiprocessing.Pool(5, attack_from_queue, (args, in_queue, out_queue))
	# attack
	# attack = Attack(goal_function, constraints, transformation, search_method)
	# print(attack)
	attack_log_manager = parse_logger_from_args(args)
	print(attack_log_manager)
	input()

	pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0)
	num_results = 0
	num_failures = 0
	num_successes = 0
	while worklist:
		result = out_queue.get(block=True)
		if isinstance(result, Exception):
			raise result
		idx, result = result
		attack_log_manager.log_result(result)
		worklist.remove(idx)
		if (not args.attack_n) or (
			not isinstance(result, textattack.attack_results.SkippedAttackResult)
		):
			pbar.update()
			num_results += 1

			if (
				type(result) == textattack.attack_results.SuccessfulAttackResult
				or type(result) == textattack.attack_results.MaximizedAttackResult
			):
				num_successes += 1
			if type(result) == textattack.attack_results.FailedAttackResult:
				num_failures += 1
			pbar.set_description(
				"[Succeeded / Failed / Total] {} / {} / {}".format(
					num_successes, num_failures, num_results
				)
			)
		else:
			# worklist_tail keeps track of highest idx that has been part of worklist
			# Used to get the next dataset element when attacking with `attack_n` = True.
			worklist_tail += 1
			try:
				text, output = dataset[worklist_tail]
				worklist.append(worklist_tail)
				in_queue.put((worklist_tail, text, output))
			except IndexError:
				raise IndexError(
					"Tried adding to worklist, but ran out of datapoints. Size of data is {} but tried to access index {}".format(
						len(dataset), worklist_tail
					)
				)

		if (
			args.checkpoint_interval
			and len(attack_log_manager.results) % args.checkpoint_interval == 0
		):
			new_checkpoint = textattack.shared.Checkpoint(
				args, attack_log_manager, worklist, worklist_tail
			)
			new_checkpoint.save()
			attack_log_manager.flush()


	# for result in attack.attack_dataset(dataset, indices=worklist):
		# attack_log_manager.log_result(result)
		# if not args.disable_stdout:
		#     print("\n")
		# if (not args.attack_n) or (
		#     not isinstance(result, textattack.attack_results.SkippedAttackResult)
		# ):
		#     pbar.update(1)
		# else:
		#     # worklist_tail keeps track of highest idx that has been part of worklist
		#     # Used to get the next dataset element when attacking with `attack_n` = True.
		#     worklist_tail += 1
		#     worklist.append(worklist_tail)

		# num_results += 1

		# if (
		#     type(result) == textattack.attack_results.SuccessfulAttackResult
		#     or type(result) == textattack.attack_results.MaximizedAttackResult
		# ):
		#     num_successes += 1
		# if type(result) == textattack.attack_results.FailedAttackResult:
		#     num_failures += 1
		# pbar.set_description(
		#     "[Succeeded / Failed / Total] {} / {} / {}".format(
		#         num_successes, num_failures, num_results
		#     )
		# )

		# if (
		#     args.checkpoint_interval
		#     and len(attack_log_manager.results) % args.checkpoint_interval == 0
		# ):
		#     new_checkpoint = textattack.shared.Checkpoint(
		#         args, attack_log_manager, worklist, worklist_tail
		#     )
		#     new_checkpoint.save()
		#     attack_log_manager.flush()

	pbar.close()
	print()
	# Enable summary stdout
	if args.disable_stdout:
		attack_log_manager.enable_stdout()
	attack_log_manager.log_summary()
	attack_log_manager.flush()
	print()
	# finish_time = time.time()
	textattack.shared.logger.info(f"Attack time: {time.time()}s")
	attack_log_manager.results
Exemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser(
        "TextAttack CLI",
        usage="[python -m] texattack <command> [<args>]",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--tokenizer",
        type=str,
        default="hfl/chinese-roberta-wwm-ext",
    )
    parser.add_argument(
        "--num_examples",
        type=int,
        default="3000",
    )

    parser.add_argument(
        "--model",
        type=str,
        required=False,
        default="hfl/chinese-roberta-wwm-ext",
    )
    parser.add_argument("--random-seed", default=21, type=int)
    # parser = main_parser.add_parser(
    #     "attack",
    #     help="run an attack on an NLP model",
    #     formatter_class=ArgumentDefaultsHelpFormatter,
    # )
    transformation_names = set(
        BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
            WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys())
    parser.add_argument(
        "--transformation",
        type=str,
        required=False,
        default="word-swap-embedding",
        help=
        'The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
        + str(transformation_names),
    )

    # add_model_args(parser)
    # add_dataset_args(parser)

    parser.add_argument(
        "--constraints",
        type=str,
        required=False,
        nargs="*",
        default=["repeat", "stopword"],
        help=
        'Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
        + str(CONSTRAINT_CLASS_NAMES.keys()),
    )

    parser.add_argument(
        "--log-to-txt",
        "-l",
        nargs="?",
        default=None,
        const="",
        type=str,
        help=
        "Save attack logs to <install-dir>/outputs/~ by default; Include '/' at the end of argument to save "
        "output to specified directory in default naming convention; otherwise enter argument to specify "
        "file name",
    )

    parser.add_argument(
        "--log-to-csv",
        nargs="?",
        default="",
        const="",
        type=str,
        help=
        "Save attack logs to <install-dir>/outputs/~ by default; Include '/' at the end of argument to save "
        "output to specified directory in default naming convention; otherwise enter argument to specify "
        "file name",
    )

    parser.add_argument(
        "--csv-style",
        default=None,
        const="fancy",
        nargs="?",
        type=str,
        help="Use --csv-style plain to remove [[]] around words",
    )

    parser.add_argument("--enable-visdom",
                        action="store_true",
                        help="Enable logging to visdom.")

    parser.add_argument(
        "--enable-wandb",
        action="store_true",
        help="Enable logging to Weights & Biases.",
    )

    parser.add_argument("--disable-stdout",
                        action="store_true",
                        help="Disable logging to stdout")

    parser.add_argument(
        "--interactive",
        action="store_true",
        default=False,
        help="Whether to run attacks interactively.",
    )

    parser.add_argument(
        "--attack-n",
        action="store_true",
        default=False,
        help=
        "Whether to run attack until `n` examples have been attacked (not skipped).",
    )

    parser.add_argument(
        "--parallel",
        action="store_true",
        default=False,
        help="Run attack using multiple GPUs.",
    )

    goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
    parser.add_argument(
        "--goal-function",
        "-g",
        default="untargeted-classification",
        help=f"The goal function to use. choices: {goal_function_choices}",
    )

    def str_to_int(s):
        return sum((ord(c) for c in s))

    parser.add_argument(
        "--checkpoint-dir",
        required=False,
        type=str,
        default=default_checkpoint_dir(),
        help="The directory to save checkpoint files.",
    )

    parser.add_argument(
        "--checkpoint-interval",
        required=False,
        type=int,
        help=
        "If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
    )

    parser.add_argument(
        "--query-budget",
        "-q",
        type=int,
        default=float("inf"),
        help=
        "The maximum number of model queries allowed per example attacked.",
    )
    parser.add_argument(
        "--model-batch-size",
        type=int,
        default=32,
        help="The batch size for making calls to the model.",
    )
    parser.add_argument(
        "--model-cache-size",
        type=int,
        default=2**18,
        help=
        "The maximum number of items to keep in the model results cache at once.",
    )
    parser.add_argument(
        "--constraint-cache-size",
        type=int,
        default=2**18,
        help=
        "The maximum number of items to keep in the constraints cache at once.",
    )

    attack_group = parser.add_mutually_exclusive_group(required=False)
    search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())
    attack_group.add_argument(
        "--search",
        "--search-method",
        "-s",
        type=str,
        required=False,
        default="greedy-word-wir",
        help=f"The search method to use. choices: {search_choices}",
    )
    attack_group.add_argument(
        "--recipe",
        "--attack-recipe",
        "-r",
        type=str,
        required=False,
        default="alzantot",
        help=
        "full attack recipe (overrides provided goal function, transformation & constraints)",
        choices=ATTACK_RECIPE_NAMES.keys(),
    )
    attack_group.add_argument(
        "--attack-from-file",
        type=str,
        required=False,
        default=None,
        help=
        "attack to load from file (overrides provided goal function, transformation & constraints)",
    )
    # subparsers = parser.add_subparsers(help="textattack command helpers")

    val_dataset = load_ocnliDataset(split="dev")
    val_hugdataset = HuggingFaceDataset(val_dataset)

    # AttackCommand.register_subcommand(parser)
    attackCommand = AttackCommand()
    args = parser.parse_args()
    attackCommand.run(args, val_hugdataset)
    print("ok")
Exemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_examples", default=3000, type=int)  #50485
    parser.add_argument("--model",
                        default="hfl/chinese-roberta-wwm-ext",
                        type=str)
    parser.add_argument("--num_labels", default=3, type=int)
    parser.add_argument("--cuda", default=0, type=int)
    parser.add_argument("--tokenizer",
                        default="hfl/chinese-roberta-wwm-ext",
                        type=str)
    parser.add_argument(
        "--transformation",
        type=str,
        required=False,
        default="word-swap-embedding",
        help=
        'The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}". Choices: ',
    )

    # add_model_args(parser)
    # add_dataset_args(parser)

    parser.add_argument(
        "--constraints",
        type=str,
        required=False,
        nargs="*",
        default=["repeat", "stopword"],
        help=
        'Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: ',
    )

    parser.add_argument(
        "--log-to-txt",
        "-l",
        nargs="?",
        default=None,
        const="",
        type=str,
        help=
        "Save attack logs to <install-dir>/outputs/~ by default; Include '/' at the end of argument to save "
        "output to specified directory in default naming convention; otherwise enter argument to specify "
        "file name",
    )

    parser.add_argument(
        "--log-to-csv",
        nargs="?",
        default=
        "/home/guest/r09944010/2020MLSECURITY/final/ml-security-proj/attack/OCNLI/roberta/",
        const="",
        type=str,
        help=
        "Save attack logs to <install-dir>/outputs/~ by default; Include '/' at the end of argument to save "
        "output to specified directory in default naming convention; otherwise enter argument to specify "
        "file name",
    )

    parser.add_argument(
        "--csv-style",
        default=None,
        const="fancy",
        nargs="?",
        type=str,
        help="Use --csv-style plain to remove [[]] around words",
    )

    parser.add_argument("--enable-visdom",
                        action="store_true",
                        help="Enable logging to visdom.")

    parser.add_argument(
        "--enable-wandb",
        action="store_true",
        help="Enable logging to Weights & Biases.",
    )

    parser.add_argument("--disable-stdout",
                        action="store_true",
                        help="Disable logging to stdout")

    parser.add_argument(
        "--interactive",
        action="store_true",
        default=False,
        help="Whether to run attacks interactively.",
    )

    parser.add_argument(
        "--attack-n",
        action="store_true",
        default=False,
        help=
        "Whether to run attack until `n` examples have been attacked (not skipped).",
    )

    parser.add_argument(
        "--parallel",
        action="store_true",
        default=False,
        help="Run attack using multiple GPUs.",
    )

    # goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
    parser.add_argument(
        "--goal-function",
        "-g",
        default="untargeted-classification",
        # help=f"The goal function to use. choices: {goal_function_choices}",
    )

    def str_to_int(s):
        return sum((ord(c) for c in s))

    parser.add_argument("--random-seed",
                        default=str_to_int("TEXTATTACK"),
                        type=int)

    parser.add_argument(
        "--checkpoint-dir",
        required=False,
        type=str,
        default=None,
        help="The directory to save checkpoint files.",
    )

    parser.add_argument(
        "--checkpoint-interval",
        required=False,
        type=int,
        help=
        "If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
    )

    parser.add_argument(
        "--query-budget",
        "-q",
        type=int,
        default=float("inf"),
        help=
        "The maximum number of model queries allowed per example attacked.",
    )
    parser.add_argument(
        "--model-batch-size",
        type=int,
        default=26,
        help="The batch size for making calls to the model.",
    )
    parser.add_argument(
        "--model-cache-size",
        type=int,
        default=2**18,
        help=
        "The maximum number of items to keep in the model results cache at once.",
    )
    parser.add_argument(
        "--constraint-cache-size",
        type=int,
        default=2**18,
        help=
        "The maximum number of items to keep in the constraints cache at once.",
    )

    attack_group = parser.add_mutually_exclusive_group(required=False)
    attack_group.add_argument(
        "--search",
        "--search-method",
        "-s",
        type=str,
        required=False,
        default="greedy-word-wir",
        # help=f"The search method to use. choices: {search_choices}",
    )
    attack_group.add_argument(
        "--recipe",
        "--attack-recipe",
        "-r",
        type=str,
        required=False,
        default=None,
        # help="full attack recipe (overrides provided goal function, transformation & constraints)",
        # choices=ATTACK_RECIPE_NAMES.keys(),
    )
    attack_group.add_argument(
        "--attack-from-file",
        type=str,
        required=False,
        default=None,
        help=
        "attack to load from file (overrides provided goal function, transformation & constraints)",
    )
    args = parser.parse_args()

    # dataset = load_dataset()
    dataset = load_ocnliDataset(split="dev")
    dataset = HuggingFaceDataset(dataset)

    num_remaining_attacks = args.num_examples
    worklist = deque(range(0, args.num_examples))
    worklist_tail = worklist[-1]

    config = BertConfig.from_pretrained(
        "hfl/chinese-macbert-base")  # "hfl/chinese-macbert-base"
    config.output_attentions = False
    config.output_token_type_ids = False
    # config.max_length = 30
    tokenizer = BertTokenizer.from_pretrained("hfl/chinese-macbert-base",
                                              config=config)

    config = AutoConfig.from_pretrained(
        './models/roberta/chinese-roberta-wwm-ext-OCNLI-2021-01-05-23-46-02-975289',
        num_labels=3)
    model = AutoModelForSequenceClassification.from_pretrained(
        './models/roberta/chinese-roberta-wwm-ext-OCNLI-2021-01-05-23-46-02-975289',
        config=config,
    )
    model_wrapper = HuggingFaceModelWrapper(model, tokenizer, batch_size=28)

    # goal function
    goal_function = UntargetedClassification(model_wrapper)
    # constraints
    # stopwords = set(
    #     ["个", "关于", "之上", "across", "之后", "afterwards", "再次", "against", "ain", "全部", "几乎", "单独", "along", "早已", "也", "虽然", "是", "among", "amongst", "一个", "和", "其他", "任何", "anyhow", "任何人", "anything", "anyway", "anywhere", "are", "aren", "没有", "around", "as", "at", "后", "been", "之前", "beforehand", "behind", "being", "below", "beside", "besides", "之間", "beyond", "皆是", "但", "by", "可以", "不可以", "是", "不是", "couldn't", "d", "didn", "didn't", "doesn", "doesn't", "don", "don't", "down", "due", "之間", "either", "之外", "elsewhere", "空", "足夠", "甚至", "ever", "任何人", "everything", "everywhere", "except", "first", "for", "former", "formerly", "from", "hadn", "hadn't", "hasn", "hasn't", "haven", "haven't", "he", "hence", "her", "here", "hereafter", "hereby", "herein", "hereupon", "hers", "herself", "him", "himself", "his", "how", "however", "hundred", "i", "if", "in", "indeed", "into", "is", "isn", "isn't", "it", "it's", "its", "itself", "just", "latter", "latterly", "least", "ll", "may", "me", "meanwhile", "mightn", "mightn't", "mine", "more", "moreover", "most", "mostly", "must", "mustn", "mustn't", "my", "myself", "namely", "needn", "needn't", "neither", "never", "nevertheless", "next", "no", "nobody", "none", "noone", "nor", "not", "nothing", "now", "nowhere", "o", "of", "off", "on", "once", "one", "only", "onto", "or", "other", "others", "otherwise", "our", "ours", "ourselves", "out", "over", "per", "please", "s", "same", "shan", "shan't", "she", "she's", "should've", "shouldn", "shouldn't", "somehow", "something", "sometime", "somewhere", "such", "t", "than", "that", "that'll", "the", "their", "theirs", "them", "themselves", "then", "thence", "there", "thereafter", "thereby", "therefore", "therein", "thereupon", "these", "they", "this", "those", "through", "throughout", "thru", "thus", "to", "too", "toward", "towards", "under", "unless", "until", "up", "upon", "used", "ve", "was", "wasn", "wasn't", "we", "were", "weren", "weren't", "what", "whatever", "when", "whence", "whenever", "where", "whereafter", "whereas", "whereby", "wherein", "whereupon", "wherever", "whether", "which", "while", "whither", "who", "whoever", "whole", "whom", "whose", "why", "with", "within", "without", "won", "won't", "would", "wouldn", "wouldn't", "y", "yet", "you", "you'd", "you'll", "you're", "you've", "your", "yours", "yourself", "yourselves"]
    # )
    constraints = [RepeatModification(), StopwordModification()]
    # constraints = [RepeatModification(), StopwordModification(stopwords=stopwords)]
    input_column_modification = InputColumnModification(
        ["premise", "hypothesis"], {"premise"})
    constraints.append(input_column_modification)
    constraints.append(MaxWordsPerturbed(max_percent=0.2))
    constraints.append(
        WordEmbeddingDistance(max_mse_dist=0.5,
                              compare_against_original=False))
    # constraints.append(
    #     Google1BillionWordsLanguageModel(
    #         top_n_per_index=4, compare_against_original=False
    #     )
    # )
    # use_constraint = UniversalSentenceEncoder(
    #     threshold=0.840845057,
    #     metric="angular",
    #     compare_against_original=False,
    #     window_size=15,
    #     skip_text_shorter_than_window=True,
    # )
    # constraints.append(use_constraint)
    transformation = WordSwapEmbedding(max_candidates=8)
    # transformation = WordDeletion()
    # search methods
    # search_method = GreedyWordSwapWIR(wir_method="delete")
    search_method = AlzantotGeneticAlgorithm(pop_size=60,
                                             max_iters=20,
                                             post_crossover_check=False)

    start_time = time.time()
    textattack.shared.utils.set_seed(args.random_seed)

    # attack
    attack = Attack(goal_function, constraints, transformation, search_method)
    print(attack)
    attack_log_manager = parse_logger_from_args(args)

    pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0)
    num_results = 0
    num_failures = 0
    num_successes = 0

    for result in attack.attack_dataset(dataset, indices=worklist):
        attack_log_manager.log_result(result)
        if not args.disable_stdout:
            print("\n")
        if (not args.attack_n) or (not isinstance(
                result, textattack.attack_results.SkippedAttackResult)):
            pbar.update(1)
        else:
            # worklist_tail keeps track of highest idx that has been part of worklist
            # Used to get the next dataset element when attacking with `attack_n` = True.
            worklist_tail += 1
            worklist.append(worklist_tail)

        num_results += 1

        if (type(result) == textattack.attack_results.SuccessfulAttackResult
                or type(result)
                == textattack.attack_results.MaximizedAttackResult):
            num_successes += 1
        if type(result) == textattack.attack_results.FailedAttackResult:
            num_failures += 1
        pbar.set_description(
            "[Succeeded / Failed / Total] {} / {} / {}".format(
                num_successes, num_failures, num_results))

        if (args.checkpoint_interval
                and len(attack_log_manager.results) % args.checkpoint_interval
                == 0):
            new_checkpoint = textattack.shared.Checkpoint(
                args, attack_log_manager, worklist, worklist_tail)
            new_checkpoint.save()
            attack_log_manager.flush()

    pbar.close()
    print()
    # Enable summary stdout
    if args.disable_stdout:
        attack_log_manager.enable_stdout()
    attack_log_manager.log_summary()
    attack_log_manager.flush()
    print()
    # finish_time = time.time()
    textattack.shared.logger.info(f"Attack time: {time.time()}s")
    attack_log_manager.results
y_train = np.array(y_train[:index])
y_test = np.array(y_test[index:])
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

vocabulary = tf.keras.datasets.imdb.get_word_index(path="imdb_word_index.json")

results = model.fit(
    x_train, y_train, epochs=1, batch_size=512, validation_data=(x_test, y_test)
)


if __name__ == "__main__":
    torch.multiprocessing.freeze_support()

    model_wrapper = CustomKerasModelWrapper(model)
    dataset = HuggingFaceDataset("rotten_tomatoes", None, "test", shuffle=True)

    attack = PWWSRen2019.build(model_wrapper)

    attack_args = AttackArgs(
        num_examples=10,
        checkpoint_dir="checkpoints",
        parallel=True,
        num_workers_per_device=2,
    )

    attacker = Attacker(attack, dataset, attack_args)

    attacker.attack_dataset()