def _create_dataset_from_args(cls, args):
        """Given ``DatasetArgs``, return specified
        ``textattack.dataset.Dataset`` object."""

        assert isinstance(
            args, cls
        ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."

        # Automatically detect dataset for huggingface & textattack models.
        # This allows us to use the --model shortcut without specifying a dataset.
        if hasattr(args, "model"):
            args.dataset_by_model = args.model
        if args.dataset_by_model in HUGGINGFACE_DATASET_BY_MODEL:
            args.dataset_from_huggingface = HUGGINGFACE_DATASET_BY_MODEL[
                args.dataset_by_model]
        elif args.dataset_by_model in TEXTATTACK_DATASET_BY_MODEL:
            dataset = TEXTATTACK_DATASET_BY_MODEL[args.dataset_by_model]
            if dataset[0].startswith("textattack"):
                # unsavory way to pass custom dataset classes
                # ex: dataset = ('textattack.datasets.helpers.TedMultiTranslationDataset', 'en', 'de')
                dataset = eval(f"{dataset[0]}")(*dataset[1:])
                return dataset
            else:
                args.dataset_from_huggingface = dataset

        # Get dataset from args.
        if args.dataset_from_file:
            textattack.shared.logger.info(
                f"Loading model and tokenizer from file: {args.model_from_file}"
            )
            if ARGS_SPLIT_TOKEN in args.dataset_from_file:
                dataset_file, dataset_name = args.dataset_from_file.split(
                    ARGS_SPLIT_TOKEN)
            else:
                dataset_file, dataset_name = args.dataset_from_file, "dataset"
            try:
                dataset_module = load_module_from_file(dataset_file)
            except Exception:
                raise ValueError(
                    f"Failed to import file {args.dataset_from_file}")
            try:
                dataset = getattr(dataset_module, dataset_name)
            except AttributeError:
                raise AttributeError(
                    f"Variable ``dataset`` not found in module {args.dataset_from_file}"
                )
        elif args.dataset_from_huggingface:
            dataset_args = args.dataset_from_huggingface
            if isinstance(dataset_args, str):
                if ARGS_SPLIT_TOKEN in dataset_args:
                    dataset_args = dataset_args.split(ARGS_SPLIT_TOKEN)
                else:
                    dataset_args = (dataset_args, )
            if args.dataset_split:
                if len(dataset_args) > 1:
                    dataset_args = (dataset_args[:1] + (args.dataset_split, ) +
                                    dataset_args[2:])
                    dataset = textattack.datasets.HuggingFaceDataset(
                        *dataset_args, shuffle=False)
                else:
                    dataset = textattack.datasets.HuggingFaceDataset(
                        *dataset_args, split=args.dataset_split, shuffle=False)
            else:
                dataset = textattack.datasets.HuggingFaceDataset(*dataset_args,
                                                                 shuffle=False)
        else:
            raise ValueError("Must supply pretrained model or dataset")

        assert isinstance(
            dataset, textattack.datasets.Dataset
        ), "Loaded `dataset` must be of type `textattack.datasets.Dataset`."

        if args.filter_by_labels:
            dataset.filter_by_labels_(args.filter_by_labels)

        return dataset
Beispiel #2
0
    def _create_model_from_args(cls, args):
        """Given ``ModelArgs``, return specified
        ``textattack.models.wrappers.ModelWrapper`` object."""

        assert isinstance(
            args, cls
        ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."

        if args.model_from_file:
            # Support loading the model from a .py file where a model wrapper
            # is instantiated.
            colored_model_name = textattack.shared.utils.color_text(
                args.model_from_file, color="blue", method="ansi")
            textattack.shared.logger.info(
                f"Loading model and tokenizer from file: {colored_model_name}")
            if ARGS_SPLIT_TOKEN in args.model_from_file:
                model_file, model_name = args.model_from_file.split(
                    ARGS_SPLIT_TOKEN)
            else:
                _, model_name = args.model_from_file, "model"
            try:
                model_module = load_module_from_file(args.model_from_file)
            except Exception:
                raise ValueError(
                    f"Failed to import file {args.model_from_file}.")
            try:
                model = getattr(model_module, model_name)
            except AttributeError:
                raise AttributeError(
                    f"Variable `{model_name}` not found in module {args.model_from_file}."
                )

            if not isinstance(model, textattack.models.wrappers.ModelWrapper):
                raise TypeError(
                    f"Variable `{model_name}` must be of type "
                    f"``textattack.models.ModelWrapper``, got type {type(model)}."
                )
        elif (args.model in HUGGINGFACE_MODELS) or args.model_from_huggingface:
            # Support loading models automatically from the HuggingFace model hub.

            model_name = (HUGGINGFACE_MODELS[args.model] if
                          (args.model in HUGGINGFACE_MODELS) else
                          args.model_from_huggingface)
            colored_model_name = textattack.shared.utils.color_text(
                model_name, color="blue", method="ansi")
            textattack.shared.logger.info(
                f"Loading pre-trained model from HuggingFace model repository: {colored_model_name}"
            )
            model = transformers.AutoModelForSequenceClassification.from_pretrained(
                model_name)
            tokenizer = transformers.AutoTokenizer.from_pretrained(
                model_name, use_fast=True)
            model = textattack.models.wrappers.HuggingFaceModelWrapper(
                model, tokenizer)
        elif args.model in TEXTATTACK_MODELS:
            # Support loading TextAttack pre-trained models via just a keyword.
            colored_model_name = textattack.shared.utils.color_text(
                args.model, color="blue", method="ansi")
            if args.model.startswith("lstm"):
                textattack.shared.logger.info(
                    f"Loading pre-trained TextAttack LSTM: {colored_model_name}"
                )
                model = textattack.models.helpers.LSTMForClassification.from_pretrained(
                    args.model)
            elif args.model.startswith("cnn"):
                textattack.shared.logger.info(
                    f"Loading pre-trained TextAttack CNN: {colored_model_name}"
                )
                model = (textattack.models.helpers.WordCNNForClassification.
                         from_pretrained(args.model))
            elif args.model.startswith("t5"):
                model = textattack.models.helpers.T5ForTextToText.from_pretrained(
                    args.model)
            else:
                raise ValueError(f"Unknown textattack model {args.model}")

            # Choose the approprate model wrapper (based on whether or not this is
            # a HuggingFace model).
            if isinstance(model, textattack.models.helpers.T5ForTextToText):
                model = textattack.models.wrappers.HuggingFaceModelWrapper(
                    model, model.tokenizer)
            else:
                model = textattack.models.wrappers.PyTorchModelWrapper(
                    model, model.tokenizer)
        elif args.model and os.path.exists(args.model):
            # Support loading TextAttack-trained models via just their folder path.
            # If `args.model` is a path/directory, let's assume it was a model
            # trained with textattack, and try and load it.
            if os.path.exists(
                    os.path.join(args.model, "t5-wrapper-config.json")):
                model = textattack.models.helpers.T5ForTextToText.from_pretrained(
                    args.model)
                model = textattack.models.wrappers.HuggingFaceModelWrapper(
                    model, model.tokenizer)
            elif os.path.exists(os.path.join(args.model, "config.json")):
                with open(os.path.join(args.model, "config.json")) as f:
                    config = json.load(f)
                model_class = config["architectures"]
                if (model_class == "LSTMForClassification"
                        or model_class == "WordCNNForClassification"):
                    model = eval(
                        f"textattack.models.helpers.{model_class}.from_pretrained({args.model})"
                    )
                    model = textattack.models.wrappers.PyTorchModelWrapper(
                        model, model.tokenizer)
                else:
                    # assume the model is from HuggingFace.
                    model = (transformers.AutoModelForSequenceClassification.
                             from_pretrained(args.model))
                    tokenizer = transformers.AutoTokenizer.from_pretrained(
                        args.model, use_fast=True)
                    model = textattack.models.wrappers.HuggingFaceModelWrapper(
                        model, tokenizer)
        else:
            raise ValueError(
                f"Error: unsupported TextAttack model {args.model}")

        assert isinstance(
            model, textattack.models.wrappers.ModelWrapper
        ), "`model` must be of type `textattack.models.wrappers.ModelWrapper`."
        return model
    def _create_attack_from_args(cls, args, model_wrapper):
        """Given ``CommandLineArgs`` and ``ModelWrapper``, return specified
        ``Attack`` object."""

        assert isinstance(
            args, cls
        ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."

        if args.attack_recipe:
            if ARGS_SPLIT_TOKEN in args.attack_recipe:
                recipe_name, params = args.attack_recipe.split(
                    ARGS_SPLIT_TOKEN)
                if recipe_name not in ATTACK_RECIPE_NAMES:
                    raise ValueError(
                        f"Error: unsupported recipe {recipe_name}")
                recipe = eval(
                    f"{ATTACK_RECIPE_NAMES[recipe_name]}.build(model_wrapper, {params})"
                )
            elif args.attack_recipe in ATTACK_RECIPE_NAMES:
                recipe = eval(
                    f"{ATTACK_RECIPE_NAMES[args.attack_recipe]}.build(model_wrapper)"
                )
            else:
                raise ValueError(f"Invalid recipe {args.attack_recipe}")
            if args.query_budget:
                recipe.goal_function.query_budget = args.query_budget
            recipe.goal_function.model_cache_size = args.model_cache_size
            recipe.constraint_cache_size = args.constraint_cache_size
            return recipe
        elif args.attack_from_file:
            if ARGS_SPLIT_TOKEN in args.attack_from_file:
                attack_file, attack_name = args.attack_from_file.split(
                    ARGS_SPLIT_TOKEN)
            else:
                attack_file, attack_name = args.attack_from_file, "attack"
            attack_module = load_module_from_file(attack_file)
            if not hasattr(attack_module, attack_name):
                raise ValueError(
                    f"Loaded `{attack_file}` but could not find `{attack_name}`."
                )
            attack_func = getattr(attack_module, attack_name)
            return attack_func(model_wrapper)
        else:
            goal_function = cls._create_goal_function_from_args(
                args, model_wrapper)
            transformation = cls._create_transformation_from_args(
                args, model_wrapper)
            constraints = cls._create_constraints_from_args(args)
            if ARGS_SPLIT_TOKEN in args.search_method:
                search_name, params = args.search_method.split(
                    ARGS_SPLIT_TOKEN)
                if search_name not in SEARCH_METHOD_CLASS_NAMES:
                    raise ValueError(
                        f"Error: unsupported search {search_name}")
                search_method = eval(
                    f"{SEARCH_METHOD_CLASS_NAMES[search_name]}({params})")
            elif args.search_method in SEARCH_METHOD_CLASS_NAMES:
                search_method = eval(
                    f"{SEARCH_METHOD_CLASS_NAMES[args.search_method]}()")
            else:
                raise ValueError(
                    f"Error: unsupported attack {args.search_method}")

        return Attack(
            goal_function,
            constraints,
            transformation,
            search_method,
            constraint_cache_size=args.constraint_cache_size,
        )