def __init__(self, args):
        self.logger = get_logger("minimal-rnr-qa")

        titles_file = get_first_matched_file_path(args.resources_path,
                                                  args.dataset, "*.titles.txt")
        self.logger.info(f"Loading titles from {titles_file}...")
        self.all_titles = read_txt(titles_file)

        docs_file = get_first_matched_file_path(args.resources_path,
                                                args.dataset, "*.docs.txt")
        self.logger.info(f"Loading docs from {docs_file}...")
        self.all_docs = read_txt(docs_file)

        if args.use_faiss_index:
            import faiss
            index_file = get_first_matched_file_path(args.resources_path,
                                                     args.dataset, "*.index")
            self.logger.info(f"Loading index from {index_file}...")
            self.index = faiss.read_index(index_file)

            import numpy as np
            self.np = np
        else:
            self.index = None

        self.tokenizer = None  # must be overriden

        self.max_retriever_input_len = 256
        self.max_reader_input_len = 350
        self.max_answer_len = 10
        self.num_contexts = 10
        self.num_passage_answer_candidates = 5
Exemple #2
0
def main(args):
    logger = get_logger("minimal-rnr-qa")
    logger.info(vars(args))

    if args.use_faiss_index:
        assert glob.glob(os.path.join(args.resources_path, args.dataset, "*.index")), \
            f"Index file does not exist in the path: {os.path.join(args.resources_path, args.dataset)}"

    minimal_rnr = TFServingMinimalRnR(args)
    run_app(args, minimal_rnr)
Exemple #3
0
def run_app(args, minimal_rnr):
    logger = get_logger("minimal-rnr-qa")

    inference_api = minimal_rnr.get_inference_api()

    app = Flask(__name__, static_url_path='/static')
    app.config["JSONIFY_PRETTYPRINT_REGULAR"] = False

    def _search(query, top_k, passage_score_weight):
        start = time()
        result = inference_api(query, top_k, passage_score_weight)
        return {"ret": result, "time": int((time() - start))}

    @app.route("/")
    def index():
        return app.send_static_file('index.html')

    @app.route("/files/<path:path>")
    def static_files(path):
        return app.send_static_file('files/' + path)

    @app.route("/api", methods=["GET"])
    def api():
        logger.info(request.args)

        query = request.args["query"]
        top_k = int(request.args["top_k"])

        if request.args["passage_score_weight"] == "null":
            passage_score_weight = None
        else:
            passage_score_weight = float(request.args["passage_score_weight"])

        result = _search(query, top_k, passage_score_weight)
        logger.info(result)
        return jsonify(result)

    @app.route("/get_examples", methods=["GET"])
    def get_examples():
        with open(args.examples_path, "r") as fp:
            examples = [line.strip() for line in fp.readlines()]
        return jsonify(examples)

    @app.route("/quit")
    def quit():
        raise KeyboardInterrupt

    logger.info("Warming up...")
    minimal_rnr.predict_answer("warmup", top_k=5, passage_score_weight=0.8)

    logger.info(f"Starting server at {args.demo_port}")
    http_server = HTTPServer(WSGIContainer(app))
    http_server.listen(args.demo_port)
    IOLoop.instance().start()
Exemple #4
0
def main(args):
    logger = get_logger("minimal-rnr-qa")
    logger.info(vars(args))

    minimal_rnr = TFServingMinimalRnR(args)
    minimal_rnr.inference_on_file(args.input_path, args.output_path, args.top_k, args.passage_score_weight)
Exemple #5
0
def get_model_tokenizer_device(args):
    # not to import torch and transformers as default
    import torch
    from torch import Tensor as T
    from torch import nn
    from transformers import MobileBertModel, MobileBertConfig, AutoTokenizer

    def init_weights(modules):
        for module in modules:
            if isinstance(module, (nn.Linear, nn.Embedding)):
                module.weight.data.normal_(mean=0.0, std=0.02)
            elif isinstance(module, nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()

    class HFMobileBertEncoder(MobileBertModel):
        def __init__(self,
                     config,
                     project_dim: int = 0,
                     ctx_bottleneck: bool = False):
            MobileBertModel.__init__(self, config)
            assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero'
            self.encode_proj = nn.Linear(
                config.hidden_size, project_dim) if project_dim != 0 else None
            self.decode_proj = nn.Sequential(
                nn.Tanh(),
                nn.Linear(project_dim,
                          (config.hidden_size + project_dim) // 2),
                nn.Tanh(),
                nn.Linear((config.hidden_size + project_dim) //
                          2, config.hidden_size),
            ) if ctx_bottleneck else None
            self.init_weights()

        @classmethod
        def init_encoder(cls, cfg_name: str) -> MobileBertModel:
            cfg = MobileBertConfig.from_pretrained(cfg_name)
            return cls.from_pretrained(cfg_name, config=cfg)

        def forward(self, input_ids: T, token_type_ids: T, attention_mask: T):
            if self.config.output_hidden_states:
                sequence_output, pooled_output, hidden_states = super(
                ).forward(input_ids=input_ids,
                          token_type_ids=token_type_ids,
                          attention_mask=attention_mask)
            else:
                hidden_states = None
                sequence_output, pooled_output = super().forward(
                    input_ids=input_ids,
                    token_type_ids=token_type_ids,
                    attention_mask=attention_mask)

            pooled_output = sequence_output[:, 0, :]
            return sequence_output, pooled_output, hidden_states

        def get_out_size(self):
            if self.encode_proj:
                return self.encode_proj.out_features
            return self.config.hidden_size

    class UnifiedRetrieverReader(nn.Module):
        def __init__(self, encoder: nn.Module):
            super(UnifiedRetrieverReader, self).__init__()

            self.emb_size = 128

            self.question_model = encoder
            hidden_size = encoder.config.hidden_size

            self.qa_outputs = nn.Linear(hidden_size, 2)
            self.qa_classifier = nn.Linear(hidden_size, 1)

            init_weights([self.qa_outputs, self.qa_classifier])

        @staticmethod
        def get_representation(sub_model: nn.Module, ids, segments, attn_mask):
            sequence_output = None
            pooled_output = None
            hidden_states = None
            if ids is not None:
                sequence_output, pooled_output, hidden_states = sub_model(
                    ids, segments, attn_mask)

            return sequence_output, pooled_output, hidden_states

        def forward(self,
                    retriever_input_ids=None,
                    retriever_token_type_ids=None,
                    retriever_attention_mask=None,
                    reader_input_ids=None,
                    reader_attention_mask=None,
                    reader_token_type_ids=None):

            if retriever_input_ids is not None:
                _, encoding, _ = self.get_representation(
                    self.question_model, retriever_input_ids,
                    retriever_token_type_ids, retriever_attention_mask)

                if self.emb_size is not None:
                    return encoding[:, :self.emb_size]
                return encoding

            if reader_input_ids is not None:
                start_logits, end_logits, relevance_logits = self._read(
                    reader_input_ids, reader_token_type_ids,
                    reader_attention_mask)
                return start_logits, end_logits, relevance_logits

        def _read(self, input_ids, token_type_ids, attention_mask):
            sequence_output, _pooled_output, _hidden_states = self.question_model(
                input_ids, token_type_ids, attention_mask)
            logits = self.qa_outputs(sequence_output)
            start_logits, end_logits = logits.split(1, dim=-1)
            start_logits = start_logits.squeeze(-1)
            end_logits = end_logits.squeeze(-1)

            qa_classifier_input = sequence_output[:, 0, :]
            relevance_logits = self.qa_classifier(qa_classifier_input)
            return start_logits, end_logits, relevance_logits

    cfg_name = "google/mobilebert-uncased"
    question_encoder = HFMobileBertEncoder.init_encoder(cfg_name)
    model = UnifiedRetrieverReader(question_encoder)
    tokenizer = AutoTokenizer.from_pretrained(cfg_name, do_lower_case=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_file = get_first_matched_file_path(args.model_path, args.dataset,
                                             "*.bin")
    logger = get_logger("minimal-rnr-qa")
    logger.info(f"Loading model from {model_file}...")
    model.load_state_dict(torch.load(model_file, map_location=device))
    model.to(device)
    model.eval()

    return model, tokenizer, device