Example #1
0
 def load_attackers(self) -> Mapping[str, Attacker]:
     """
     Returns a mapping of attackers keyed by a unique identifier. Requests to `/attack/:id`
     will invoke the attacker with the provided `:id`. Override this method to add or remove
     attackers.
     """
     hotflip = Hotflip(self.predictor)
     hotflip.initialize()
     return { "hotflip": hotflip,
              "input-reduction": InputReduction(self.predictor) }
Example #2
0
 def load_attackers(self) -> Dict[str, Attacker]:
     """
     Returns a mapping of attackers keyed by a unique identifier. Requests to `/attack/:id`
     will invoke the attacker with the provided `:id`. Override this method to add or remove
     attackers.
     """
     attackers: Dict[str, Attacker] = {}
     if "hotflip" in self.model.attackers:
         hotflip = Hotflip(self.predictor)
         hotflip.initialize()
         attackers["hotflip"] = hotflip
     if "input_reduction" in self.model.attackers:
         attackers["input_reduction"] = InputReduction(self.predictor)
     return attackers
Example #3
0
    def test_input_reduction(self):
        # test using entailment model
        inputs = {
            "premise": "I always write unit tests for my code.",
            "hypothesis": "One time I didn't write any unit tests for my code."
        }

        archive = load_archive(self.FIXTURES_ROOT / 'decomposable_attention' /
                               'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'textual-entailment')

        reducer = InputReduction(predictor)
        reduced = reducer.attack_from_json(inputs, 'hypothesis',
                                           'grad_input_1')
        assert reduced is not None
        assert 'final' in reduced
        assert 'original' in reduced
        assert reduced['final'][0]  # always at least one token
        assert len(reduced['final'][0]) <= len(
            reduced['original'])  # input reduction removes tokens
        for word in reduced['final'][0]:  # no new words entered
            assert word in reduced['original']

        # test using NER model (tests different underlying logic)
        inputs = {
            "sentence": "Eric Wallace was an intern at AI2",
        }

        archive = load_archive(self.FIXTURES_ROOT / 'simple_tagger' /
                               'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'sentence-tagger')

        reducer = InputReduction(predictor)
        reduced = reducer.attack_from_json(inputs, 'tokens', 'grad_input_1')
        assert reduced is not None
        assert 'final' in reduced
        assert 'original' in reduced
        for reduced_input in reduced['final']:
            assert reduced_input  # always at least one token
            assert len(reduced_input) <= len(
                reduced['original'])  # input reduction removes tokens
            for word in reduced_input:  # no new words entered
                assert word in reduced['original']
    def test_input_reduction(self):
        # test using classification model
        inputs = {"sentence": "I always write unit tests for my code."}

        archive = load_archive(
            self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz"
        )
        predictor = Predictor.from_archive(archive)

        reducer = InputReduction(predictor)
        reduced = reducer.attack_from_json(inputs, "tokens", "grad_input_1")
        assert reduced is not None
        assert "final" in reduced
        assert "original" in reduced
        assert reduced["final"][0]  # always at least one token
        assert len(reduced["final"][0]) <= len(
            reduced["original"]
        )  # input reduction removes tokens
        for word in reduced["final"][0]:  # no new words entered
            assert word in reduced["original"]

        # test using NER model (tests different underlying logic)
        inputs = {"sentence": "Eric Wallace was an intern at AI2"}

        archive = load_archive(
            self.FIXTURES_ROOT / "simple_tagger" / "serialization" / "model.tar.gz"
        )
        predictor = Predictor.from_archive(archive, "sentence-tagger")

        reducer = InputReduction(predictor)
        reduced = reducer.attack_from_json(inputs, "tokens", "grad_input_1")
        assert reduced is not None
        assert "final" in reduced
        assert "original" in reduced
        for reduced_input in reduced["final"]:
            assert reduced_input  # always at least one token
            assert len(reduced_input) <= len(reduced["original"])  # input reduction removes tokens
            for word in reduced_input:  # no new words entered
                assert word in reduced["original"]
Example #5
0
def make_app(models: Dict[str, DemoModel],
             demo_db: Optional[DemoDatabase] = None,
             cache_size: int = 128,
             interpret_cache_size: int = 500,
             attack_cache_size: int = 500) -> Flask:

    app = Flask(__name__)  # pylint: disable=invalid-name
    start_time = datetime.now(pytz.utc)
    start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S %Z")

    app.predictors = {}
    app.max_request_lengths = {
    }  # requests longer than these will be rejected to prevent OOME
    app.attackers = defaultdict(dict)
    app.interpreters = defaultdict(dict)
    app.wsgi_app = ProxyFix(
        app.wsgi_app)  # sets the requester IP with the X-Forwarded-For header

    for name, demo_model in models.items():
        if demo_model is not None:
            logger.info(f"loading {name} model")
            predictor = demo_model.predictor()
            app.predictors[name] = predictor
            app.max_request_lengths[name] = demo_model.max_request_length

            if name in supported_interpret_models:
                app.interpreters[name]['simple_gradient'] = SimpleGradient(
                    predictor)
                app.interpreters[name][
                    'integrated_gradient'] = IntegratedGradient(predictor)
                app.interpreters[name]['smooth_gradient'] = SmoothGradient(
                    predictor)
                app.attackers[name]["input_reduction"] = InputReduction(
                    predictor)
                if name == 'masked-lm':
                    app.attackers[name]["hotflip"] = Hotflip(predictor, 'bert')
                elif name == "next-token-lm":
                    app.attackers[name]["hotflip"] = Hotflip(predictor, 'gpt2')
                elif 'named-entity-recognition' in name:
                    # We haven't implemented hotflip for NER.
                    continue
                elif name == 'textual-entailment':
                    # The SNLI model only has ELMo embeddings, which don't work with hotflip on
                    # their own.
                    continue
                else:
                    app.attackers[name]["hotflip"] = Hotflip(predictor)
                    app.attackers[name]["hotflip"].initialize()

    # Disable caching for HTML documents and API responses so that clients
    # always talk to the source (this server).
    @app.after_request
    def set_cache_headers(resp: Response) -> Response:
        if resp.mimetype == "text/html" or resp.mimetype == "application/json":
            return with_no_cache_headers(resp)
        else:
            return resp

    @app.errorhandler(ServerError)
    def handle_invalid_usage(error: ServerError) -> Response:  # pylint: disable=unused-variable
        response = jsonify(error.to_dict())
        response.status_code = error.status_code
        return response

    @lru_cache(maxsize=cache_size)
    def _caching_prediction(model: Predictor, data: str) -> JsonDict:
        """
        Just a wrapper around ``model.predict_json`` that allows us to use a cache decorator.
        """
        return model.predict_json(json.loads(data))

    @lru_cache(maxsize=interpret_cache_size)
    def _caching_interpret(interpreter: SaliencyInterpreter,
                           data: str) -> JsonDict:
        """
        Just a wrapper around ``model.interpret_from_json`` that allows us to use a cache decorator.
        """
        return interpreter.saliency_interpret_from_json(json.loads(data))

    @lru_cache(maxsize=attack_cache_size)
    def _caching_attack(attacker: Attacker, data: str,
                        input_field_to_attack: str, grad_input_field: str,
                        target: str) -> JsonDict:
        """
        Just a wrapper around ``model.attack_from_json`` that allows us to use a cache decorator.
        """
        return attacker.attack_from_json(
            inputs=json.loads(data),
            input_field_to_attack=input_field_to_attack,
            grad_input_field=grad_input_field,
            target=json.loads(target))

    @app.route('/')
    def index() -> str:  # pylint: disableunused-variable
        loaded_modules = {}
        for n, m in models.items():
            loaded_modules[n] = m.__dict__
        return jsonify({"allennlp_version": VERSION, "models": loaded_modules})

    @app.route('/permadata/<model_name>', methods=['POST', 'OPTIONS'])
    def permadata(model_name: str) -> Response:  # pylint: disable=unused-variable
        """
        If the user requests a permalink, the front end will POST here with the payload
            { slug: slug }
        which we convert to an integer id and use to retrieve saved results from the database.
        """
        # This is just CORS boilerplate.
        if request.method == "OPTIONS":
            return Response(response="", status=200)

        # If we don't have a database configured, there are no permalinks.
        if demo_db is None:
            raise ServerError('Permalinks are not enabled', 400)

        # Convert the provided slug to an integer id.
        slug = request.get_json()["slug"]
        perma_id = slug_to_int(slug)
        if perma_id is None:
            # Malformed slug
            raise ServerError("Unrecognized permalink: {}".format(slug), 400)

        # Fetch the results from the database.
        try:
            permadata = demo_db.get_result(perma_id)
        except psycopg2.Error:
            logger.exception(
                "Unable to get results from database: perma_id %s", perma_id)
            raise ServerError('Database trouble', 500)

        if permadata is None:
            # No data found, invalid id?
            raise ServerError("Unrecognized permalink: {}".format(slug), 400)

        return jsonify({
            "modelName": permadata.model_name,
            "requestData": permadata.request_data,
            "responseData": permadata.response_data
        })

    @app.route('/predict/<model_name>', methods=['POST', 'OPTIONS'])
    def predict(model_name: str) -> Response:  # pylint: disable=unused-variable
        """make a prediction using the specified model and return the results"""
        if request.method == "OPTIONS":
            return Response(response="", status=200)

        # Do log if no argument is specified
        record_to_database = request.args.get("record",
                                              "true").lower() != "false"

        # Do use the cache if no argument is specified
        use_cache = request.args.get("cache", "true").lower() != "false"

        lowered_model_name = model_name.lower()
        model = app.predictors.get(lowered_model_name)
        if model is None:
            raise ServerError("unknown model: {}".format(model_name),
                              status_code=400)
        max_request_length = app.max_request_lengths[lowered_model_name]

        data = request.get_json()

        serialized_request = json.dumps(data)
        if len(serialized_request) > max_request_length:
            raise ServerError(
                f"Max request length exceeded for model {model_name}! " +
                f"Max: {max_request_length} Actual: {len(serialized_request)}")

        logger.info("request: %s",
                    json.dumps({
                        "model": model_name,
                        "inputs": data
                    }))

        log_blob = {
            "model": model_name,
            "inputs": data,
            "cached": False,
            "outputs": {}
        }

        # Record the number of cache hits before we hit the cache so we can tell whether we hit or not.
        # In theory this could result in false positives.
        pre_hits = _caching_prediction.cache_info().hits  # pylint: disable=no-value-for-parameter

        if record_to_database and demo_db is not None:
            try:
                perma_id = None
                perma_id = demo_db.insert_request(
                    headers=dict(request.headers),
                    requester=request.remote_addr,
                    model_name=model_name,
                    inputs=data)

            except Exception:  # pylint: disable=broad-except
                # TODO(joelgrus): catch more specific errors
                logger.exception("Unable to add request to database",
                                 exc_info=True)

        if use_cache and cache_size > 0:
            # lru_cache insists that all function arguments be hashable,
            # so unfortunately we have to stringify the data.
            prediction = _caching_prediction(model, json.dumps(data))
        else:
            # if cache_size is 0, skip caching altogether
            prediction = model.predict_json(data)

        post_hits = _caching_prediction.cache_info().hits  # pylint: disable=no-value-for-parameter

        if record_to_database and demo_db is not None and perma_id is not None:
            try:
                demo_db.update_response(perma_id=perma_id, outputs=prediction)
                slug = int_to_slug(perma_id)
                prediction["slug"] = slug
                log_blob["slug"] = slug

            except Exception:  # pylint: disable=broad-except
                # TODO(joelgrus): catch more specific errors
                logger.exception("Unable to add response to database",
                                 exc_info=True)

        if use_cache and post_hits > pre_hits:
            # Cache hit, so insert an artifical pause
            log_blob["cached"] = True
            time.sleep(0.25)

        # The model predictions are extremely verbose, so we only log the most human-readable
        # parts of them.
        if "comprehension" in model_name:
            if 'best_span_str' in prediction:
                answer = prediction['best_span_str']
            else:
                answer = prediction['answer']
            log_blob["outputs"]["answer"] = answer
        elif model_name == "nmn-drop":
            answer = prediction['answer']
            log_blob["outputs"]["answer"] = answer
        elif model_name == "coreference-resolution":
            log_blob["outputs"]["clusters"] = prediction["clusters"]
            log_blob["outputs"]["document"] = prediction["document"]
        elif model_name == "textual-entailment":
            log_blob["outputs"]["label_probs"] = prediction["label_probs"]
        elif model_name == "sentiment-analysis":
            log_blob["outputs"]["probs"] = prediction["probs"]
        elif model_name == "named-entity-recognition":
            log_blob["outputs"]["tags"] = prediction["tags"]
        elif model_name == "semantic-role-labeling":
            verbs = []
            for verb in prediction["verbs"]:
                # Don't want to log boring verbs with no semantic parses.
                good_tags = [tag for tag in verb["tags"] if tag != "0"]
                if len(good_tags) > 1:
                    verbs.append({
                        "verb": verb["verb"],
                        "description": verb["description"]
                    })
            log_blob["outputs"]["verbs"] = verbs

        elif model_name == "constituency-parsing":
            log_blob["outputs"]["trees"] = prediction["trees"]
        elif model_name == "wikitables-parser":
            log_blob['outputs']['logical_form'] = prediction['logical_form']
            log_blob['outputs']['answer'] = prediction['answer']
        elif model_name == "nlvr-parser":
            log_blob['outputs']['logical_form'] = prediction['logical_form'][0]
            log_blob['outputs']['answer'] = prediction['denotations'][0][0]
        elif model_name == "atis-parser":
            log_blob['outputs']['predicted_sql_query'] = prediction[
                'predicted_sql_query']
        # TODO(brendanr): Add event2mind log_blob here?

        logger.info("prediction: %s", json.dumps(log_blob))

        return jsonify(prediction)

    @app.route('/attack/<model_name>', methods=['POST', 'OPTIONS'])
    def attack(model_name: str) -> Response:
        """
        Modify input to change prediction of model
        """
        if request.method == "OPTIONS":
            return Response(response="", status=200)

        # Do use the cache if no argument is specified
        use_cache = request.args.get("cache", "true").lower() != "false"

        lowered_model_name = model_name.lower()

        data = request.get_json()
        attacker_name = data.pop("attacker")
        input_field_to_attack = data.pop("inputToAttack")
        grad_input_field = data.pop("gradInput")
        target = data.pop("target", None)

        model_attackers = app.attackers.get(lowered_model_name)
        if model_attackers is None:
            raise ServerError("unknown model: {}".format(model_name),
                              status_code=400)
        attacker = model_attackers.get(attacker_name)
        if attacker is None:
            raise ServerError("unknown attacker for model: {} {}".format(
                attacker_name, model_name),
                              status_code=400)

        max_request_length = app.max_request_lengths[lowered_model_name]
        serialized_request = json.dumps(data)
        if len(serialized_request) > max_request_length:
            raise ServerError(
                f"Max request length exceeded for model {model_name}! " +
                f"Max: {max_request_length} Actual: {len(serialized_request)}")

        pre_hits = _caching_attack.cache_info().hits  # pylint: disable=no-value-for-parameter

        if use_cache and attack_cache_size > 0:
            # lru_cache insists that all function arguments be hashable,
            # so unfortunately we have to stringify the data.
            attack = _caching_attack(attacker, json.dumps(data),
                                     input_field_to_attack, grad_input_field,
                                     json.dumps(target))

        else:
            # if cache_size is 0, skip caching altogether
            attack = attacker.attack_from_json(
                inputs=data,
                input_field_to_attack=input_field_to_attack,
                grad_input_field=grad_input_field,
                target=target)

        post_hits = _caching_attack.cache_info().hits  # pylint: disable=no-value-for-parameter

        if use_cache and post_hits > pre_hits:
            # Cache hit, so insert an artifical pause
            time.sleep(0.25)

        return jsonify(attack)

    @app.route('/interpret/<model_name>', methods=['POST', 'OPTIONS'])
    def interpret(model_name: str) -> Response:
        """
        Interpret prediction of the model
        """
        if request.method == "OPTIONS":
            return Response(response="", status=200)

        # Do use the cache if no argument is specified
        use_cache = request.args.get("cache", "true").lower() != "false"

        lowered_model_name = model_name.lower()

        data = request.get_json()
        interpreter_name = data.pop("interpreter")

        model_interpreters = app.interpreters.get(lowered_model_name)
        if model_interpreters is None:
            raise ServerError(
                "no interpreters for model: {}".format(model_name),
                status_code=400)
        interpreter = model_interpreters.get(interpreter_name)
        if interpreter is None:
            raise ServerError("unknown interpreter for model: {} {}".format(
                interpreter_name, model_name),
                              status_code=400)

        max_request_length = app.max_request_lengths[lowered_model_name]

        serialized_request = json.dumps(data)
        if len(serialized_request) > max_request_length:
            raise ServerError(
                f"Max request length exceeded for interpreter {model_name}! " +
                f"Max: {max_request_length} Actual: {len(serialized_request)}")

        pre_hits = _caching_interpret.cache_info().hits  # pylint: disable=no-value-for-parameter

        if use_cache and interpret_cache_size > 0:
            # lru_cache insists that all function arguments be hashable,
            # so unfortunately we have to stringify the data.
            interpretation = _caching_interpret(interpreter, json.dumps(data))
        else:
            # if cache_size is 0, skip caching altogether
            interpretation = interpreter.saliency_interpret_from_json(data)

        post_hits = _caching_prediction.cache_info().hits  # pylint: disable=no-value-for-parameter

        if use_cache and post_hits > pre_hits:
            # Cache hit, so insert an artifical pause
            time.sleep(0.25)

        return jsonify(interpretation)

    @app.route('/models')
    def list_models() -> Response:  # pylint: disable=unused-variable
        """list the available models"""
        return jsonify({"models": list(app.predictors.keys())})

    @app.route('/info')
    def info() -> Response:  # pylint: disable=unused-variable
        """List metadata about the running webserver"""
        uptime = str(datetime.now(pytz.utc) - start_time)
        git_version = os.environ.get('ALLENNLP_DEMO_SOURCE_COMMIT') or ""
        return jsonify({
            "start_time":
            start_time_str,
            "uptime":
            uptime,
            "git_version":
            git_version,
            "peak_memory_mb":
            peak_memory_mb(),
            "githubUrl":
            "http://github.com/allenai/allennlp-demo/commit/" + git_version
        })

    @app.route('/health')
    def health() -> Response:  # pylint: disable=unused-variable
        return "healthy"

# As an SPA, we need to return index.html for /model-name and /model-name/permalink

    def return_page(permalink: str = None) -> Response:  # pylint: disable=unused-argument, unused-variable
        """return the page"""
        return send_file(os.path.join(build_dir, 'index.html'))

    for model_name in models:
        logger.info(f"setting up default routes for {model_name}")
        app.add_url_rule(f"/{model_name}", view_func=return_page)
        app.add_url_rule(f"/{model_name}/<permalink>", view_func=return_page)

    return app