Exemplo n.º 1
0
def run_training(savename: str,
                 train_config: TrainConfig,
                 dataset_oversampling: Dict[str, int],
                 n_processes: int,
                 use_cudnn: bool
                 ):
    """Train a Cape-Flavoured DocumentQA model.

    After preparing the datasets for training, a model will be created and saved in a directory
    specified by `savename`. Logging (Tensorboard) can be found in the log subdirectory of the model directory.

    The datasets to train the model on are specified in the `dataset_oversampling` dictionary.
    E.g. {'squad': 2, 'wiki':1} will train a model on one equivalence of triviaqa wiki and two equivalences of squad.

    :param savename: Name of model
    :param train_config: cape_config.TrainConfig object containing hyperparameters etc
    :param dataset_oversampling: dictionary mapping dataset names to integer counts of how much
       to oversample them
    :param n_processes: Number of processes to paralellize prepro on
    :param use_cudnn: Whether to train with GRU's optimized for Cudnn (recommended)
    """

    model = build_model(WithIndicators(), train_config, use_cudnn=use_cudnn)
    data = prepare_data(model, train_config, dataset_oversampling, n_processes)
    eval = get_evaluators(train_config)
    params = get_training_params(train_config)

    with open(__file__, "r", encoding='utf8') as f:
        notes = f.read()
    notes = "Mode: " + train_config.trivia_qa_mode + "\n" + notes
    notes += '\nDataset oversampling : ' + str(dataset_oversampling)

    # pull the trigger
    trainer.start_training(data, model, params, eval, model_dir.ModelDir(savename), notes)
Exemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "corpus",
        choices=["en", "fr", "de", "ru", "pt", "zh", "pl", "uk", "ta"])
    parser.add_argument(
        'mode',
        choices=["confidence", "merge", "shared-norm", "sigmoid", "paragraph"])
    # Note I haven't tested modes other than `shared-norm` on this corpus, so
    # some things might need adjusting
    parser.add_argument("-t",
                        "--n_tokens",
                        default=400,
                        type=int,
                        help="Paragraph size")
    args = parser.parse_args()
    mode = args.mode
    corpus = args.corpus

    model = get_model(100, 140, mode, WithIndicators())

    extract = ExtractMultiParagraphsPerQuestion(MergeParagraphs(args.n_tokens),
                                                ShallowOpenWebRanker(16),
                                                model.preprocessor,
                                                intern=True)

    oversample = [
        1
    ] * 2  # Sample the top two answer-containing paragraphs twice

    if mode == "paragraph":
        n_epochs = 120
        test = RandomParagraphSetDatasetBuilder(120, "flatten", True,
                                                oversample)
        train = StratifyParagraphsBuilder(ClusteredBatcher(
            60, ContextLenBucketedKey(3), True),
                                          oversample,
                                          only_answers=True)
    elif mode == "confidence" or mode == "sigmoid":
        if mode == "sigmoid":
            n_epochs = 640
        else:
            n_epochs = 160
        test = RandomParagraphSetDatasetBuilder(120, "flatten", True,
                                                oversample)
        train = StratifyParagraphsBuilder(
            ClusteredBatcher(60, ContextLenBucketedKey(3), True), oversample)
    else:
        n_epochs = 80
        test = RandomParagraphSetDatasetBuilder(
            120, "merge" if mode == "merge" else "group", True, oversample)
        train = StratifyParagraphSetsBuilder(30, mode == "merge", True,
                                             oversample)

    data = XQADataset(corpus)

    data = PreprocessedData(data, extract, train, test, eval_on_verified=False)

    data.preprocess(1, 1000)

    # dump preprocessed dev data for bert
    data.cache_preprocess("dev_data_%s.pkl" % args.corpus)
Exemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("corpus", choices=["en", "en_trans_de", "en_trans_zh"])
    parser.add_argument(
        'mode',
        choices=["confidence", "merge", "shared-norm", "sigmoid", "paragraph"])
    # Note I haven't tested modes other than `shared-norm` on this corpus, so
    # some things might need adjusting
    parser.add_argument("name", help="Where to store the model")
    parser.add_argument("-t",
                        "--n_tokens",
                        default=400,
                        type=int,
                        help="Paragraph size")
    parser.add_argument(
        '-n',
        '--n_processes',
        type=int,
        default=2,
        help="Number of processes (i.e., select which paragraphs to train on) "
        "the data with")
    args = parser.parse_args()
    mode = args.mode
    corpus = args.corpus

    out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S")

    model = get_model(100, 140, mode, WithIndicators())

    extract = ExtractMultiParagraphsPerQuestion(MergeParagraphs(args.n_tokens),
                                                ShallowOpenWebRanker(16),
                                                model.preprocessor,
                                                intern=True)

    eval = [
        LossEvaluator(),
        MultiParagraphSpanEvaluator(8,
                                    "triviaqa",
                                    mode != "merge",
                                    per_doc=False)
    ]
    oversample = [
        1
    ] * 2  # Sample the top two answer-containing paragraphs twice

    if mode == "paragraph":
        n_epochs = 120
        test = RandomParagraphSetDatasetBuilder(120, "flatten", True,
                                                oversample)
        train = StratifyParagraphsBuilder(ClusteredBatcher(
            60, ContextLenBucketedKey(3), True),
                                          oversample,
                                          only_answers=True)
    elif mode == "confidence" or mode == "sigmoid":
        if mode == "sigmoid":
            n_epochs = 640
        else:
            n_epochs = 160
        test = RandomParagraphSetDatasetBuilder(120, "flatten", True,
                                                oversample)
        train = StratifyParagraphsBuilder(
            ClusteredBatcher(60, ContextLenBucketedKey(3), True), oversample)
    else:
        n_epochs = 80
        test = RandomParagraphSetDatasetBuilder(
            120, "merge" if mode == "merge" else "group", True, oversample)
        train = StratifyParagraphSetsBuilder(30, mode == "merge", True,
                                             oversample)

    data = XQADataset(corpus)

    params = TrainParams(SerializableOptimizer("Adadelta",
                                               dict(learning_rate=1)),
                         num_epochs=n_epochs,
                         ema=0.999,
                         max_checkpoints_to_keep=2,
                         async_encoding=10,
                         log_period=30,
                         eval_period=1800,
                         save_period=1800,
                         best_weights=("dev", "b8/question-text-f1"),
                         eval_samples=dict(dev=None, train=6000))

    data = PreprocessedData(data, extract, train, test, eval_on_verified=False)

    data.preprocess(args.n_processes, 1000)

    with open(__file__, "r") as f:
        notes = f.read()
    notes = "Mode: " + args.mode + "\n" + notes
    trainer.start_training(data, model, params, eval, model_dir.ModelDir(out),
                           notes)
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser(description='Train a model on TriviaQA web')
    parser.add_argument('mode', choices=["paragraph-level", "confidence", "merge",
                                         "shared-norm", "sigmoid", "shared-norm-600"])
    parser.add_argument("name", help="Where to store the model")
    parser.add_argument('-n', '--n_processes', type=int, default=2,
                        help="Number of processes (i.e., select which paragraphs to train on) "
                             "the data with")
    args = parser.parse_args()
    mode = args.mode

    out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S")

    model = get_model(100, 140, mode, WithIndicators())

    stop = NltkPlusStopWords(True)

    if mode == "paragraph-level":
        extract = ExtractSingleParagraph(MergeParagraphs(400), TopTfIdf(stop, 1), model.preprocessor, intern=True)
    elif mode == "shared-norm-600":
        extract = ExtractMultiParagraphs(MergeParagraphs(600), TopTfIdf(stop, 4), model.preprocessor, intern=True)
    else:
        extract = ExtractMultiParagraphs(MergeParagraphs(400), TopTfIdf(stop, 4), model.preprocessor, intern=True)
    
    if mode == "paragraph-level":
        n_epochs = 16
        train = ParagraphAndQuestionsBuilder(ClusteredBatcher(60, ContextLenBucketedKey(3), True))
        test = ParagraphAndQuestionsBuilder(ClusteredBatcher(60, ContextLenKey(), False))
        n_dev, n_train = 21000, 12000
        eval = [LossEvaluator(), SpanEvaluator([4, 8], "triviaqa")]
    else:
        eval = [LossEvaluator(), MultiParagraphSpanEvaluator(8, "triviaqa", mode != "merge")]
        # we sample two paragraphs per a (question, doc) pair, so evaluate on fewer questions
        n_dev, n_train = 15000, 8000

        if mode == "confidence" or mode == "sigmoid":
            if mode == "sigmoid":
                # Trains very slowly, do this at your own risk
                n_epochs = 71
            else:
                n_epochs = 28
            test = RandomParagraphSetDatasetBuilder(120, "flatten", True, 1)
            train = StratifyParagraphsBuilder(ClusteredBatcher(60, ContextLenBucketedKey(3), True), 0, 1)
        else:
            n_epochs = 14
            test = RandomParagraphSetDatasetBuilder(120, "merge" if mode == "merge" else "group", True, 1)
            train = StratifyParagraphSetsBuilder(35, mode == "merge", True, 1)

    data = TriviaQaWebDataset()

    params = get_triviaqa_train_params(n_epochs, n_dev, n_train)

    data = PreprocessedData(data, extract, train, test, eval_on_verified=False)

    data.preprocess(args.n_processes, 1000)

    with open(__file__, "r") as f:
        notes = f.read()
    notes = "*" * 10 + "\nMode: " + args.mode + "\n" + "*"*10 + "\n" + notes

    trainer.start_training(data, model, params, eval, model_dir.ModelDir(out), notes)
Exemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser(
        description='Train a model on document-level SQuAD')
    parser.add_argument(
        'mode',
        choices=["paragraph", "confidence", "shared-norm", "merge", "sigmoid"])
    parser.add_argument("name", help="Output directory")
    args = parser.parse_args()
    mode = args.mode
    out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S")

    corpus = SquadCorpus()
    if mode == "merge":
        # Adds paragraph start tokens, since we will be concatenating paragraphs together
        pre = WithIndicators(True, para_tokens=False, doc_start_token=False)
    else:
        pre = None

    model = get_model(50, 100, args.mode, pre)

    if mode == "paragraph":
        # Run in the "standard" known-paragraph setting
        if model.preprocessor is not None:
            raise NotImplementedError()
        n_epochs = 26
        train_batching = ClusteredBatcher(45, ContextLenBucketedKey(3), True,
                                          False)
        eval_batching = ClusteredBatcher(45, ContextLenKey(), False, False)
        data = DocumentQaTrainingData(corpus, None, train_batching,
                                      eval_batching)
        eval = [LossEvaluator(), SpanEvaluator(bound=[17], text_eval="squad")]
    else:
        eval_set_mode = {
            "confidence": "flatten",
            "sigmoid": "flatten",
            "shared-norm": "group",
            "merge": "merge"
        }[mode]
        eval_dataset = RandomParagraphSetDatasetBuilder(
            100, eval_set_mode, True, 0)

        if mode == "confidence" or mode == "sigmoid":
            if mode == "sigmoid":
                # needs to be trained for a really long time for reasons unknown, even this might be too small
                n_epochs = 100
            else:
                n_epochs = 50  # more epochs since we only "see" the label very other epoch-osh
            train_batching = ClusteredBatcher(45, ContextLenBucketedKey(3),
                                              True, False)
            data = PreprocessedData(
                SquadCorpus(),
                SquadTfIdfRanker(NltkPlusStopWords(True), 4, True,
                                 model.preprocessor),
                StratifyParagraphsBuilder(train_batching, 1),
                eval_dataset,
                eval_on_verified=False,
            )
        else:
            n_epochs = 26
            data = PreprocessedData(
                SquadCorpus(),
                SquadTfIdfRanker(NltkPlusStopWords(True), 4, True,
                                 model.preprocessor),
                StratifyParagraphSetsBuilder(25, args.mode == "merge", True,
                                             1),
                eval_dataset,
                eval_on_verified=False,
            )

        eval = [LossEvaluator(), MultiParagraphSpanEvaluator(17, "squad")]
        data.preprocess(1)

    with open(__file__, "r") as f:
        notes = f.read()
        notes = args.mode + "\n" + notes

    trainer.start_training(data, model, train_params(n_epochs), eval,
                           model_dir.ModelDir(out), notes)
Exemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_processes',
        type=int,
        default=1,
        help=
        "Number of processes to do the preprocessing (selecting paragraphs+loading context) with"
    )
    parser.add_argument('-a', '--async', type=int, default=10)
    parser.add_argument('-t',
                        '--tokens',
                        type=int,
                        default=400,
                        help="Max tokens per a paragraph")
    parser.add_argument('-n',
                        '--n_sample',
                        type=int,
                        default=None,
                        help="Number of questions to evaluate on")
    parser.add_argument('-g',
                        '--n_paragraphs',
                        type=int,
                        default=15,
                        help="Number of paragraphs to run the model on")
    parser.add_argument('-f',
                        '--filter',
                        type=str,
                        default=None,
                        choices=["tfidf", "truncate", "linear"],
                        help="How to select paragraphs")
    parser.add_argument(
        '-c',
        '--corpus',
        choices=[
            "en_dev", "en_test", "fr_dev", "fr_test", "de_dev", "de_test",
            "ru_dev", "ru_test", "pt_dev", "pt_test", "zh_dev", "zh_test",
            "pl_dev", "pl_test", "uk_dev", "uk_test", "ta_dev", "ta_test",
            "fr_trans_en_dev", "fr_trans_en_test", "de_trans_en_dev",
            "de_trans_en_test", "ru_trans_en_dev", "ru_trans_en_test",
            "pt_trans_en_dev", "pt_trans_en_test", "zh_trans_en_dev",
            "zh_trans_en_test", "pl_trans_en_dev", "pl_trans_en_test",
            "uk_trans_en_dev", "uk_trans_en_test", "ta_trans_en_dev",
            "ta_trans_en_test"
        ],
        required=True)
    args = parser.parse_args()

    corpus_name = args.corpus[:args.corpus.rfind("_")]
    eval_set = args.corpus[args.corpus.rfind("_") + 1:]
    dataset = XQADataset(corpus_name)
    if eval_set == "dev":
        test_questions = dataset.get_dev()
    elif eval_set == "test":
        test_questions = dataset.get_test()
    else:
        raise AssertionError()

    corpus = dataset.evidence
    splitter = MergeParagraphs(args.tokens)

    per_document = args.corpus.startswith(
        "web")  # wiki and web are both multi-document

    filter_name = args.filter
    if filter_name is None:
        # Pick default depending on the kind of data we are using
        if per_document:
            filter_name = "tfidf"
        else:
            filter_name = "linear"

    print("Selecting %d paragraphs using method \"%s\" per %s" %
          (args.n_paragraphs, filter_name,
           ("question-document pair" if per_document else "question")))

    if filter_name == "tfidf":
        para_filter = TopTfIdf(NltkPlusStopWords(punctuation=True),
                               args.n_paragraphs)
    elif filter_name == "truncate":
        para_filter = FirstN(args.n_paragraphs)
    elif filter_name == "linear":
        para_filter = ShallowOpenWebRanker(args.n_paragraphs)
    else:
        raise ValueError()

    n_questions = args.n_sample
    if n_questions is not None:
        test_questions.sort(key=lambda x: x.question_id)
        np.random.RandomState(0).shuffle(test_questions)
        test_questions = test_questions[:n_questions]

    preprocessor = WithIndicators()
    print("Building question/paragraph pairs...")
    # Loads the relevant questions/documents, selects the right paragraphs, and runs the model's preprocessor
    if per_document:
        prep = ExtractMultiParagraphs(splitter,
                                      para_filter,
                                      preprocessor,
                                      require_an_answer=False)
    else:
        prep = ExtractMultiParagraphsPerQuestion(splitter,
                                                 para_filter,
                                                 preprocessor,
                                                 require_an_answer=False)
    prepped_data = preprocess_par(test_questions, corpus, prep,
                                  args.n_processes, 1000)

    data = []
    for q in prepped_data.data:
        for i, p in enumerate(q.paragraphs):
            if q.answer_text is None:
                ans = None
            else:
                ans = TokenSpans(q.answer_text, p.answer_spans)
            data.append(
                DocumentParagraphQuestion(q.question_id, p.doc_id,
                                          (p.start, p.end), q.question, p.text,
                                          ans, i))

    # Reverse so our first batch will be the largest (so OOMs happen early)
    questions = sorted(data,
                       key=lambda x: (x.n_context_words, len(x.question)),
                       reverse=True)

    # dump eval data for bert
    import pickle
    pickle.dump(questions,
                open("%s_%d.pkl" % (args.corpus, args.n_paragraphs), "wb"))
Exemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser(description='Run the demo server')
    parser.add_argument(
        'model',
        default=
        "/home/antriv/conversation_ai/ALLENAI_DocumentQA/document-qa/models/triviaqa-unfiltered-shared-norm/best-weights",
        help='Models to use')

    parser.add_argument(
        '-v',
        '--voc',
        default=
        "/home/antriv/conversation_ai/ALLENAI_DocumentQA/document-qa/data/triviaqa/evidence/vocab.txt",
        help='vocab to use, only words from this file will be used')
    parser.add_argument('-t',
                        '--tokens',
                        type=int,
                        default=400,
                        help='Number of tokens to use per paragraph')
    parser.add_argument('--vec_dir',
                        default="/home/antriv/data/glove",
                        help='Location to find word vectors')
    parser.add_argument('--n_paragraphs',
                        type=int,
                        default=15,
                        help="Number of paragraphs to run the model on")
    parser.add_argument('--paragraphs_to_return',
                        type=int,
                        default=10,
                        help="Number of paragraphs return to the frontend")
    parser.add_argument('--span_bound',
                        type=int,
                        default=8,
                        help="Max span size to return as an answer")

    parser.add_argument(
        '--tagme_api_key',
        default="1cdc0067-b2de-4774-afbe-38703b11a365-843339462",
        help="Key to use for TAGME (tagme.d4science.org/tagme)")
    parser.add_argument('--bing_api_key',
                        default="413239df9faa4f1494a914e0c9cec78e",
                        help="Key to use for bing searches")
    parser.add_argument(
        '--bing_version',
        choices=["v5.0", "v7.0"],
        default="v7.0",
        help='Version of Bing API to use (must be compatible with the API key)'
    )
    parser.add_argument(
        '--tagme_thresh',
        default=0.2,
        type=float,
        help="TAGME threshold for when to use the identified docs")
    parser.add_argument('--n_web',
                        type=int,
                        default=10,
                        help='Number of web docs to fetch')
    parser.add_argument('--blacklist_trivia_sites',
                        action="store_true",
                        help="Don't use trivia websites")
    parser.add_argument(
        '-c',
        '--wiki_cache',
        default=
        "/home/antriv/conversation_ai/ALLENAI_DocumentQA/document-qa/data/triviaqa/evidence/wikipedia",
        help="Cache wiki articles in this directory")

    parser.add_argument('--n_dl_threads',
                        type=int,
                        default=5,
                        help="Number of threads to download documents with")
    parser.add_argument('--request_timeout', type=int, default=60)
    parser.add_argument('--download_timeout',
                        type=int,
                        default=25,
                        help="how long to wait before timing out downloads")
    parser.add_argument('--workers',
                        type=int,
                        default=1,
                        help="Number of server workers")
    parser.add_argument('--debug',
                        default=None,
                        choices=["random_model", "dummy_qa"])

    args = parser.parse_args()
    span_bound = args.span_bound
    n_to_return = args.paragraphs_to_return

    if args.tagme_api_key is not None:
        tagme_api_key = args.tagme_api_key
    else:
        tagme_api_key = environ.get("TAGME_API_KEY")

    if args.bing_api_key is not None:
        bing_api_key = args.bing_api_key
    else:
        bing_api_key = environ.get("BING_API_KEY")
        if bing_api_key is None and args.n_web > 0:
            raise ValueError("If n_web > 0 you must give a BING_API_KEY")

    if args.debug is None:
        model = ModelDir(args.model)
    else:
        model = RandomPredictor(5, WithIndicators())

    if args.vec_dir is not None:
        loader = LoadFromPath(args.vec_dir)
    else:
        loader = ResourceLoader()

    # Update Sanic's logging to register our class's loggers
    log_config = LOGGING
    formatter = "%(asctime)s: %(levelname)s: %(message)s"
    log_config["formatters"]['my_formatter'] = {
        'format': formatter,
        'datefmt': '%Y-%m-%d %H:%M:%S',
    }
    log_config['handlers']['stream_handler'] = {
        'class': "logging.StreamHandler",
        'formatter': 'my_formatter',
        'stream': sys.stderr
    }
    log_config['handlers']['file_handler'] = {
        'class': "logging.FileHandler",
        'formatter': 'my_formatter',
        'filename': 'logging.log'
    }

    # It looks like we have to go and name every logger our own code might
    # use in order to register it with Sanic
    log_config["loggers"]['qa_system'] = {
        'level': 'INFO',
        'handlers': ['stream_handler', 'file_handler'],
    }
    log_config["loggers"]['downloader'] = {
        'level': 'INFO',
        'handlers': ['stream_handler', 'file_handler'],
    }
    log_config["loggers"]['server'] = {
        'level': 'INFO',
        'handlers': ['stream_handler', 'file_handler'],
    }

    app = Sanic()
    app.config.REQUEST_TIMEOUT = args.request_timeout
    log = logging.getLogger('server')

    @app.listener('before_server_start')
    async def setup_qa(app, loop):
        # To play nice with iohttp's async ClientSession objects, we need to construct the QaSystem
        # inside the event loop.
        if args.debug == "dummy_qa":
            qa = DummyQa()
        else:
            qa = QaSystem(
                args.wiki_cache,
                MergeParagraphs(args.tokens),
                ShallowOpenWebRanker(args.n_paragraphs),
                args.voc,
                model,
                loader,
                bing_api_key,
                bing_version=args.bing_version,
                tagme_api_key=tagme_api_key,
                n_dl_threads=args.n_dl_threads,
                blacklist_trivia_sites=args.blacklist_trivia_sites,
                download_timeout=args.download_timeout,
                span_bound=span_bound,
                tagme_threshold=None if
                (tagme_api_key is None) else args.tagme_thresh,
                n_web_docs=args.n_web,
            )
        app.qa = qa

    @app.listener('after_server_stop')
    async def setup_qa(app, loop):
        app.qa.close()

    @app.route("/answer")
    async def answer(request):
        try:
            question = request.args["question"][0]
            if question == "":
                return response.json({'message': 'No question given'},
                                     status=400)
            spans, paras = await app.qa.answer_question(question)
            answers = select_answers(paras, spans, span_bound, 10)
            answers = answers[:n_to_return]
            best_span = max(answers[0].answers, key=lambda x: x.conf)
            log.info("Answered \"%s\" (with web search): \"%s\"", question,
                     answers[0].original_text[best_span.start:best_span.end])
            return json([x.to_json() for x in answers])
        except Exception as e:
            log.info("Error: " + str(e))
            raise ServerError(e, status_code=500)

    @app.route('/answer-from', methods=['POST'])
    async def answer_from(request):
        try:
            args = ujson.loads(request.body.decode("utf-8"))
            question = args.get("question")
            if question is None or question == "":
                return response.json({'message': 'No question given'},
                                     status=400)
            doc = args["document"]
            if len(doc) > 500000:
                raise ServerError("Document too large", status_code=400)
            spans, paras = app.qa.answer_with_doc(question, doc)
            answers = select_answers(paras, spans, span_bound, 10)
            answers = answers[:n_to_return]
            best_span = max(answers[0].answers, key=lambda x: x.conf)
            log.info("Answered \"%s\" (with user doc): \"%s\"", question,
                     answers[0].original_text[best_span.start:best_span.end])
            return json([x.to_json() for x in answers])
        except Exception as e:
            log.info("Error: " + str(e))
            raise ServerError(e, status_code=500)

    app.static('/', './docqa//server/static/index.html')
    app.static('/about.html', './docqa/server/static/about.html')
    app.run(host="0.0.0.0",
            port=5000,
            workers=args.workers,
            debug=False,
            log_config=LOGGING)
Exemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser(
        description='Train a model on TriviaQA unfiltered')
    parser.add_argument(
        'mode',
        choices=["confidence", "merge", "shared-norm", "sigmoid", "paragraph"])
    parser.add_argument("name", help="Where to store the model")
    parser.add_argument("-t",
                        "--n_tokens",
                        default=400,
                        type=int,
                        help="Paragraph size")
    parser.add_argument(
        '-n',
        '--n_processes',
        type=int,
        default=2,
        help="Number of processes (i.e., select which paragraphs to train on) "
        "the data with")
    parser.add_argument("-s",
                        "--source_dir",
                        type=str,
                        default=None,
                        help="where to take input files")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=None,
                        help="Max number of epoches to train on ")
    parser.add_argument("--char_th",
                        type=int,
                        default=None,
                        help="char level embeddings")
    parser.add_argument("--hl_dim",
                        type=int,
                        default=None,
                        help="hidden layer dim size")
    parser.add_argument("--regularization",
                        type=int,
                        default=None,
                        help="hidden layer dim size")
    parser.add_argument("--LR",
                        type=float,
                        default=1.0,
                        help="hidden layer dim size")
    parser.add_argument("--save_every",
                        type=int,
                        default=1800,
                        help="save period")

    parser.add_argument("--init_from",
                        type=str,
                        default=None,
                        help="model to init from")
    args = parser.parse_args()
    mode = args.mode

    #out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S")
    out = join('models', args.name)

    char_th = 100
    hl_dim = 140
    if args.char_th is not None:
        print(args.char_th)
        char_th = int(args.char_th)
        out += '--th' + str(char_th)
    if args.hl_dim is not None:
        print(args.hl_dim)
        hl_dim = int(args.hl_dim)
        out += '--hl' + str(hl_dim)

    if args.init_from is None:
        model = get_model(char_th, hl_dim, mode, WithIndicators())
    else:
        md = model_dir.ModelDir(args.init_from)
        model = md.get_model()

    extract = ExtractMultiParagraphsPerQuestion(MergeParagraphs(args.n_tokens),
                                                ShallowOpenWebRanker(16),
                                                model.preprocessor,
                                                intern=True)

    eval = [
        LossEvaluator(),
        MultiParagraphSpanEvaluator(8,
                                    "triviaqa",
                                    mode != "merge",
                                    per_doc=False)
    ]
    oversample = [1] * 4

    if mode == "paragraph":
        n_epochs = 120
        test = RandomParagraphSetDatasetBuilder(120, "flatten", True,
                                                oversample)
        train = StratifyParagraphsBuilder(ClusteredBatcher(
            60, ContextLenBucketedKey(3), True),
                                          oversample,
                                          only_answers=True)
    elif mode == "confidence" or mode == "sigmoid":
        if mode == "sigmoid":
            n_epochs = 640
        else:
            n_epochs = 160
        test = RandomParagraphSetDatasetBuilder(120, "flatten", True,
                                                oversample)
        train = StratifyParagraphsBuilder(
            ClusteredBatcher(60, ContextLenBucketedKey(3), True), oversample)
    else:
        n_epochs = 80
        test = RandomParagraphSetDatasetBuilder(
            120, "merge" if mode == "merge" else "group", True, oversample)
        train = StratifyParagraphSetsBuilder(30, mode == "merge", True,
                                             oversample)

    if args.n_epochs is not None:
        n_epochs = args.n_epochs
        out += '--' + str(n_epochs)

    if args.LR != 1.0:
        out += '--' + str(args.LR)

    data = TriviaQaOpenDataset(args.source_dir)

    async_encoding = 10
    #async_encoding = 0
    params = TrainParams(SerializableOptimizer("Adadelta",
                                               dict(learning_rate=args.LR)),
                         num_epochs=n_epochs,
                         num_of_steps=250000,
                         ema=0.999,
                         max_checkpoints_to_keep=2,
                         async_encoding=async_encoding,
                         log_period=30,
                         eval_period=1800,
                         save_period=args.save_every,
                         eval_samples=dict(dev=None, train=6000),
                         regularization_weight=None)

    data = PreprocessedData(data, extract, train, test, eval_on_verified=False)

    data.preprocess(args.n_processes, 1000)

    with open(__file__, "r") as f:
        notes = f.read()
    notes = "Mode: " + args.mode + "\n" + notes

    if args.init_from is not None:
        init_from = model_dir.ModelDir(args.init_from).get_best_weights()
        if init_from is None:
            init_from = model_dir.ModelDir(
                args.init_from).get_latest_checkpoint()
    else:
        init_from = None

    trainer.start_training(data,
                           model,
                           params,
                           eval,
                           model_dir.ModelDir(out),
                           notes,
                           initialize_from=init_from)
Exemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser(description='Run the demo server')
    parser.add_argument('model', help='Models to use')

    parser.add_argument(
        '-v',
        '--voc',
        help='vocab to use, only words from this file will be used')
    parser.add_argument('-t',
                        '--tokens',
                        type=int,
                        default=400,
                        help='Number of tokens to use per paragraph')
    parser.add_argument('--vec_dir', help='Location to find word vectors')
    parser.add_argument('--n_paragraphs',
                        type=int,
                        default=12,
                        help="Number of paragraphs to run the model on")
    parser.add_argument('--span_bound',
                        type=int,
                        default=8,
                        help="Max span size to return as an answer")

    parser.add_argument(
        '--tagme_api_key',
        help="Key to use for TAGME (tagme.d4science.org/tagme)")
    parser.add_argument('--bing_api_key', help="Key to use for bing searches")
    parser.add_argument('--tagme_thresh', default=0.2, type=float)
    parser.add_argument('--no_wiki',
                        action="store_true",
                        help="Dont use TAGME")
    parser.add_argument('--n_web',
                        type=int,
                        default=10,
                        help='Number of web docs to fetch')
    parser.add_argument('--blacklist_trivia_sites',
                        action="store_true",
                        help="Don't use trivia websites")
    parser.add_argument('-c',
                        '--wiki_cache',
                        help="Cache wiki articles in this directory")

    parser.add_argument('--n_dl_threads',
                        type=int,
                        default=5,
                        help="Number of threads to download documents with")
    parser.add_argument('--request_timeout', type=int, default=60)
    parser.add_argument('--download_timeout', type=int, default=25)
    parser.add_argument('--workers',
                        type=int,
                        default=1,
                        help="Number of server workers")
    parser.add_argument('--debug',
                        default=None,
                        choices=["random_model", "dummy_qa"])

    args = parser.parse_args()
    span_bound = args.span_bound

    if args.tagme_api_key is not None:
        tagme_api_key = args.tagme_api_key
    else:
        tagme_api_key = environ.get("TAGME_API_KEY")

    if args.bing_api_key is not None:
        bing_api_key = args.bing_api_key
    else:
        bing_api_key = environ.get("BING_API_KEY")
        if bing_api_key is None and args.n_web > 0:
            raise ValueError("If n_web > 0 you must give a BING_API_KEY")

    if args.debug is None:
        model = ModelDir(args.model)
    else:
        model = RandomPredictor(5, WithIndicators())

    if args.vec_dir is not None:
        loader = LoadFromPath(args.vec_dir)
    else:
        loader = ResourceLoader()

    if args.debug == "dummy_qa":
        qa = DummyQa()
    else:
        qa = QaSystem(
            args.wiki_cache,
            MergeParagraphs(args.tokens),
            ShallowOpenWebRanker(args.n_paragraphs),
            args.voc,
            model,
            loader,
            bing_api_key,
            tagme_api_key=tagme_api_key,
            n_dl_threads=args.n_dl_threads,
            blacklist_trivia_sites=args.blacklist_trivia_sites,
            download_timeout=args.download_timeout,
            span_bound=span_bound,
            tagme_threshold=None if args.no_wiki else args.tagme_thresh,
            n_web_docs=args.n_web)

    logging.propagate = False
    formatter = logging.Formatter("%(asctime)s: %(levelname)s: %(message)s")
    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    logging.root.addHandler(handler)
    logging.root.setLevel(logging.DEBUG)

    app = Sanic()
    app.config.REQUEST_TIMEOUT = args.request_timeout

    @app.route("/answer")
    async def answer(request):
        try:
            question = request.args["question"][0]
            if question == "":
                return response.json({'message': 'No question given'},
                                     status=400)
            spans, paras = await qa.answer_question(question)
            answers = select_answers(paras, spans, span_bound, 10)
            return json([x.to_json() for x in answers])
        except Exception as e:
            log.info("Error: " + str(e))

            raise ServerError("Server Error", status_code=500)

    @app.route('/answer-from', methods=['POST'])
    async def answer_from(request):
        try:
            args = ujson.loads(request.body.decode("utf-8"))
            question = args.get("question")
            if question is None or question == "":
                return response.json({'message': 'No question given'},
                                     status=400)
            doc = args["document"]
            if len(doc) > 500000:
                raise ServerError("Document too large", status_code=400)
            spans, paras = qa.answer_with_doc(question, doc)
            answers = select_answers(paras, spans, span_bound, 10)
            return json([x.to_json() for x in answers])
        except Exception as e:
            log.info("Error: " + str(e))
            raise ServerError("Server Error", status_code=500)

    app.static('/', './docqa//server/static/index.html')
    app.static('/about.html', './docqa//service/static/about.html')
    app.run(host="0.0.0.0", port=8000, workers=args.workers, debug=False)
Exemplo n.º 10
0
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file',
                        required=True,
                        help="input file, e.g. train_data.pkl")
    parser.add_argument('--output_train_file',
                        required=True,
                        help="output train file, e.g. train_output.json")
    parser.add_argument('--num_epoch',
                        required=True,
                        type=int,
                        help="num_epoch, e.g. 10")
    args = parser.parse_args()

    mode = "shared-norm"
    model = get_model(100, 140, mode, WithIndicators())
    extract = ExtractMultiParagraphsPerQuestion(MergeParagraphs(400),
                                                ShallowOpenWebRanker(16),
                                                model.preprocessor,
                                                intern=True)

    oversample = [1] * 2
    train = StratifyParagraphSetsBuilder(30, mode == "merge", True, oversample)
    test = RandomParagraphSetDatasetBuilder(
        120, "merge" if mode == "merge" else "group", True, oversample)

    data = PreprocessedData(None, extract, train, test, eval_on_verified=False)
    data.load_preprocess(args.input_file)

    outputs = []
    training_data = data.get_train()
Exemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('answers', help='answer file')
    parser.add_argument('question_source')
    args = parser.parse_args()

    print("Loading answers..")
    answer_df = pd.read_csv(args.answers)

    print("Loading questions..")
    if args.question_source == "open":
        corpus = TriviaQaOpenDataset()
        questions = {q.question_id: q for q in corpus.get_dev()}
    elif args.question_source == "web":
        corpus = TriviaQaWebDataset()
        questions = {}
        for q in corpus.get_dev():
            for d in q.all_docs:
                questions[(q.question_id, d.doc_id)] = q
    elif args.question_source == "squad":
        show_squad_errors(args.answers)
        return
    else:
        raise ValueError()

    pre = WithIndicators()

    answer_df.sort_values(["question_id", "rank"], inplace=True)

    if args.question_source == "open":
        iter = answer_df.groupby(["question_id"])
    else:
        iter = answer_df.groupby(["question_id", "doc_id"])

    grouped = list(iter)
    np.random.shuffle(grouped)

    for key, group in grouped:
        print(list(questions.keys())[:10])
        q = questions[key]
        cur_best_score = group.text_f1.iloc[0]
        cur_best_conf = group.predicted_score.iloc[0]
        cur_best_ix = group.index[0]
        for i in range(1, len(group)):
            ix = group.index[i]
            conf = group.predicted_score[ix]
            if conf > cur_best_conf:
                score = group.text_f1[ix]
                if score < cur_best_score:
                    # We hurt our selves!
                    print("Oh no!")
                    print(" ".join(q.question))
                    print(q.answer.all_answers)
                    print("Best score was %.4f (conf=%.4f), but not is %.4f (conf=%.4f)" % (
                        cur_best_score, cur_best_conf, score, conf
                    ))
                    d1 = [d for d in q.all_docs if d.doc_id == group.doc_id[cur_best_ix]][0]
                    p1 = extract_paragraph(corpus.evidence.get_document(d1.doc_id), group.para_start[cur_best_ix], group.para_end[cur_best_ix])
                    s, e = group.para_start[cur_best_ix], group.para_end[cur_best_ix]
                    answers = d1.answer_spans[np.logical_and(d1.answer_spans[:, 0] >= s, d1.answer_spans[:, 1] < s)] - s
                    p1 = pre.encode_extracted_paragraph(q.question, ExtractedParagraphWithAnswers(
                        p1,  group.para_start[cur_best_ix], group.para_end[cur_best_ix], answers))

                    d2 = [d for d in q.all_docs if d.doc_id == group.doc_id[ix]][0]
                    p2 = extract_paragraph(corpus.evidence.get_document(d2.doc_id), group.para_start[ix], group.para_end[ix])
                    s, e = group.para_start[ix], group.para_end[ix]
                    answers = d2.answer_spans[np.logical_and(d2.answer_spans[:, 0] >= s, d2.answer_spans[:, 1] < s)] - s
                    p2 = pre.encode_extracted_paragraph(q.question, ExtractedParagraphWithAnswers(
                        p2,  group.para_start[ix], group.para_end[ix], answers))

                    p1_s, p1_e = group.predicted_start[cur_best_ix], group.predicted_end[cur_best_ix]
                    p2_s, p2_e = group.predicted_start[ix], group.predicted_end[ix]
                    print(" ".join(display_para(p1.text, p1.answer_spans, q.question, p1_s, p1_e)))
                    print()
                    print(" ".join(display_para(p2.text, p2.answer_spans, q.question, p2_s, p2_e)))
                    input()
                else:
                    cur_best_score = score
                    cur_best_ix = ix
                    cur_best_conf = conf