Ejemplo n.º 1
0
def parse_model_from_args(args):
    if args.model_from_file:
        colored_model_name = textattack.shared.utils.color_text(
            args.model_from_file, color="blue", method="ansi")
        textattack.shared.logger.info(
            f"Loading model and tokenizer from file: {colored_model_name}")
        if ":" in args.model_from_file:
            model_file, model_name, tokenizer_name = args.model_from_file.split(
                ":")
        else:
            model_file, model_name, tokenizer_name = (
                args.model_from_file,
                "model",
                "tokenizer",
            )
        try:
            model_module = load_module_from_file(args.model_from_file)
        except:
            raise ValueError(f"Failed to import file {args.model_from_file}")
        try:
            model = getattr(model_module, model_name)
        except AttributeError:
            raise AttributeError(
                f"``{model_name}`` not found in module {args.model_from_file}")
        try:
            tokenizer = getattr(model_module, tokenizer_name)
        except AttributeError:
            raise AttributeError(
                f"``{tokenizer_name}`` not found in module {args.model_from_file}"
            )
        model = model.to(textattack.shared.utils.device)
        setattr(model, "tokenizer", tokenizer)
    elif (args.model
          in HUGGINGFACE_DATASET_BY_MODEL) or args.model_from_huggingface:
        import transformers

        model_name = (HUGGINGFACE_DATASET_BY_MODEL[args.model][0] if
                      (args.model in HUGGINGFACE_DATASET_BY_MODEL) else
                      args.model_from_huggingface)

        if ":" in model_name:
            model_class, model_name = model_name
            model_class = eval(f"transformers.{model_class}")
        else:
            model_class, model_name = (
                transformers.AutoModelForSequenceClassification,
                model_name,
            )
        colored_model_name = textattack.shared.utils.color_text(model_name,
                                                                color="blue",
                                                                method="ansi")
        textattack.shared.logger.info(
            f"Loading pre-trained model from HuggingFace model repository: {colored_model_name}"
        )
        model = model_class.from_pretrained(model_name)
        model = model.to(textattack.shared.utils.device)
        try:
            tokenizer = textattack.models.tokenizers.AutoTokenizer(model_name)
        except OSError:
            textattack.shared.logger.warn(
                f"AutoTokenizer {args.model_from_huggingface} not found. Defaulting to `bert-base-uncased`"
            )
            tokenizer = textattack.models.tokenizers.AutoTokenizer(
                "bert-base-uncased")
        setattr(model, "tokenizer", tokenizer)
    else:
        if args.model in TEXTATTACK_DATASET_BY_MODEL:
            model_path, _ = TEXTATTACK_DATASET_BY_MODEL[args.model]
            model = textattack.shared.utils.load_textattack_model_from_path(
                args.model, model_path)
        elif args.model and os.path.exists(args.model):
            # If `args.model` is a path/directory, let's assume it was a model
            # trained with textattack, and try and load it.
            model_args_json_path = os.path.join(args.model, "train_args.json")
            if not os.path.exists(model_args_json_path):
                raise FileNotFoundError(
                    f"Tried to load model from path {args.model} - could not find train_args.json."
                )
            model_train_args = json.loads(open(model_args_json_path).read())
            model_train_args["model"] = args.model
            num_labels = model_train_args["num_labels"]
            from textattack.commands.train_model.train_args_helpers import (
                model_from_args, )

            model = model_from_args(argparse.Namespace(**model_train_args),
                                    num_labels)
        else:
            raise ValueError(
                f"Error: unsupported TextAttack model {args.model}")
    return model
def parse_model_from_args(args):
    if args.model_from_file:
        # Support loading the model from a .py file where a model wrapper
        # is instantiated.
        colored_model_name = textattack.shared.utils.color_text(
            args.model_from_file, color="blue", method="ansi")
        textattack.shared.logger.info(
            f"Loading model and tokenizer from file: {colored_model_name}")
        if ARGS_SPLIT_TOKEN in args.model_from_file:
            model_file, model_name = args.model_from_file.split(
                ARGS_SPLIT_TOKEN)
        else:
            _, model_name = args.model_from_file, "model"
        try:
            model_module = load_module_from_file(args.model_from_file)
        except Exception:
            raise ValueError(f"Failed to import file {args.model_from_file}")
        try:
            model = getattr(model_module, model_name)
        except AttributeError:
            raise AttributeError(
                f"``{model_name}`` not found in module {args.model_from_file}")

        if not isinstance(model, textattack.models.wrappers.ModelWrapper):
            raise TypeError(
                "Model must be of type "
                f"``textattack.models.ModelWrapper``, got type {type(model)}")
    elif (args.model
          in HUGGINGFACE_DATASET_BY_MODEL) or args.model_from_huggingface:
        # Support loading models automatically from the HuggingFace model hub.
        import transformers

        model_name = (HUGGINGFACE_DATASET_BY_MODEL[args.model][0] if
                      (args.model in HUGGINGFACE_DATASET_BY_MODEL) else
                      args.model_from_huggingface)

        if ARGS_SPLIT_TOKEN in model_name:
            model_class, model_name = model_name
            model_class = eval(f"transformers.{model_class}")
        else:
            model_class, model_name = (
                transformers.AutoModelForSequenceClassification,
                model_name,
            )
        colored_model_name = textattack.shared.utils.color_text(model_name,
                                                                color="blue",
                                                                method="ansi")
        textattack.shared.logger.info(
            f"Loading pre-trained model from HuggingFace model repository: {colored_model_name}"
        )
        model = model_class.from_pretrained(model_name)
        tokenizer = textattack.models.tokenizers.AutoTokenizer(model_name)
        model = textattack.models.wrappers.HuggingFaceModelWrapper(
            model, tokenizer, batch_size=args.model_batch_size)
    elif args.model in TEXTATTACK_DATASET_BY_MODEL:
        # Support loading TextAttack pre-trained models via just a keyword.
        model_path, _ = TEXTATTACK_DATASET_BY_MODEL[args.model]
        model = textattack.shared.utils.load_textattack_model_from_path(
            args.model, model_path)
        # Choose the approprate model wrapper (based on whether or not this is
        # a HuggingFace model).
        if isinstance(
                model,
                textattack.models.helpers.BERTForClassification) or isinstance(
                    model, textattack.models.helpers.T5ForTextToText):
            model = textattack.models.wrappers.HuggingFaceModelWrapper(
                model, model.tokenizer, batch_size=args.model_batch_size)
        else:
            model = textattack.models.wrappers.PyTorchModelWrapper(
                model, model.tokenizer, batch_size=args.model_batch_size)
    elif args.model and os.path.exists(args.model):
        # Support loading TextAttack-trained models via just their folder path.
        # If `args.model` is a path/directory, let's assume it was a model
        # trained with textattack, and try and load it.
        model_args_json_path = os.path.join(args.model, "train_args.json")
        if not os.path.exists(model_args_json_path):
            raise FileNotFoundError(
                f"Tried to load model from path {args.model} - could not find train_args.json."
            )
        model_train_args = json.loads(open(model_args_json_path).read())
        if model_train_args["model"] not in {"cnn", "lstm"}:
            # for huggingface models, set args.model to the path of the model
            model_train_args["model"] = args.model
        num_labels = model_train_args["num_labels"]
        from textattack.commands.train_model.train_args_helpers import model_from_args

        model = model_from_args(
            argparse.Namespace(**model_train_args),
            num_labels,
            model_path=args.model,
        )
        model = textattack.models.wrappers.PyTorchModelWrapper(
            model, model.tokenizer, batch_size=args.model_batch_size)
    else:
        raise ValueError(f"Error: unsupported TextAttack model {args.model}")

    return model