def test_interpret_fails_when_embedding_layer_not_found(self): inputs = {"sentence": "It was the ending that I hated"} vocab = Vocabulary() vocab.add_tokens_to_namespace( [w for w in inputs["sentence"].split(" ")]) model = FakeModelForTestingInterpret( vocab, max_tokens=len(inputs["sentence"].split(" "))) predictor = TextClassifierPredictor(model, TextClassificationJsonReader()) interpreter = IntegratedGradient(predictor) with raises(RuntimeError): interpreter.saliency_interpret_from_json(inputs)
def __init__(self, model_path=None, cuda_device=1): # model_path = model_path or LSTM_MODEL_PATH model_path = model_path or ROBERTA_MODEL_PATH self.predictor = Predictor.from_path(model_path, cuda_device=cuda_device) _tokenizer = PretrainedTransformerTokenizer( model_name="roberta-base", max_length=TRANSFORMER_WORDPIECE_LIMIT) class_name_mapper = {"0": "Negative", "1": "Positive"} _model = self.predictor._model _label_namespace = _model._label_namespace class_names = [ class_name_mapper[_model.vocab.get_index_to_token_vocabulary( _label_namespace).get(0)], class_name_mapper[_model.vocab.get_index_to_token_vocabulary( _label_namespace).get(1)] ] # reset the tokenizer to remove separators self.tokenizer = lambda s: [ t.text.replace("Ġ", "").replace('Ċ', '').replace('ĉ', "") for t in _tokenizer.tokenize(s) ][1:-1] self.explainer_lime = LimeTextExplainer( class_names=class_names, split_expression=self.tokenizer) self.explainer_integrate = IntegratedGradient(self.predictor) self.explainer_simple = SimpleGradient(self.predictor)
def load_interpreters(self) -> Mapping[str, SaliencyInterpreter]: """ Returns a mapping of interpreters keyed by a unique identifier. Requests to `/interpret/:id` will invoke the interpreter with the provided `:id`. Override this method to add or remove interpreters. """ return { "simple": SimpleGradient(self.predictor), "smooth": SmoothGradient(self.predictor), "integrated": IntegratedGradient(self.predictor) }
def test_interpret_works_with_custom_embedding_layer(self): inputs = {"sentence": "It was the ending that I hated"} vocab = Vocabulary() vocab.add_tokens_to_namespace( [w for w in inputs["sentence"].split(" ")]) model = FakeModelForTestingInterpret( vocab, max_tokens=len(inputs["sentence"].split(" "))) predictor = FakePredictorForTestingInterpret( model, TextClassificationJsonReader()) interpreter = IntegratedGradient(predictor) interpretation = interpreter.saliency_interpret_from_json(inputs) assert interpretation is not None assert "instance_1" in interpretation assert "grad_input_1" in interpretation["instance_1"] grad_input_1 = interpretation["instance_1"]["grad_input_1"] assert len(grad_input_1) == 7 # 7 words in input
def test_integrated_gradient(self): inputs = {"sentence": "It was the ending that I hated"} archive = load_archive(self.FIXTURES_ROOT / 'basic_classifier' / 'serialization' / 'model.tar.gz') predictor = Predictor.from_archive(archive, 'text_classifier') interpreter = IntegratedGradient(predictor) interpretation = interpreter.saliency_interpret_from_json(inputs) assert interpretation is not None assert 'instance_1' in interpretation assert 'grad_input_1' in interpretation['instance_1'] grad_input_1 = interpretation['instance_1']['grad_input_1'] assert len(grad_input_1) == 7 # 7 words in input # two interpretations should be identical for integrated gradients repeat_interpretation = interpreter.saliency_interpret_from_json( inputs) repeat_grad_input_1 = repeat_interpretation['instance_1'][ 'grad_input_1'] for grad, repeat_grad in zip(grad_input_1, repeat_grad_input_1): assert grad == approx(repeat_grad)
def load_interpreters(self) -> Dict[str, SaliencyInterpreter]: """ Returns a mapping of interpreters keyed by a unique identifier. Requests to `/interpret/:id` will invoke the interpreter with the provided `:id`. Override this method to add or remove interpreters. """ interpreters: Dict[str, SaliencyInterpreter] = {} if "simple_gradient" in self.model.interpreters: interpreters["simple_gradient"] = SimpleGradient(self.predictor) if "smooth_gradient" in self.model.interpreters: interpreters["smooth_gradient"] = SmoothGradient(self.predictor) if "integrated_gradient" in self.model.interpreters: interpreters["integrated_gradient"] = IntegratedGradient( self.predictor) return interpreters
def make_app(models: Dict[str, DemoModel], demo_db: Optional[DemoDatabase] = None, cache_size: int = 128, interpret_cache_size: int = 500, attack_cache_size: int = 500) -> Flask: app = Flask(__name__) # pylint: disable=invalid-name start_time = datetime.now(pytz.utc) start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S %Z") app.predictors = {} app.max_request_lengths = { } # requests longer than these will be rejected to prevent OOME app.attackers = defaultdict(dict) app.interpreters = defaultdict(dict) app.wsgi_app = ProxyFix( app.wsgi_app) # sets the requester IP with the X-Forwarded-For header for name, demo_model in models.items(): if demo_model is not None: logger.info(f"loading {name} model") predictor = demo_model.predictor() app.predictors[name] = predictor app.max_request_lengths[name] = demo_model.max_request_length if name in supported_interpret_models: app.interpreters[name]['simple_gradient'] = SimpleGradient( predictor) app.interpreters[name][ 'integrated_gradient'] = IntegratedGradient(predictor) app.interpreters[name]['smooth_gradient'] = SmoothGradient( predictor) app.attackers[name]["input_reduction"] = InputReduction( predictor) if name == 'masked-lm': app.attackers[name]["hotflip"] = Hotflip(predictor, 'bert') elif name == "next-token-lm": app.attackers[name]["hotflip"] = Hotflip(predictor, 'gpt2') elif 'named-entity-recognition' in name: # We haven't implemented hotflip for NER. continue elif name == 'textual-entailment': # The SNLI model only has ELMo embeddings, which don't work with hotflip on # their own. continue else: app.attackers[name]["hotflip"] = Hotflip(predictor) app.attackers[name]["hotflip"].initialize() # Disable caching for HTML documents and API responses so that clients # always talk to the source (this server). @app.after_request def set_cache_headers(resp: Response) -> Response: if resp.mimetype == "text/html" or resp.mimetype == "application/json": return with_no_cache_headers(resp) else: return resp @app.errorhandler(ServerError) def handle_invalid_usage(error: ServerError) -> Response: # pylint: disable=unused-variable response = jsonify(error.to_dict()) response.status_code = error.status_code return response @lru_cache(maxsize=cache_size) def _caching_prediction(model: Predictor, data: str) -> JsonDict: """ Just a wrapper around ``model.predict_json`` that allows us to use a cache decorator. """ return model.predict_json(json.loads(data)) @lru_cache(maxsize=interpret_cache_size) def _caching_interpret(interpreter: SaliencyInterpreter, data: str) -> JsonDict: """ Just a wrapper around ``model.interpret_from_json`` that allows us to use a cache decorator. """ return interpreter.saliency_interpret_from_json(json.loads(data)) @lru_cache(maxsize=attack_cache_size) def _caching_attack(attacker: Attacker, data: str, input_field_to_attack: str, grad_input_field: str, target: str) -> JsonDict: """ Just a wrapper around ``model.attack_from_json`` that allows us to use a cache decorator. """ return attacker.attack_from_json( inputs=json.loads(data), input_field_to_attack=input_field_to_attack, grad_input_field=grad_input_field, target=json.loads(target)) @app.route('/') def index() -> str: # pylint: disableunused-variable loaded_modules = {} for n, m in models.items(): loaded_modules[n] = m.__dict__ return jsonify({"allennlp_version": VERSION, "models": loaded_modules}) @app.route('/permadata/<model_name>', methods=['POST', 'OPTIONS']) def permadata(model_name: str) -> Response: # pylint: disable=unused-variable """ If the user requests a permalink, the front end will POST here with the payload { slug: slug } which we convert to an integer id and use to retrieve saved results from the database. """ # This is just CORS boilerplate. if request.method == "OPTIONS": return Response(response="", status=200) # If we don't have a database configured, there are no permalinks. if demo_db is None: raise ServerError('Permalinks are not enabled', 400) # Convert the provided slug to an integer id. slug = request.get_json()["slug"] perma_id = slug_to_int(slug) if perma_id is None: # Malformed slug raise ServerError("Unrecognized permalink: {}".format(slug), 400) # Fetch the results from the database. try: permadata = demo_db.get_result(perma_id) except psycopg2.Error: logger.exception( "Unable to get results from database: perma_id %s", perma_id) raise ServerError('Database trouble', 500) if permadata is None: # No data found, invalid id? raise ServerError("Unrecognized permalink: {}".format(slug), 400) return jsonify({ "modelName": permadata.model_name, "requestData": permadata.request_data, "responseData": permadata.response_data }) @app.route('/predict/<model_name>', methods=['POST', 'OPTIONS']) def predict(model_name: str) -> Response: # pylint: disable=unused-variable """make a prediction using the specified model and return the results""" if request.method == "OPTIONS": return Response(response="", status=200) # Do log if no argument is specified record_to_database = request.args.get("record", "true").lower() != "false" # Do use the cache if no argument is specified use_cache = request.args.get("cache", "true").lower() != "false" lowered_model_name = model_name.lower() model = app.predictors.get(lowered_model_name) if model is None: raise ServerError("unknown model: {}".format(model_name), status_code=400) max_request_length = app.max_request_lengths[lowered_model_name] data = request.get_json() serialized_request = json.dumps(data) if len(serialized_request) > max_request_length: raise ServerError( f"Max request length exceeded for model {model_name}! " + f"Max: {max_request_length} Actual: {len(serialized_request)}") logger.info("request: %s", json.dumps({ "model": model_name, "inputs": data })) log_blob = { "model": model_name, "inputs": data, "cached": False, "outputs": {} } # Record the number of cache hits before we hit the cache so we can tell whether we hit or not. # In theory this could result in false positives. pre_hits = _caching_prediction.cache_info().hits # pylint: disable=no-value-for-parameter if record_to_database and demo_db is not None: try: perma_id = None perma_id = demo_db.insert_request( headers=dict(request.headers), requester=request.remote_addr, model_name=model_name, inputs=data) except Exception: # pylint: disable=broad-except # TODO(joelgrus): catch more specific errors logger.exception("Unable to add request to database", exc_info=True) if use_cache and cache_size > 0: # lru_cache insists that all function arguments be hashable, # so unfortunately we have to stringify the data. prediction = _caching_prediction(model, json.dumps(data)) else: # if cache_size is 0, skip caching altogether prediction = model.predict_json(data) post_hits = _caching_prediction.cache_info().hits # pylint: disable=no-value-for-parameter if record_to_database and demo_db is not None and perma_id is not None: try: demo_db.update_response(perma_id=perma_id, outputs=prediction) slug = int_to_slug(perma_id) prediction["slug"] = slug log_blob["slug"] = slug except Exception: # pylint: disable=broad-except # TODO(joelgrus): catch more specific errors logger.exception("Unable to add response to database", exc_info=True) if use_cache and post_hits > pre_hits: # Cache hit, so insert an artifical pause log_blob["cached"] = True time.sleep(0.25) # The model predictions are extremely verbose, so we only log the most human-readable # parts of them. if "comprehension" in model_name: if 'best_span_str' in prediction: answer = prediction['best_span_str'] else: answer = prediction['answer'] log_blob["outputs"]["answer"] = answer elif model_name == "nmn-drop": answer = prediction['answer'] log_blob["outputs"]["answer"] = answer elif model_name == "coreference-resolution": log_blob["outputs"]["clusters"] = prediction["clusters"] log_blob["outputs"]["document"] = prediction["document"] elif model_name == "textual-entailment": log_blob["outputs"]["label_probs"] = prediction["label_probs"] elif model_name == "sentiment-analysis": log_blob["outputs"]["probs"] = prediction["probs"] elif model_name == "named-entity-recognition": log_blob["outputs"]["tags"] = prediction["tags"] elif model_name == "semantic-role-labeling": verbs = [] for verb in prediction["verbs"]: # Don't want to log boring verbs with no semantic parses. good_tags = [tag for tag in verb["tags"] if tag != "0"] if len(good_tags) > 1: verbs.append({ "verb": verb["verb"], "description": verb["description"] }) log_blob["outputs"]["verbs"] = verbs elif model_name == "constituency-parsing": log_blob["outputs"]["trees"] = prediction["trees"] elif model_name == "wikitables-parser": log_blob['outputs']['logical_form'] = prediction['logical_form'] log_blob['outputs']['answer'] = prediction['answer'] elif model_name == "nlvr-parser": log_blob['outputs']['logical_form'] = prediction['logical_form'][0] log_blob['outputs']['answer'] = prediction['denotations'][0][0] elif model_name == "atis-parser": log_blob['outputs']['predicted_sql_query'] = prediction[ 'predicted_sql_query'] # TODO(brendanr): Add event2mind log_blob here? logger.info("prediction: %s", json.dumps(log_blob)) return jsonify(prediction) @app.route('/attack/<model_name>', methods=['POST', 'OPTIONS']) def attack(model_name: str) -> Response: """ Modify input to change prediction of model """ if request.method == "OPTIONS": return Response(response="", status=200) # Do use the cache if no argument is specified use_cache = request.args.get("cache", "true").lower() != "false" lowered_model_name = model_name.lower() data = request.get_json() attacker_name = data.pop("attacker") input_field_to_attack = data.pop("inputToAttack") grad_input_field = data.pop("gradInput") target = data.pop("target", None) model_attackers = app.attackers.get(lowered_model_name) if model_attackers is None: raise ServerError("unknown model: {}".format(model_name), status_code=400) attacker = model_attackers.get(attacker_name) if attacker is None: raise ServerError("unknown attacker for model: {} {}".format( attacker_name, model_name), status_code=400) max_request_length = app.max_request_lengths[lowered_model_name] serialized_request = json.dumps(data) if len(serialized_request) > max_request_length: raise ServerError( f"Max request length exceeded for model {model_name}! " + f"Max: {max_request_length} Actual: {len(serialized_request)}") pre_hits = _caching_attack.cache_info().hits # pylint: disable=no-value-for-parameter if use_cache and attack_cache_size > 0: # lru_cache insists that all function arguments be hashable, # so unfortunately we have to stringify the data. attack = _caching_attack(attacker, json.dumps(data), input_field_to_attack, grad_input_field, json.dumps(target)) else: # if cache_size is 0, skip caching altogether attack = attacker.attack_from_json( inputs=data, input_field_to_attack=input_field_to_attack, grad_input_field=grad_input_field, target=target) post_hits = _caching_attack.cache_info().hits # pylint: disable=no-value-for-parameter if use_cache and post_hits > pre_hits: # Cache hit, so insert an artifical pause time.sleep(0.25) return jsonify(attack) @app.route('/interpret/<model_name>', methods=['POST', 'OPTIONS']) def interpret(model_name: str) -> Response: """ Interpret prediction of the model """ if request.method == "OPTIONS": return Response(response="", status=200) # Do use the cache if no argument is specified use_cache = request.args.get("cache", "true").lower() != "false" lowered_model_name = model_name.lower() data = request.get_json() interpreter_name = data.pop("interpreter") model_interpreters = app.interpreters.get(lowered_model_name) if model_interpreters is None: raise ServerError( "no interpreters for model: {}".format(model_name), status_code=400) interpreter = model_interpreters.get(interpreter_name) if interpreter is None: raise ServerError("unknown interpreter for model: {} {}".format( interpreter_name, model_name), status_code=400) max_request_length = app.max_request_lengths[lowered_model_name] serialized_request = json.dumps(data) if len(serialized_request) > max_request_length: raise ServerError( f"Max request length exceeded for interpreter {model_name}! " + f"Max: {max_request_length} Actual: {len(serialized_request)}") pre_hits = _caching_interpret.cache_info().hits # pylint: disable=no-value-for-parameter if use_cache and interpret_cache_size > 0: # lru_cache insists that all function arguments be hashable, # so unfortunately we have to stringify the data. interpretation = _caching_interpret(interpreter, json.dumps(data)) else: # if cache_size is 0, skip caching altogether interpretation = interpreter.saliency_interpret_from_json(data) post_hits = _caching_prediction.cache_info().hits # pylint: disable=no-value-for-parameter if use_cache and post_hits > pre_hits: # Cache hit, so insert an artifical pause time.sleep(0.25) return jsonify(interpretation) @app.route('/models') def list_models() -> Response: # pylint: disable=unused-variable """list the available models""" return jsonify({"models": list(app.predictors.keys())}) @app.route('/info') def info() -> Response: # pylint: disable=unused-variable """List metadata about the running webserver""" uptime = str(datetime.now(pytz.utc) - start_time) git_version = os.environ.get('ALLENNLP_DEMO_SOURCE_COMMIT') or "" return jsonify({ "start_time": start_time_str, "uptime": uptime, "git_version": git_version, "peak_memory_mb": peak_memory_mb(), "githubUrl": "http://github.com/allenai/allennlp-demo/commit/" + git_version }) @app.route('/health') def health() -> Response: # pylint: disable=unused-variable return "healthy" # As an SPA, we need to return index.html for /model-name and /model-name/permalink def return_page(permalink: str = None) -> Response: # pylint: disable=unused-argument, unused-variable """return the page""" return send_file(os.path.join(build_dir, 'index.html')) for model_name in models: logger.info(f"setting up default routes for {model_name}") app.add_url_rule(f"/{model_name}", view_func=return_page) app.add_url_rule(f"/{model_name}/<permalink>", view_func=return_page) return app