def test_interpret_fails_when_embedding_layer_not_found(self):
        inputs = {"sentence": "It was the ending that I hated"}
        vocab = Vocabulary()
        vocab.add_tokens_to_namespace([w for w in inputs["sentence"].split(" ")])
        model = FakeModelForTestingInterpret(vocab, max_tokens=len(inputs["sentence"].split(" ")))
        predictor = TextClassifierPredictor(model, TextClassificationJsonReader())

        interpreter = SmoothGradient(predictor)
        with raises(RuntimeError):
    def test_smooth_gradient(self):
        inputs = {"sentence": "It was the ending that I hated"}
        archive = load_archive(self.FIXTURES_ROOT / "basic_classifier" /
                               "serialization" / "model.tar.gz")
        predictor = Predictor.from_archive(archive, "text_classifier")

        interpreter = SmoothGradient(predictor)
        interpretation = interpreter.saliency_interpret_from_json(inputs)
        assert interpretation is not None
        assert "instance_1" in interpretation
        assert "grad_input_1" in interpretation["instance_1"]
        assert len(interpretation["instance_1"]
                   ["grad_input_1"]) == 7  # 7 words in input
Exemple #3
    def test_interpret_works_with_custom_embedding_layer(self):
        inputs = {"sentence": "It was the ending that I hated"}
        vocab = Vocabulary()
        vocab.add_tokens_to_namespace([w for w in inputs["sentence"].split(" ")])
        model = FakeModelForTestingInterpret(vocab, max_tokens=len(inputs["sentence"].split(" ")))
        predictor = FakePredictorForTestingInterpret(model, TextClassificationJsonReader())
        interpreter = SmoothGradient(predictor)

        interpretation = interpreter.saliency_interpret_from_json(inputs)

        assert interpretation is not None
        assert "instance_1" in interpretation
        assert "grad_input_1" in interpretation["instance_1"]
        grad_input_1 = interpretation["instance_1"]["grad_input_1"]
        assert len(grad_input_1) == 7  # 7 words in input
 def load_interpreters(self) -> Mapping[str, SaliencyInterpreter]:
     Returns a mapping of interpreters keyed by a unique identifier. Requests to
     `/interpret/:id` will invoke the interpreter with the provided `:id`. Override this method
     to add or remove interpreters.
     return { "simple": SimpleGradient(self.predictor),
              "smooth": SmoothGradient(self.predictor),
              "integrated": IntegratedGradient(self.predictor) }
 def load_interpreters(self) -> Dict[str, SaliencyInterpreter]:
     Returns a mapping of interpreters keyed by a unique identifier. Requests to
     `/interpret/:id` will invoke the interpreter with the provided `:id`. Override this method
     to add or remove interpreters.
     interpreters: Dict[str, SaliencyInterpreter] = {}
     if "simple_gradient" in self.model.interpreters:
         interpreters["simple_gradient"] = SimpleGradient(self.predictor)
     if "smooth_gradient" in self.model.interpreters:
         interpreters["smooth_gradient"] = SmoothGradient(self.predictor)
     if "integrated_gradient" in self.model.interpreters:
         interpreters["integrated_gradient"] = IntegratedGradient(
     return interpreters
def make_app(models: Dict[str, DemoModel],
             demo_db: Optional[DemoDatabase] = None,
             cache_size: int = 128,
             interpret_cache_size: int = 500,
             attack_cache_size: int = 500) -> Flask:

    app = Flask(__name__)  # pylint: disable=invalid-name
    start_time =
    start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S %Z")

    app.predictors = {}
    app.max_request_lengths = {
    }  # requests longer than these will be rejected to prevent OOME
    app.attackers = defaultdict(dict)
    app.interpreters = defaultdict(dict)
    app.wsgi_app = ProxyFix(
        app.wsgi_app)  # sets the requester IP with the X-Forwarded-For header

    for name, demo_model in models.items():
        if demo_model is not None:
  "loading {name} model")
            predictor = demo_model.predictor()
            app.predictors[name] = predictor
            app.max_request_lengths[name] = demo_model.max_request_length

            if name in supported_interpret_models:
                app.interpreters[name]['simple_gradient'] = SimpleGradient(
                    'integrated_gradient'] = IntegratedGradient(predictor)
                app.interpreters[name]['smooth_gradient'] = SmoothGradient(
                app.attackers[name]["input_reduction"] = InputReduction(
                if name == 'masked-lm':
                    app.attackers[name]["hotflip"] = Hotflip(predictor, 'bert')
                elif name == "next-token-lm":
                    app.attackers[name]["hotflip"] = Hotflip(predictor, 'gpt2')
                elif 'named-entity-recognition' in name:
                    # We haven't implemented hotflip for NER.
                elif name == 'textual-entailment':
                    # The SNLI model only has ELMo embeddings, which don't work with hotflip on
                    # their own.
                    app.attackers[name]["hotflip"] = Hotflip(predictor)

    # Disable caching for HTML documents and API responses so that clients
    # always talk to the source (this server).
    def set_cache_headers(resp: Response) -> Response:
        if resp.mimetype == "text/html" or resp.mimetype == "application/json":
            return with_no_cache_headers(resp)
            return resp

    def handle_invalid_usage(error: ServerError) -> Response:  # pylint: disable=unused-variable
        response = jsonify(error.to_dict())
        response.status_code = error.status_code
        return response

    def _caching_prediction(model: Predictor, data: str) -> JsonDict:
        Just a wrapper around ``model.predict_json`` that allows us to use a cache decorator.
        return model.predict_json(json.loads(data))

    def _caching_interpret(interpreter: SaliencyInterpreter,
                           data: str) -> JsonDict:
        Just a wrapper around ``model.interpret_from_json`` that allows us to use a cache decorator.
        return interpreter.saliency_interpret_from_json(json.loads(data))

    def _caching_attack(attacker: Attacker, data: str,
                        input_field_to_attack: str, grad_input_field: str,
                        target: str) -> JsonDict:
        Just a wrapper around ``model.attack_from_json`` that allows us to use a cache decorator.
        return attacker.attack_from_json(

    def index() -> str:  # pylint: disableunused-variable
        loaded_modules = {}
        for n, m in models.items():
            loaded_modules[n] = m.__dict__
        return jsonify({"allennlp_version": VERSION, "models": loaded_modules})

    @app.route('/permadata/<model_name>', methods=['POST', 'OPTIONS'])
    def permadata(model_name: str) -> Response:  # pylint: disable=unused-variable
        If the user requests a permalink, the front end will POST here with the payload
            { slug: slug }
        which we convert to an integer id and use to retrieve saved results from the database.
        # This is just CORS boilerplate.
        if request.method == "OPTIONS":
            return Response(response="", status=200)

        # If we don't have a database configured, there are no permalinks.
        if demo_db is None:
            raise ServerError('Permalinks are not enabled', 400)

        # Convert the provided slug to an integer id.
        slug = request.get_json()["slug"]
        perma_id = slug_to_int(slug)
        if perma_id is None:
            # Malformed slug
            raise ServerError("Unrecognized permalink: {}".format(slug), 400)

        # Fetch the results from the database.
            permadata = demo_db.get_result(perma_id)
        except psycopg2.Error:
                "Unable to get results from database: perma_id %s", perma_id)
            raise ServerError('Database trouble', 500)

        if permadata is None:
            # No data found, invalid id?
            raise ServerError("Unrecognized permalink: {}".format(slug), 400)

        return jsonify({
            "modelName": permadata.model_name,
            "requestData": permadata.request_data,
            "responseData": permadata.response_data

    @app.route('/predict/<model_name>', methods=['POST', 'OPTIONS'])
    def predict(model_name: str) -> Response:  # pylint: disable=unused-variable
        """make a prediction using the specified model and return the results"""
        if request.method == "OPTIONS":
            return Response(response="", status=200)

        # Do log if no argument is specified
        record_to_database = request.args.get("record",
                                              "true").lower() != "false"

        # Do use the cache if no argument is specified
        use_cache = request.args.get("cache", "true").lower() != "false"

        lowered_model_name = model_name.lower()
        model = app.predictors.get(lowered_model_name)
        if model is None:
            raise ServerError("unknown model: {}".format(model_name),
        max_request_length = app.max_request_lengths[lowered_model_name]

        data = request.get_json()

        serialized_request = json.dumps(data)
        if len(serialized_request) > max_request_length:
            raise ServerError(
                f"Max request length exceeded for model {model_name}! " +
                f"Max: {max_request_length} Actual: {len(serialized_request)}")"request: %s",
                        "model": model_name,
                        "inputs": data

        log_blob = {
            "model": model_name,
            "inputs": data,
            "cached": False,
            "outputs": {}

        # Record the number of cache hits before we hit the cache so we can tell whether we hit or not.
        # In theory this could result in false positives.
        pre_hits = _caching_prediction.cache_info().hits  # pylint: disable=no-value-for-parameter

        if record_to_database and demo_db is not None:
                perma_id = None
                perma_id = demo_db.insert_request(

            except Exception:  # pylint: disable=broad-except
                # TODO(joelgrus): catch more specific errors
                logger.exception("Unable to add request to database",

        if use_cache and cache_size > 0:
            # lru_cache insists that all function arguments be hashable,
            # so unfortunately we have to stringify the data.
            prediction = _caching_prediction(model, json.dumps(data))
            # if cache_size is 0, skip caching altogether
            prediction = model.predict_json(data)

        post_hits = _caching_prediction.cache_info().hits  # pylint: disable=no-value-for-parameter

        if record_to_database and demo_db is not None and perma_id is not None:
                demo_db.update_response(perma_id=perma_id, outputs=prediction)
                slug = int_to_slug(perma_id)
                prediction["slug"] = slug
                log_blob["slug"] = slug

            except Exception:  # pylint: disable=broad-except
                # TODO(joelgrus): catch more specific errors
                logger.exception("Unable to add response to database",

        if use_cache and post_hits > pre_hits:
            # Cache hit, so insert an artifical pause
            log_blob["cached"] = True

        # The model predictions are extremely verbose, so we only log the most human-readable
        # parts of them.
        if "comprehension" in model_name:
            if 'best_span_str' in prediction:
                answer = prediction['best_span_str']
                answer = prediction['answer']
            log_blob["outputs"]["answer"] = answer
        elif model_name == "nmn-drop":
            answer = prediction['answer']
            log_blob["outputs"]["answer"] = answer
        elif model_name == "coreference-resolution":
            log_blob["outputs"]["clusters"] = prediction["clusters"]
            log_blob["outputs"]["document"] = prediction["document"]
        elif model_name == "textual-entailment":
            log_blob["outputs"]["label_probs"] = prediction["label_probs"]
        elif model_name == "sentiment-analysis":
            log_blob["outputs"]["probs"] = prediction["probs"]
        elif model_name == "named-entity-recognition":
            log_blob["outputs"]["tags"] = prediction["tags"]
        elif model_name == "semantic-role-labeling":
            verbs = []
            for verb in prediction["verbs"]:
                # Don't want to log boring verbs with no semantic parses.
                good_tags = [tag for tag in verb["tags"] if tag != "0"]
                if len(good_tags) > 1:
                        "verb": verb["verb"],
                        "description": verb["description"]
            log_blob["outputs"]["verbs"] = verbs

        elif model_name == "constituency-parsing":
            log_blob["outputs"]["trees"] = prediction["trees"]
        elif model_name == "wikitables-parser":
            log_blob['outputs']['logical_form'] = prediction['logical_form']
            log_blob['outputs']['answer'] = prediction['answer']
        elif model_name == "nlvr-parser":
            log_blob['outputs']['logical_form'] = prediction['logical_form'][0]
            log_blob['outputs']['answer'] = prediction['denotations'][0][0]
        elif model_name == "atis-parser":
            log_blob['outputs']['predicted_sql_query'] = prediction[
        # TODO(brendanr): Add event2mind log_blob here?"prediction: %s", json.dumps(log_blob))

        return jsonify(prediction)

    @app.route('/attack/<model_name>', methods=['POST', 'OPTIONS'])
    def attack(model_name: str) -> Response:
        Modify input to change prediction of model
        if request.method == "OPTIONS":
            return Response(response="", status=200)

        # Do use the cache if no argument is specified
        use_cache = request.args.get("cache", "true").lower() != "false"

        lowered_model_name = model_name.lower()

        data = request.get_json()
        attacker_name = data.pop("attacker")
        input_field_to_attack = data.pop("inputToAttack")
        grad_input_field = data.pop("gradInput")
        target = data.pop("target", None)

        model_attackers = app.attackers.get(lowered_model_name)
        if model_attackers is None:
            raise ServerError("unknown model: {}".format(model_name),
        attacker = model_attackers.get(attacker_name)
        if attacker is None:
            raise ServerError("unknown attacker for model: {} {}".format(
                attacker_name, model_name),

        max_request_length = app.max_request_lengths[lowered_model_name]
        serialized_request = json.dumps(data)
        if len(serialized_request) > max_request_length:
            raise ServerError(
                f"Max request length exceeded for model {model_name}! " +
                f"Max: {max_request_length} Actual: {len(serialized_request)}")

        pre_hits = _caching_attack.cache_info().hits  # pylint: disable=no-value-for-parameter

        if use_cache and attack_cache_size > 0:
            # lru_cache insists that all function arguments be hashable,
            # so unfortunately we have to stringify the data.
            attack = _caching_attack(attacker, json.dumps(data),
                                     input_field_to_attack, grad_input_field,

            # if cache_size is 0, skip caching altogether
            attack = attacker.attack_from_json(

        post_hits = _caching_attack.cache_info().hits  # pylint: disable=no-value-for-parameter

        if use_cache and post_hits > pre_hits:
            # Cache hit, so insert an artifical pause

        return jsonify(attack)

    @app.route('/interpret/<model_name>', methods=['POST', 'OPTIONS'])
    def interpret(model_name: str) -> Response:
        Interpret prediction of the model
        if request.method == "OPTIONS":
            return Response(response="", status=200)

        # Do use the cache if no argument is specified
        use_cache = request.args.get("cache", "true").lower() != "false"

        lowered_model_name = model_name.lower()

        data = request.get_json()
        interpreter_name = data.pop("interpreter")

        model_interpreters = app.interpreters.get(lowered_model_name)
        if model_interpreters is None:
            raise ServerError(
                "no interpreters for model: {}".format(model_name),
        interpreter = model_interpreters.get(interpreter_name)
        if interpreter is None:
            raise ServerError("unknown interpreter for model: {} {}".format(
                interpreter_name, model_name),

        max_request_length = app.max_request_lengths[lowered_model_name]

        serialized_request = json.dumps(data)
        if len(serialized_request) > max_request_length:
            raise ServerError(
                f"Max request length exceeded for interpreter {model_name}! " +
                f"Max: {max_request_length} Actual: {len(serialized_request)}")

        pre_hits = _caching_interpret.cache_info().hits  # pylint: disable=no-value-for-parameter

        if use_cache and interpret_cache_size > 0:
            # lru_cache insists that all function arguments be hashable,
            # so unfortunately we have to stringify the data.
            interpretation = _caching_interpret(interpreter, json.dumps(data))
            # if cache_size is 0, skip caching altogether
            interpretation = interpreter.saliency_interpret_from_json(data)

        post_hits = _caching_prediction.cache_info().hits  # pylint: disable=no-value-for-parameter

        if use_cache and post_hits > pre_hits:
            # Cache hit, so insert an artifical pause

        return jsonify(interpretation)

    def list_models() -> Response:  # pylint: disable=unused-variable
        """list the available models"""
        return jsonify({"models": list(app.predictors.keys())})

    def info() -> Response:  # pylint: disable=unused-variable
        """List metadata about the running webserver"""
        uptime = str( - start_time)
        git_version = os.environ.get('ALLENNLP_DEMO_SOURCE_COMMIT') or ""
        return jsonify({
            "" + git_version

    def health() -> Response:  # pylint: disable=unused-variable
        return "healthy"

# As an SPA, we need to return index.html for /model-name and /model-name/permalink

    def return_page(permalink: str = None) -> Response:  # pylint: disable=unused-argument, unused-variable
        """return the page"""
        return send_file(os.path.join(build_dir, 'index.html'))

    for model_name in models:"setting up default routes for {model_name}")
        app.add_url_rule(f"/{model_name}", view_func=return_page)
        app.add_url_rule(f"/{model_name}/<permalink>", view_func=return_page)

    return app