Ejemplo n.º 1
0
 def register_subcommand(main_parser: ArgumentParser):
     parser = main_parser.add_parser(
         "peek-dataset",
         help="show main statistics about a dataset",
         formatter_class=ArgumentDefaultsHelpFormatter,
     )
     add_dataset_args(parser)
     parser.set_defaults(func=PeekDatasetCommand())
Ejemplo n.º 2
0
    def register_subcommand(main_parser: ArgumentParser):
        parser = main_parser.add_parser(
            "eval",
            help="evaluate a model with TextAttack",
            formatter_class=ArgumentDefaultsHelpFormatter,
        )

        add_model_args(parser)
        add_dataset_args(parser)

        parser.add_argument(
            "--batch-size",
            type=int,
            default=256,
            help="Batch size for model inference.",
        )
        parser.set_defaults(func=EvalModelCommand())
Ejemplo n.º 3
0
    def register_subcommand(main_parser: ArgumentParser):
        parser = main_parser.add_parser(
            "attack",
            help="run an attack on an NLP model",
            formatter_class=ArgumentDefaultsHelpFormatter,
        )
        transformation_names = set(
            BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
                WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys())
        parser.add_argument(
            "--transformation",
            type=str,
            required=False,
            default="word-swap-embedding",
            help=
            'The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
            + str(transformation_names),
        )

        add_model_args(parser)
        add_dataset_args(parser)

        parser.add_argument(
            "--constraints",
            type=str,
            required=False,
            nargs="*",
            default=["repeat", "stopword"],
            help=
            'Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
            + str(CONSTRAINT_CLASS_NAMES.keys()),
        )

        parser.add_argument(
            "--log-to-txt",
            "-l",
            nargs="?",
            default=None,
            const="",
            type=str,
            help=
            "Save attack logs to <install-dir>/outputs/~ by default; Include '/' at the end of argument to save "
            "output to specified directory in default naming convention; otherwise enter argument to specify "
            "file name",
        )

        parser.add_argument(
            "--log-to-csv",
            nargs="?",
            default=None,
            const="",
            type=str,
            help=
            "Save attack logs to <install-dir>/outputs/~ by default; Include '/' at the end of argument to save "
            "output to specified directory in default naming convention; otherwise enter argument to specify "
            "file name",
        )

        parser.add_argument(
            "--csv-style",
            default=None,
            const="fancy",
            nargs="?",
            type=str,
            help="Use --csv-style plain to remove [[]] around words",
        )

        parser.add_argument("--enable-visdom",
                            action="store_true",
                            help="Enable logging to visdom.")

        parser.add_argument(
            "--enable-wandb",
            action="store_true",
            help="Enable logging to Weights & Biases.",
        )

        parser.add_argument("--disable-stdout",
                            action="store_true",
                            help="Disable logging to stdout")

        parser.add_argument(
            "--interactive",
            action="store_true",
            default=False,
            help="Whether to run attacks interactively.",
        )

        parser.add_argument(
            "--attack-n",
            action="store_true",
            default=False,
            help=
            "Whether to run attack until `n` examples have been attacked (not skipped).",
        )

        parser.add_argument(
            "--parallel",
            action="store_true",
            default=False,
            help="Run attack using multiple GPUs.",
        )

        goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
        parser.add_argument(
            "--goal-function",
            "-g",
            default="untargeted-classification",
            help=f"The goal function to use. choices: {goal_function_choices}",
        )

        def str_to_int(s):
            return sum((ord(c) for c in s))

        parser.add_argument("--random-seed",
                            default=str_to_int("TEXTATTACK"),
                            type=int)

        parser.add_argument(
            "--checkpoint-dir",
            required=False,
            type=str,
            default=default_checkpoint_dir(),
            help="The directory to save checkpoint files.",
        )

        parser.add_argument(
            "--checkpoint-interval",
            required=False,
            type=int,
            help=
            "If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
        )

        parser.add_argument(
            "--query-budget",
            "-q",
            type=int,
            default=float("inf"),
            help=
            "The maximum number of model queries allowed per example attacked.",
        )
        parser.add_argument(
            "--model-batch-size",
            type=int,
            default=32,
            help="The batch size for making calls to the model.",
        )
        parser.add_argument(
            "--model-cache-size",
            type=int,
            default=2**18,
            help=
            "The maximum number of items to keep in the model results cache at once.",
        )
        parser.add_argument(
            "--constraint-cache-size",
            type=int,
            default=2**18,
            help=
            "The maximum number of items to keep in the constraints cache at once.",
        )

        attack_group = parser.add_mutually_exclusive_group(required=False)
        search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())
        attack_group.add_argument(
            "--search",
            "--search-method",
            "-s",
            type=str,
            required=False,
            default="greedy-word-wir",
            help=f"The search method to use. choices: {search_choices}",
        )
        attack_group.add_argument(
            "--recipe",
            "--attack-recipe",
            "-r",
            type=str,
            required=False,
            default=None,
            help=
            "full attack recipe (overrides provided goal function, transformation & constraints)",
            choices=ATTACK_RECIPE_NAMES.keys(),
        )
        attack_group.add_argument(
            "--attack-from-file",
            type=str,
            required=False,
            default=None,
            help=
            "attack to load from file (overrides provided goal function, transformation & constraints)",
        )

        parser.set_defaults(func=AttackCommand())