def _create_dataset_from_args(cls, args): """Given ``DatasetArgs``, return specified ``textattack.dataset.Dataset`` object.""" assert isinstance( args, cls ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`." # Automatically detect dataset for huggingface & textattack models. # This allows us to use the --model shortcut without specifying a dataset. if hasattr(args, "model"): args.dataset_by_model = args.model if args.dataset_by_model in HUGGINGFACE_DATASET_BY_MODEL: args.dataset_from_huggingface = HUGGINGFACE_DATASET_BY_MODEL[ args.dataset_by_model] elif args.dataset_by_model in TEXTATTACK_DATASET_BY_MODEL: dataset = TEXTATTACK_DATASET_BY_MODEL[args.dataset_by_model] if dataset[0].startswith("textattack"): # unsavory way to pass custom dataset classes # ex: dataset = ('textattack.datasets.helpers.TedMultiTranslationDataset', 'en', 'de') dataset = eval(f"{dataset[0]}")(*dataset[1:]) return dataset else: args.dataset_from_huggingface = dataset # Get dataset from args. if args.dataset_from_file: textattack.shared.logger.info( f"Loading model and tokenizer from file: {args.model_from_file}" ) if ARGS_SPLIT_TOKEN in args.dataset_from_file: dataset_file, dataset_name = args.dataset_from_file.split( ARGS_SPLIT_TOKEN) else: dataset_file, dataset_name = args.dataset_from_file, "dataset" try: dataset_module = load_module_from_file(dataset_file) except Exception: raise ValueError( f"Failed to import file {args.dataset_from_file}") try: dataset = getattr(dataset_module, dataset_name) except AttributeError: raise AttributeError( f"Variable ``dataset`` not found in module {args.dataset_from_file}" ) elif args.dataset_from_huggingface: dataset_args = args.dataset_from_huggingface if isinstance(dataset_args, str): if ARGS_SPLIT_TOKEN in dataset_args: dataset_args = dataset_args.split(ARGS_SPLIT_TOKEN) else: dataset_args = (dataset_args, ) if args.dataset_split: if len(dataset_args) > 1: dataset_args = (dataset_args[:1] + (args.dataset_split, ) + dataset_args[2:]) dataset = textattack.datasets.HuggingFaceDataset( *dataset_args, shuffle=False) else: dataset = textattack.datasets.HuggingFaceDataset( *dataset_args, split=args.dataset_split, shuffle=False) else: dataset = textattack.datasets.HuggingFaceDataset(*dataset_args, shuffle=False) else: raise ValueError("Must supply pretrained model or dataset") assert isinstance( dataset, textattack.datasets.Dataset ), "Loaded `dataset` must be of type `textattack.datasets.Dataset`." if args.filter_by_labels: dataset.filter_by_labels_(args.filter_by_labels) return dataset
def _create_model_from_args(cls, args): """Given ``ModelArgs``, return specified ``textattack.models.wrappers.ModelWrapper`` object.""" assert isinstance( args, cls ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`." if args.model_from_file: # Support loading the model from a .py file where a model wrapper # is instantiated. colored_model_name = textattack.shared.utils.color_text( args.model_from_file, color="blue", method="ansi") textattack.shared.logger.info( f"Loading model and tokenizer from file: {colored_model_name}") if ARGS_SPLIT_TOKEN in args.model_from_file: model_file, model_name = args.model_from_file.split( ARGS_SPLIT_TOKEN) else: _, model_name = args.model_from_file, "model" try: model_module = load_module_from_file(args.model_from_file) except Exception: raise ValueError( f"Failed to import file {args.model_from_file}.") try: model = getattr(model_module, model_name) except AttributeError: raise AttributeError( f"Variable `{model_name}` not found in module {args.model_from_file}." ) if not isinstance(model, textattack.models.wrappers.ModelWrapper): raise TypeError( f"Variable `{model_name}` must be of type " f"``textattack.models.ModelWrapper``, got type {type(model)}." ) elif (args.model in HUGGINGFACE_MODELS) or args.model_from_huggingface: # Support loading models automatically from the HuggingFace model hub. model_name = (HUGGINGFACE_MODELS[args.model] if (args.model in HUGGINGFACE_MODELS) else args.model_from_huggingface) colored_model_name = textattack.shared.utils.color_text( model_name, color="blue", method="ansi") textattack.shared.logger.info( f"Loading pre-trained model from HuggingFace model repository: {colored_model_name}" ) model = transformers.AutoModelForSequenceClassification.from_pretrained( model_name) tokenizer = transformers.AutoTokenizer.from_pretrained( model_name, use_fast=True) model = textattack.models.wrappers.HuggingFaceModelWrapper( model, tokenizer) elif args.model in TEXTATTACK_MODELS: # Support loading TextAttack pre-trained models via just a keyword. colored_model_name = textattack.shared.utils.color_text( args.model, color="blue", method="ansi") if args.model.startswith("lstm"): textattack.shared.logger.info( f"Loading pre-trained TextAttack LSTM: {colored_model_name}" ) model = textattack.models.helpers.LSTMForClassification.from_pretrained( args.model) elif args.model.startswith("cnn"): textattack.shared.logger.info( f"Loading pre-trained TextAttack CNN: {colored_model_name}" ) model = (textattack.models.helpers.WordCNNForClassification. from_pretrained(args.model)) elif args.model.startswith("t5"): model = textattack.models.helpers.T5ForTextToText.from_pretrained( args.model) else: raise ValueError(f"Unknown textattack model {args.model}") # Choose the approprate model wrapper (based on whether or not this is # a HuggingFace model). if isinstance(model, textattack.models.helpers.T5ForTextToText): model = textattack.models.wrappers.HuggingFaceModelWrapper( model, model.tokenizer) else: model = textattack.models.wrappers.PyTorchModelWrapper( model, model.tokenizer) elif args.model and os.path.exists(args.model): # Support loading TextAttack-trained models via just their folder path. # If `args.model` is a path/directory, let's assume it was a model # trained with textattack, and try and load it. if os.path.exists( os.path.join(args.model, "t5-wrapper-config.json")): model = textattack.models.helpers.T5ForTextToText.from_pretrained( args.model) model = textattack.models.wrappers.HuggingFaceModelWrapper( model, model.tokenizer) elif os.path.exists(os.path.join(args.model, "config.json")): with open(os.path.join(args.model, "config.json")) as f: config = json.load(f) model_class = config["architectures"] if (model_class == "LSTMForClassification" or model_class == "WordCNNForClassification"): model = eval( f"textattack.models.helpers.{model_class}.from_pretrained({args.model})" ) model = textattack.models.wrappers.PyTorchModelWrapper( model, model.tokenizer) else: # assume the model is from HuggingFace. model = (transformers.AutoModelForSequenceClassification. from_pretrained(args.model)) tokenizer = transformers.AutoTokenizer.from_pretrained( args.model, use_fast=True) model = textattack.models.wrappers.HuggingFaceModelWrapper( model, tokenizer) else: raise ValueError( f"Error: unsupported TextAttack model {args.model}") assert isinstance( model, textattack.models.wrappers.ModelWrapper ), "`model` must be of type `textattack.models.wrappers.ModelWrapper`." return model
def _create_attack_from_args(cls, args, model_wrapper): """Given ``CommandLineArgs`` and ``ModelWrapper``, return specified ``Attack`` object.""" assert isinstance( args, cls ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`." if args.attack_recipe: if ARGS_SPLIT_TOKEN in args.attack_recipe: recipe_name, params = args.attack_recipe.split( ARGS_SPLIT_TOKEN) if recipe_name not in ATTACK_RECIPE_NAMES: raise ValueError( f"Error: unsupported recipe {recipe_name}") recipe = eval( f"{ATTACK_RECIPE_NAMES[recipe_name]}.build(model_wrapper, {params})" ) elif args.attack_recipe in ATTACK_RECIPE_NAMES: recipe = eval( f"{ATTACK_RECIPE_NAMES[args.attack_recipe]}.build(model_wrapper)" ) else: raise ValueError(f"Invalid recipe {args.attack_recipe}") if args.query_budget: recipe.goal_function.query_budget = args.query_budget recipe.goal_function.model_cache_size = args.model_cache_size recipe.constraint_cache_size = args.constraint_cache_size return recipe elif args.attack_from_file: if ARGS_SPLIT_TOKEN in args.attack_from_file: attack_file, attack_name = args.attack_from_file.split( ARGS_SPLIT_TOKEN) else: attack_file, attack_name = args.attack_from_file, "attack" attack_module = load_module_from_file(attack_file) if not hasattr(attack_module, attack_name): raise ValueError( f"Loaded `{attack_file}` but could not find `{attack_name}`." ) attack_func = getattr(attack_module, attack_name) return attack_func(model_wrapper) else: goal_function = cls._create_goal_function_from_args( args, model_wrapper) transformation = cls._create_transformation_from_args( args, model_wrapper) constraints = cls._create_constraints_from_args(args) if ARGS_SPLIT_TOKEN in args.search_method: search_name, params = args.search_method.split( ARGS_SPLIT_TOKEN) if search_name not in SEARCH_METHOD_CLASS_NAMES: raise ValueError( f"Error: unsupported search {search_name}") search_method = eval( f"{SEARCH_METHOD_CLASS_NAMES[search_name]}({params})") elif args.search_method in SEARCH_METHOD_CLASS_NAMES: search_method = eval( f"{SEARCH_METHOD_CLASS_NAMES[args.search_method]}()") else: raise ValueError( f"Error: unsupported attack {args.search_method}") return Attack( goal_function, constraints, transformation, search_method, constraint_cache_size=args.constraint_cache_size, )