def test_interpret_fails_when_embedding_layer_not_found(self):
        inputs = {"sentence": "It was the ending that I hated"}
        vocab = Vocabulary()
        vocab.add_tokens_to_namespace(
            [w for w in inputs["sentence"].split(" ")])
        model = FakeModelForTestingInterpret(
            vocab, max_tokens=len(inputs["sentence"].split(" ")))
        predictor = TextClassifierPredictor(model,
                                            TextClassificationJsonReader())

        interpreter = IntegratedGradient(predictor)
        with raises(RuntimeError):
            interpreter.saliency_interpret_from_json(inputs)
Пример #2
0
    def __init__(self, model_path=None, cuda_device=1):
        # model_path = model_path or LSTM_MODEL_PATH
        model_path = model_path or ROBERTA_MODEL_PATH
        self.predictor = Predictor.from_path(model_path,
                                             cuda_device=cuda_device)

        _tokenizer = PretrainedTransformerTokenizer(
            model_name="roberta-base", max_length=TRANSFORMER_WORDPIECE_LIMIT)
        class_name_mapper = {"0": "Negative", "1": "Positive"}
        _model = self.predictor._model
        _label_namespace = _model._label_namespace
        class_names = [
            class_name_mapper[_model.vocab.get_index_to_token_vocabulary(
                _label_namespace).get(0)],
            class_name_mapper[_model.vocab.get_index_to_token_vocabulary(
                _label_namespace).get(1)]
        ]
        # reset the tokenizer to remove separators
        self.tokenizer = lambda s: [
            t.text.replace("Ġ", "").replace('Ċ', '').replace('ĉ', "")
            for t in _tokenizer.tokenize(s)
        ][1:-1]
        self.explainer_lime = LimeTextExplainer(
            class_names=class_names, split_expression=self.tokenizer)
        self.explainer_integrate = IntegratedGradient(self.predictor)
        self.explainer_simple = SimpleGradient(self.predictor)
Пример #3
0
 def load_interpreters(self) -> Mapping[str, SaliencyInterpreter]:
     """
     Returns a mapping of interpreters keyed by a unique identifier. Requests to
     `/interpret/:id` will invoke the interpreter with the provided `:id`. Override this method
     to add or remove interpreters.
     """
     return { "simple": SimpleGradient(self.predictor),
              "smooth": SmoothGradient(self.predictor),
              "integrated": IntegratedGradient(self.predictor) }
    def test_interpret_works_with_custom_embedding_layer(self):
        inputs = {"sentence": "It was the ending that I hated"}
        vocab = Vocabulary()
        vocab.add_tokens_to_namespace(
            [w for w in inputs["sentence"].split(" ")])
        model = FakeModelForTestingInterpret(
            vocab, max_tokens=len(inputs["sentence"].split(" ")))
        predictor = FakePredictorForTestingInterpret(
            model, TextClassificationJsonReader())
        interpreter = IntegratedGradient(predictor)

        interpretation = interpreter.saliency_interpret_from_json(inputs)

        assert interpretation is not None
        assert "instance_1" in interpretation
        assert "grad_input_1" in interpretation["instance_1"]
        grad_input_1 = interpretation["instance_1"]["grad_input_1"]
        assert len(grad_input_1) == 7  # 7 words in input
Пример #5
0
    def test_integrated_gradient(self):
        inputs = {"sentence": "It was the ending that I hated"}
        archive = load_archive(self.FIXTURES_ROOT / 'basic_classifier' /
                               'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'text_classifier')

        interpreter = IntegratedGradient(predictor)
        interpretation = interpreter.saliency_interpret_from_json(inputs)
        assert interpretation is not None
        assert 'instance_1' in interpretation
        assert 'grad_input_1' in interpretation['instance_1']
        grad_input_1 = interpretation['instance_1']['grad_input_1']
        assert len(grad_input_1) == 7  # 7 words in input

        # two interpretations should be identical for integrated gradients
        repeat_interpretation = interpreter.saliency_interpret_from_json(
            inputs)
        repeat_grad_input_1 = repeat_interpretation['instance_1'][
            'grad_input_1']
        for grad, repeat_grad in zip(grad_input_1, repeat_grad_input_1):
            assert grad == approx(repeat_grad)
Пример #6
0
 def load_interpreters(self) -> Dict[str, SaliencyInterpreter]:
     """
     Returns a mapping of interpreters keyed by a unique identifier. Requests to
     `/interpret/:id` will invoke the interpreter with the provided `:id`. Override this method
     to add or remove interpreters.
     """
     interpreters: Dict[str, SaliencyInterpreter] = {}
     if "simple_gradient" in self.model.interpreters:
         interpreters["simple_gradient"] = SimpleGradient(self.predictor)
     if "smooth_gradient" in self.model.interpreters:
         interpreters["smooth_gradient"] = SmoothGradient(self.predictor)
     if "integrated_gradient" in self.model.interpreters:
         interpreters["integrated_gradient"] = IntegratedGradient(
             self.predictor)
     return interpreters
Пример #7
0
def make_app(models: Dict[str, DemoModel],
             demo_db: Optional[DemoDatabase] = None,
             cache_size: int = 128,
             interpret_cache_size: int = 500,
             attack_cache_size: int = 500) -> Flask:

    app = Flask(__name__)  # pylint: disable=invalid-name
    start_time = datetime.now(pytz.utc)
    start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S %Z")

    app.predictors = {}
    app.max_request_lengths = {
    }  # requests longer than these will be rejected to prevent OOME
    app.attackers = defaultdict(dict)
    app.interpreters = defaultdict(dict)
    app.wsgi_app = ProxyFix(
        app.wsgi_app)  # sets the requester IP with the X-Forwarded-For header

    for name, demo_model in models.items():
        if demo_model is not None:
            logger.info(f"loading {name} model")
            predictor = demo_model.predictor()
            app.predictors[name] = predictor
            app.max_request_lengths[name] = demo_model.max_request_length

            if name in supported_interpret_models:
                app.interpreters[name]['simple_gradient'] = SimpleGradient(
                    predictor)
                app.interpreters[name][
                    'integrated_gradient'] = IntegratedGradient(predictor)
                app.interpreters[name]['smooth_gradient'] = SmoothGradient(
                    predictor)
                app.attackers[name]["input_reduction"] = InputReduction(
                    predictor)
                if name == 'masked-lm':
                    app.attackers[name]["hotflip"] = Hotflip(predictor, 'bert')
                elif name == "next-token-lm":
                    app.attackers[name]["hotflip"] = Hotflip(predictor, 'gpt2')
                elif 'named-entity-recognition' in name:
                    # We haven't implemented hotflip for NER.
                    continue
                elif name == 'textual-entailment':
                    # The SNLI model only has ELMo embeddings, which don't work with hotflip on
                    # their own.
                    continue
                else:
                    app.attackers[name]["hotflip"] = Hotflip(predictor)
                    app.attackers[name]["hotflip"].initialize()

    # Disable caching for HTML documents and API responses so that clients
    # always talk to the source (this server).
    @app.after_request
    def set_cache_headers(resp: Response) -> Response:
        if resp.mimetype == "text/html" or resp.mimetype == "application/json":
            return with_no_cache_headers(resp)
        else:
            return resp

    @app.errorhandler(ServerError)
    def handle_invalid_usage(error: ServerError) -> Response:  # pylint: disable=unused-variable
        response = jsonify(error.to_dict())
        response.status_code = error.status_code
        return response

    @lru_cache(maxsize=cache_size)
    def _caching_prediction(model: Predictor, data: str) -> JsonDict:
        """
        Just a wrapper around ``model.predict_json`` that allows us to use a cache decorator.
        """
        return model.predict_json(json.loads(data))

    @lru_cache(maxsize=interpret_cache_size)
    def _caching_interpret(interpreter: SaliencyInterpreter,
                           data: str) -> JsonDict:
        """
        Just a wrapper around ``model.interpret_from_json`` that allows us to use a cache decorator.
        """
        return interpreter.saliency_interpret_from_json(json.loads(data))

    @lru_cache(maxsize=attack_cache_size)
    def _caching_attack(attacker: Attacker, data: str,
                        input_field_to_attack: str, grad_input_field: str,
                        target: str) -> JsonDict:
        """
        Just a wrapper around ``model.attack_from_json`` that allows us to use a cache decorator.
        """
        return attacker.attack_from_json(
            inputs=json.loads(data),
            input_field_to_attack=input_field_to_attack,
            grad_input_field=grad_input_field,
            target=json.loads(target))

    @app.route('/')
    def index() -> str:  # pylint: disableunused-variable
        loaded_modules = {}
        for n, m in models.items():
            loaded_modules[n] = m.__dict__
        return jsonify({"allennlp_version": VERSION, "models": loaded_modules})

    @app.route('/permadata/<model_name>', methods=['POST', 'OPTIONS'])
    def permadata(model_name: str) -> Response:  # pylint: disable=unused-variable
        """
        If the user requests a permalink, the front end will POST here with the payload
            { slug: slug }
        which we convert to an integer id and use to retrieve saved results from the database.
        """
        # This is just CORS boilerplate.
        if request.method == "OPTIONS":
            return Response(response="", status=200)

        # If we don't have a database configured, there are no permalinks.
        if demo_db is None:
            raise ServerError('Permalinks are not enabled', 400)

        # Convert the provided slug to an integer id.
        slug = request.get_json()["slug"]
        perma_id = slug_to_int(slug)
        if perma_id is None:
            # Malformed slug
            raise ServerError("Unrecognized permalink: {}".format(slug), 400)

        # Fetch the results from the database.
        try:
            permadata = demo_db.get_result(perma_id)
        except psycopg2.Error:
            logger.exception(
                "Unable to get results from database: perma_id %s", perma_id)
            raise ServerError('Database trouble', 500)

        if permadata is None:
            # No data found, invalid id?
            raise ServerError("Unrecognized permalink: {}".format(slug), 400)

        return jsonify({
            "modelName": permadata.model_name,
            "requestData": permadata.request_data,
            "responseData": permadata.response_data
        })

    @app.route('/predict/<model_name>', methods=['POST', 'OPTIONS'])
    def predict(model_name: str) -> Response:  # pylint: disable=unused-variable
        """make a prediction using the specified model and return the results"""
        if request.method == "OPTIONS":
            return Response(response="", status=200)

        # Do log if no argument is specified
        record_to_database = request.args.get("record",
                                              "true").lower() != "false"

        # Do use the cache if no argument is specified
        use_cache = request.args.get("cache", "true").lower() != "false"

        lowered_model_name = model_name.lower()
        model = app.predictors.get(lowered_model_name)
        if model is None:
            raise ServerError("unknown model: {}".format(model_name),
                              status_code=400)
        max_request_length = app.max_request_lengths[lowered_model_name]

        data = request.get_json()

        serialized_request = json.dumps(data)
        if len(serialized_request) > max_request_length:
            raise ServerError(
                f"Max request length exceeded for model {model_name}! " +
                f"Max: {max_request_length} Actual: {len(serialized_request)}")

        logger.info("request: %s",
                    json.dumps({
                        "model": model_name,
                        "inputs": data
                    }))

        log_blob = {
            "model": model_name,
            "inputs": data,
            "cached": False,
            "outputs": {}
        }

        # Record the number of cache hits before we hit the cache so we can tell whether we hit or not.
        # In theory this could result in false positives.
        pre_hits = _caching_prediction.cache_info().hits  # pylint: disable=no-value-for-parameter

        if record_to_database and demo_db is not None:
            try:
                perma_id = None
                perma_id = demo_db.insert_request(
                    headers=dict(request.headers),
                    requester=request.remote_addr,
                    model_name=model_name,
                    inputs=data)

            except Exception:  # pylint: disable=broad-except
                # TODO(joelgrus): catch more specific errors
                logger.exception("Unable to add request to database",
                                 exc_info=True)

        if use_cache and cache_size > 0:
            # lru_cache insists that all function arguments be hashable,
            # so unfortunately we have to stringify the data.
            prediction = _caching_prediction(model, json.dumps(data))
        else:
            # if cache_size is 0, skip caching altogether
            prediction = model.predict_json(data)

        post_hits = _caching_prediction.cache_info().hits  # pylint: disable=no-value-for-parameter

        if record_to_database and demo_db is not None and perma_id is not None:
            try:
                demo_db.update_response(perma_id=perma_id, outputs=prediction)
                slug = int_to_slug(perma_id)
                prediction["slug"] = slug
                log_blob["slug"] = slug

            except Exception:  # pylint: disable=broad-except
                # TODO(joelgrus): catch more specific errors
                logger.exception("Unable to add response to database",
                                 exc_info=True)

        if use_cache and post_hits > pre_hits:
            # Cache hit, so insert an artifical pause
            log_blob["cached"] = True
            time.sleep(0.25)

        # The model predictions are extremely verbose, so we only log the most human-readable
        # parts of them.
        if "comprehension" in model_name:
            if 'best_span_str' in prediction:
                answer = prediction['best_span_str']
            else:
                answer = prediction['answer']
            log_blob["outputs"]["answer"] = answer
        elif model_name == "nmn-drop":
            answer = prediction['answer']
            log_blob["outputs"]["answer"] = answer
        elif model_name == "coreference-resolution":
            log_blob["outputs"]["clusters"] = prediction["clusters"]
            log_blob["outputs"]["document"] = prediction["document"]
        elif model_name == "textual-entailment":
            log_blob["outputs"]["label_probs"] = prediction["label_probs"]
        elif model_name == "sentiment-analysis":
            log_blob["outputs"]["probs"] = prediction["probs"]
        elif model_name == "named-entity-recognition":
            log_blob["outputs"]["tags"] = prediction["tags"]
        elif model_name == "semantic-role-labeling":
            verbs = []
            for verb in prediction["verbs"]:
                # Don't want to log boring verbs with no semantic parses.
                good_tags = [tag for tag in verb["tags"] if tag != "0"]
                if len(good_tags) > 1:
                    verbs.append({
                        "verb": verb["verb"],
                        "description": verb["description"]
                    })
            log_blob["outputs"]["verbs"] = verbs

        elif model_name == "constituency-parsing":
            log_blob["outputs"]["trees"] = prediction["trees"]
        elif model_name == "wikitables-parser":
            log_blob['outputs']['logical_form'] = prediction['logical_form']
            log_blob['outputs']['answer'] = prediction['answer']
        elif model_name == "nlvr-parser":
            log_blob['outputs']['logical_form'] = prediction['logical_form'][0]
            log_blob['outputs']['answer'] = prediction['denotations'][0][0]
        elif model_name == "atis-parser":
            log_blob['outputs']['predicted_sql_query'] = prediction[
                'predicted_sql_query']
        # TODO(brendanr): Add event2mind log_blob here?

        logger.info("prediction: %s", json.dumps(log_blob))

        return jsonify(prediction)

    @app.route('/attack/<model_name>', methods=['POST', 'OPTIONS'])
    def attack(model_name: str) -> Response:
        """
        Modify input to change prediction of model
        """
        if request.method == "OPTIONS":
            return Response(response="", status=200)

        # Do use the cache if no argument is specified
        use_cache = request.args.get("cache", "true").lower() != "false"

        lowered_model_name = model_name.lower()

        data = request.get_json()
        attacker_name = data.pop("attacker")
        input_field_to_attack = data.pop("inputToAttack")
        grad_input_field = data.pop("gradInput")
        target = data.pop("target", None)

        model_attackers = app.attackers.get(lowered_model_name)
        if model_attackers is None:
            raise ServerError("unknown model: {}".format(model_name),
                              status_code=400)
        attacker = model_attackers.get(attacker_name)
        if attacker is None:
            raise ServerError("unknown attacker for model: {} {}".format(
                attacker_name, model_name),
                              status_code=400)

        max_request_length = app.max_request_lengths[lowered_model_name]
        serialized_request = json.dumps(data)
        if len(serialized_request) > max_request_length:
            raise ServerError(
                f"Max request length exceeded for model {model_name}! " +
                f"Max: {max_request_length} Actual: {len(serialized_request)}")

        pre_hits = _caching_attack.cache_info().hits  # pylint: disable=no-value-for-parameter

        if use_cache and attack_cache_size > 0:
            # lru_cache insists that all function arguments be hashable,
            # so unfortunately we have to stringify the data.
            attack = _caching_attack(attacker, json.dumps(data),
                                     input_field_to_attack, grad_input_field,
                                     json.dumps(target))

        else:
            # if cache_size is 0, skip caching altogether
            attack = attacker.attack_from_json(
                inputs=data,
                input_field_to_attack=input_field_to_attack,
                grad_input_field=grad_input_field,
                target=target)

        post_hits = _caching_attack.cache_info().hits  # pylint: disable=no-value-for-parameter

        if use_cache and post_hits > pre_hits:
            # Cache hit, so insert an artifical pause
            time.sleep(0.25)

        return jsonify(attack)

    @app.route('/interpret/<model_name>', methods=['POST', 'OPTIONS'])
    def interpret(model_name: str) -> Response:
        """
        Interpret prediction of the model
        """
        if request.method == "OPTIONS":
            return Response(response="", status=200)

        # Do use the cache if no argument is specified
        use_cache = request.args.get("cache", "true").lower() != "false"

        lowered_model_name = model_name.lower()

        data = request.get_json()
        interpreter_name = data.pop("interpreter")

        model_interpreters = app.interpreters.get(lowered_model_name)
        if model_interpreters is None:
            raise ServerError(
                "no interpreters for model: {}".format(model_name),
                status_code=400)
        interpreter = model_interpreters.get(interpreter_name)
        if interpreter is None:
            raise ServerError("unknown interpreter for model: {} {}".format(
                interpreter_name, model_name),
                              status_code=400)

        max_request_length = app.max_request_lengths[lowered_model_name]

        serialized_request = json.dumps(data)
        if len(serialized_request) > max_request_length:
            raise ServerError(
                f"Max request length exceeded for interpreter {model_name}! " +
                f"Max: {max_request_length} Actual: {len(serialized_request)}")

        pre_hits = _caching_interpret.cache_info().hits  # pylint: disable=no-value-for-parameter

        if use_cache and interpret_cache_size > 0:
            # lru_cache insists that all function arguments be hashable,
            # so unfortunately we have to stringify the data.
            interpretation = _caching_interpret(interpreter, json.dumps(data))
        else:
            # if cache_size is 0, skip caching altogether
            interpretation = interpreter.saliency_interpret_from_json(data)

        post_hits = _caching_prediction.cache_info().hits  # pylint: disable=no-value-for-parameter

        if use_cache and post_hits > pre_hits:
            # Cache hit, so insert an artifical pause
            time.sleep(0.25)

        return jsonify(interpretation)

    @app.route('/models')
    def list_models() -> Response:  # pylint: disable=unused-variable
        """list the available models"""
        return jsonify({"models": list(app.predictors.keys())})

    @app.route('/info')
    def info() -> Response:  # pylint: disable=unused-variable
        """List metadata about the running webserver"""
        uptime = str(datetime.now(pytz.utc) - start_time)
        git_version = os.environ.get('ALLENNLP_DEMO_SOURCE_COMMIT') or ""
        return jsonify({
            "start_time":
            start_time_str,
            "uptime":
            uptime,
            "git_version":
            git_version,
            "peak_memory_mb":
            peak_memory_mb(),
            "githubUrl":
            "http://github.com/allenai/allennlp-demo/commit/" + git_version
        })

    @app.route('/health')
    def health() -> Response:  # pylint: disable=unused-variable
        return "healthy"

# As an SPA, we need to return index.html for /model-name and /model-name/permalink

    def return_page(permalink: str = None) -> Response:  # pylint: disable=unused-argument, unused-variable
        """return the page"""
        return send_file(os.path.join(build_dir, 'index.html'))

    for model_name in models:
        logger.info(f"setting up default routes for {model_name}")
        app.add_url_rule(f"/{model_name}", view_func=return_page)
        app.add_url_rule(f"/{model_name}/<permalink>", view_func=return_page)

    return app