def run_training(savename: str, train_config: TrainConfig, dataset_oversampling: Dict[str, int], n_processes: int, use_cudnn: bool ): """Train a Cape-Flavoured DocumentQA model. After preparing the datasets for training, a model will be created and saved in a directory specified by `savename`. Logging (Tensorboard) can be found in the log subdirectory of the model directory. The datasets to train the model on are specified in the `dataset_oversampling` dictionary. E.g. {'squad': 2, 'wiki':1} will train a model on one equivalence of triviaqa wiki and two equivalences of squad. :param savename: Name of model :param train_config: cape_config.TrainConfig object containing hyperparameters etc :param dataset_oversampling: dictionary mapping dataset names to integer counts of how much to oversample them :param n_processes: Number of processes to paralellize prepro on :param use_cudnn: Whether to train with GRU's optimized for Cudnn (recommended) """ model = build_model(WithIndicators(), train_config, use_cudnn=use_cudnn) data = prepare_data(model, train_config, dataset_oversampling, n_processes) eval = get_evaluators(train_config) params = get_training_params(train_config) with open(__file__, "r", encoding='utf8') as f: notes = f.read() notes = "Mode: " + train_config.trivia_qa_mode + "\n" + notes notes += '\nDataset oversampling : ' + str(dataset_oversampling) # pull the trigger trainer.start_training(data, model, params, eval, model_dir.ModelDir(savename), notes)
def main(): parser = argparse.ArgumentParser() parser.add_argument( "corpus", choices=["en", "fr", "de", "ru", "pt", "zh", "pl", "uk", "ta"]) parser.add_argument( 'mode', choices=["confidence", "merge", "shared-norm", "sigmoid", "paragraph"]) # Note I haven't tested modes other than `shared-norm` on this corpus, so # some things might need adjusting parser.add_argument("-t", "--n_tokens", default=400, type=int, help="Paragraph size") args = parser.parse_args() mode = args.mode corpus = args.corpus model = get_model(100, 140, mode, WithIndicators()) extract = ExtractMultiParagraphsPerQuestion(MergeParagraphs(args.n_tokens), ShallowOpenWebRanker(16), model.preprocessor, intern=True) oversample = [ 1 ] * 2 # Sample the top two answer-containing paragraphs twice if mode == "paragraph": n_epochs = 120 test = RandomParagraphSetDatasetBuilder(120, "flatten", True, oversample) train = StratifyParagraphsBuilder(ClusteredBatcher( 60, ContextLenBucketedKey(3), True), oversample, only_answers=True) elif mode == "confidence" or mode == "sigmoid": if mode == "sigmoid": n_epochs = 640 else: n_epochs = 160 test = RandomParagraphSetDatasetBuilder(120, "flatten", True, oversample) train = StratifyParagraphsBuilder( ClusteredBatcher(60, ContextLenBucketedKey(3), True), oversample) else: n_epochs = 80 test = RandomParagraphSetDatasetBuilder( 120, "merge" if mode == "merge" else "group", True, oversample) train = StratifyParagraphSetsBuilder(30, mode == "merge", True, oversample) data = XQADataset(corpus) data = PreprocessedData(data, extract, train, test, eval_on_verified=False) data.preprocess(1, 1000) # dump preprocessed dev data for bert data.cache_preprocess("dev_data_%s.pkl" % args.corpus)
def main(): parser = argparse.ArgumentParser() parser.add_argument("corpus", choices=["en", "en_trans_de", "en_trans_zh"]) parser.add_argument( 'mode', choices=["confidence", "merge", "shared-norm", "sigmoid", "paragraph"]) # Note I haven't tested modes other than `shared-norm` on this corpus, so # some things might need adjusting parser.add_argument("name", help="Where to store the model") parser.add_argument("-t", "--n_tokens", default=400, type=int, help="Paragraph size") parser.add_argument( '-n', '--n_processes', type=int, default=2, help="Number of processes (i.e., select which paragraphs to train on) " "the data with") args = parser.parse_args() mode = args.mode corpus = args.corpus out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S") model = get_model(100, 140, mode, WithIndicators()) extract = ExtractMultiParagraphsPerQuestion(MergeParagraphs(args.n_tokens), ShallowOpenWebRanker(16), model.preprocessor, intern=True) eval = [ LossEvaluator(), MultiParagraphSpanEvaluator(8, "triviaqa", mode != "merge", per_doc=False) ] oversample = [ 1 ] * 2 # Sample the top two answer-containing paragraphs twice if mode == "paragraph": n_epochs = 120 test = RandomParagraphSetDatasetBuilder(120, "flatten", True, oversample) train = StratifyParagraphsBuilder(ClusteredBatcher( 60, ContextLenBucketedKey(3), True), oversample, only_answers=True) elif mode == "confidence" or mode == "sigmoid": if mode == "sigmoid": n_epochs = 640 else: n_epochs = 160 test = RandomParagraphSetDatasetBuilder(120, "flatten", True, oversample) train = StratifyParagraphsBuilder( ClusteredBatcher(60, ContextLenBucketedKey(3), True), oversample) else: n_epochs = 80 test = RandomParagraphSetDatasetBuilder( 120, "merge" if mode == "merge" else "group", True, oversample) train = StratifyParagraphSetsBuilder(30, mode == "merge", True, oversample) data = XQADataset(corpus) params = TrainParams(SerializableOptimizer("Adadelta", dict(learning_rate=1)), num_epochs=n_epochs, ema=0.999, max_checkpoints_to_keep=2, async_encoding=10, log_period=30, eval_period=1800, save_period=1800, best_weights=("dev", "b8/question-text-f1"), eval_samples=dict(dev=None, train=6000)) data = PreprocessedData(data, extract, train, test, eval_on_verified=False) data.preprocess(args.n_processes, 1000) with open(__file__, "r") as f: notes = f.read() notes = "Mode: " + args.mode + "\n" + notes trainer.start_training(data, model, params, eval, model_dir.ModelDir(out), notes)
def main(): parser = argparse.ArgumentParser(description='Train a model on TriviaQA web') parser.add_argument('mode', choices=["paragraph-level", "confidence", "merge", "shared-norm", "sigmoid", "shared-norm-600"]) parser.add_argument("name", help="Where to store the model") parser.add_argument('-n', '--n_processes', type=int, default=2, help="Number of processes (i.e., select which paragraphs to train on) " "the data with") args = parser.parse_args() mode = args.mode out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S") model = get_model(100, 140, mode, WithIndicators()) stop = NltkPlusStopWords(True) if mode == "paragraph-level": extract = ExtractSingleParagraph(MergeParagraphs(400), TopTfIdf(stop, 1), model.preprocessor, intern=True) elif mode == "shared-norm-600": extract = ExtractMultiParagraphs(MergeParagraphs(600), TopTfIdf(stop, 4), model.preprocessor, intern=True) else: extract = ExtractMultiParagraphs(MergeParagraphs(400), TopTfIdf(stop, 4), model.preprocessor, intern=True) if mode == "paragraph-level": n_epochs = 16 train = ParagraphAndQuestionsBuilder(ClusteredBatcher(60, ContextLenBucketedKey(3), True)) test = ParagraphAndQuestionsBuilder(ClusteredBatcher(60, ContextLenKey(), False)) n_dev, n_train = 21000, 12000 eval = [LossEvaluator(), SpanEvaluator([4, 8], "triviaqa")] else: eval = [LossEvaluator(), MultiParagraphSpanEvaluator(8, "triviaqa", mode != "merge")] # we sample two paragraphs per a (question, doc) pair, so evaluate on fewer questions n_dev, n_train = 15000, 8000 if mode == "confidence" or mode == "sigmoid": if mode == "sigmoid": # Trains very slowly, do this at your own risk n_epochs = 71 else: n_epochs = 28 test = RandomParagraphSetDatasetBuilder(120, "flatten", True, 1) train = StratifyParagraphsBuilder(ClusteredBatcher(60, ContextLenBucketedKey(3), True), 0, 1) else: n_epochs = 14 test = RandomParagraphSetDatasetBuilder(120, "merge" if mode == "merge" else "group", True, 1) train = StratifyParagraphSetsBuilder(35, mode == "merge", True, 1) data = TriviaQaWebDataset() params = get_triviaqa_train_params(n_epochs, n_dev, n_train) data = PreprocessedData(data, extract, train, test, eval_on_verified=False) data.preprocess(args.n_processes, 1000) with open(__file__, "r") as f: notes = f.read() notes = "*" * 10 + "\nMode: " + args.mode + "\n" + "*"*10 + "\n" + notes trainer.start_training(data, model, params, eval, model_dir.ModelDir(out), notes)
def main(): parser = argparse.ArgumentParser( description='Train a model on document-level SQuAD') parser.add_argument( 'mode', choices=["paragraph", "confidence", "shared-norm", "merge", "sigmoid"]) parser.add_argument("name", help="Output directory") args = parser.parse_args() mode = args.mode out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S") corpus = SquadCorpus() if mode == "merge": # Adds paragraph start tokens, since we will be concatenating paragraphs together pre = WithIndicators(True, para_tokens=False, doc_start_token=False) else: pre = None model = get_model(50, 100, args.mode, pre) if mode == "paragraph": # Run in the "standard" known-paragraph setting if model.preprocessor is not None: raise NotImplementedError() n_epochs = 26 train_batching = ClusteredBatcher(45, ContextLenBucketedKey(3), True, False) eval_batching = ClusteredBatcher(45, ContextLenKey(), False, False) data = DocumentQaTrainingData(corpus, None, train_batching, eval_batching) eval = [LossEvaluator(), SpanEvaluator(bound=[17], text_eval="squad")] else: eval_set_mode = { "confidence": "flatten", "sigmoid": "flatten", "shared-norm": "group", "merge": "merge" }[mode] eval_dataset = RandomParagraphSetDatasetBuilder( 100, eval_set_mode, True, 0) if mode == "confidence" or mode == "sigmoid": if mode == "sigmoid": # needs to be trained for a really long time for reasons unknown, even this might be too small n_epochs = 100 else: n_epochs = 50 # more epochs since we only "see" the label very other epoch-osh train_batching = ClusteredBatcher(45, ContextLenBucketedKey(3), True, False) data = PreprocessedData( SquadCorpus(), SquadTfIdfRanker(NltkPlusStopWords(True), 4, True, model.preprocessor), StratifyParagraphsBuilder(train_batching, 1), eval_dataset, eval_on_verified=False, ) else: n_epochs = 26 data = PreprocessedData( SquadCorpus(), SquadTfIdfRanker(NltkPlusStopWords(True), 4, True, model.preprocessor), StratifyParagraphSetsBuilder(25, args.mode == "merge", True, 1), eval_dataset, eval_on_verified=False, ) eval = [LossEvaluator(), MultiParagraphSpanEvaluator(17, "squad")] data.preprocess(1) with open(__file__, "r") as f: notes = f.read() notes = args.mode + "\n" + notes trainer.start_training(data, model, train_params(n_epochs), eval, model_dir.ModelDir(out), notes)
def main(): parser = argparse.ArgumentParser() parser.add_argument( '--n_processes', type=int, default=1, help= "Number of processes to do the preprocessing (selecting paragraphs+loading context) with" ) parser.add_argument('-a', '--async', type=int, default=10) parser.add_argument('-t', '--tokens', type=int, default=400, help="Max tokens per a paragraph") parser.add_argument('-n', '--n_sample', type=int, default=None, help="Number of questions to evaluate on") parser.add_argument('-g', '--n_paragraphs', type=int, default=15, help="Number of paragraphs to run the model on") parser.add_argument('-f', '--filter', type=str, default=None, choices=["tfidf", "truncate", "linear"], help="How to select paragraphs") parser.add_argument( '-c', '--corpus', choices=[ "en_dev", "en_test", "fr_dev", "fr_test", "de_dev", "de_test", "ru_dev", "ru_test", "pt_dev", "pt_test", "zh_dev", "zh_test", "pl_dev", "pl_test", "uk_dev", "uk_test", "ta_dev", "ta_test", "fr_trans_en_dev", "fr_trans_en_test", "de_trans_en_dev", "de_trans_en_test", "ru_trans_en_dev", "ru_trans_en_test", "pt_trans_en_dev", "pt_trans_en_test", "zh_trans_en_dev", "zh_trans_en_test", "pl_trans_en_dev", "pl_trans_en_test", "uk_trans_en_dev", "uk_trans_en_test", "ta_trans_en_dev", "ta_trans_en_test" ], required=True) args = parser.parse_args() corpus_name = args.corpus[:args.corpus.rfind("_")] eval_set = args.corpus[args.corpus.rfind("_") + 1:] dataset = XQADataset(corpus_name) if eval_set == "dev": test_questions = dataset.get_dev() elif eval_set == "test": test_questions = dataset.get_test() else: raise AssertionError() corpus = dataset.evidence splitter = MergeParagraphs(args.tokens) per_document = args.corpus.startswith( "web") # wiki and web are both multi-document filter_name = args.filter if filter_name is None: # Pick default depending on the kind of data we are using if per_document: filter_name = "tfidf" else: filter_name = "linear" print("Selecting %d paragraphs using method \"%s\" per %s" % (args.n_paragraphs, filter_name, ("question-document pair" if per_document else "question"))) if filter_name == "tfidf": para_filter = TopTfIdf(NltkPlusStopWords(punctuation=True), args.n_paragraphs) elif filter_name == "truncate": para_filter = FirstN(args.n_paragraphs) elif filter_name == "linear": para_filter = ShallowOpenWebRanker(args.n_paragraphs) else: raise ValueError() n_questions = args.n_sample if n_questions is not None: test_questions.sort(key=lambda x: x.question_id) np.random.RandomState(0).shuffle(test_questions) test_questions = test_questions[:n_questions] preprocessor = WithIndicators() print("Building question/paragraph pairs...") # Loads the relevant questions/documents, selects the right paragraphs, and runs the model's preprocessor if per_document: prep = ExtractMultiParagraphs(splitter, para_filter, preprocessor, require_an_answer=False) else: prep = ExtractMultiParagraphsPerQuestion(splitter, para_filter, preprocessor, require_an_answer=False) prepped_data = preprocess_par(test_questions, corpus, prep, args.n_processes, 1000) data = [] for q in prepped_data.data: for i, p in enumerate(q.paragraphs): if q.answer_text is None: ans = None else: ans = TokenSpans(q.answer_text, p.answer_spans) data.append( DocumentParagraphQuestion(q.question_id, p.doc_id, (p.start, p.end), q.question, p.text, ans, i)) # Reverse so our first batch will be the largest (so OOMs happen early) questions = sorted(data, key=lambda x: (x.n_context_words, len(x.question)), reverse=True) # dump eval data for bert import pickle pickle.dump(questions, open("%s_%d.pkl" % (args.corpus, args.n_paragraphs), "wb"))
def main(): parser = argparse.ArgumentParser(description='Run the demo server') parser.add_argument( 'model', default= "/home/antriv/conversation_ai/ALLENAI_DocumentQA/document-qa/models/triviaqa-unfiltered-shared-norm/best-weights", help='Models to use') parser.add_argument( '-v', '--voc', default= "/home/antriv/conversation_ai/ALLENAI_DocumentQA/document-qa/data/triviaqa/evidence/vocab.txt", help='vocab to use, only words from this file will be used') parser.add_argument('-t', '--tokens', type=int, default=400, help='Number of tokens to use per paragraph') parser.add_argument('--vec_dir', default="/home/antriv/data/glove", help='Location to find word vectors') parser.add_argument('--n_paragraphs', type=int, default=15, help="Number of paragraphs to run the model on") parser.add_argument('--paragraphs_to_return', type=int, default=10, help="Number of paragraphs return to the frontend") parser.add_argument('--span_bound', type=int, default=8, help="Max span size to return as an answer") parser.add_argument( '--tagme_api_key', default="1cdc0067-b2de-4774-afbe-38703b11a365-843339462", help="Key to use for TAGME (tagme.d4science.org/tagme)") parser.add_argument('--bing_api_key', default="413239df9faa4f1494a914e0c9cec78e", help="Key to use for bing searches") parser.add_argument( '--bing_version', choices=["v5.0", "v7.0"], default="v7.0", help='Version of Bing API to use (must be compatible with the API key)' ) parser.add_argument( '--tagme_thresh', default=0.2, type=float, help="TAGME threshold for when to use the identified docs") parser.add_argument('--n_web', type=int, default=10, help='Number of web docs to fetch') parser.add_argument('--blacklist_trivia_sites', action="store_true", help="Don't use trivia websites") parser.add_argument( '-c', '--wiki_cache', default= "/home/antriv/conversation_ai/ALLENAI_DocumentQA/document-qa/data/triviaqa/evidence/wikipedia", help="Cache wiki articles in this directory") parser.add_argument('--n_dl_threads', type=int, default=5, help="Number of threads to download documents with") parser.add_argument('--request_timeout', type=int, default=60) parser.add_argument('--download_timeout', type=int, default=25, help="how long to wait before timing out downloads") parser.add_argument('--workers', type=int, default=1, help="Number of server workers") parser.add_argument('--debug', default=None, choices=["random_model", "dummy_qa"]) args = parser.parse_args() span_bound = args.span_bound n_to_return = args.paragraphs_to_return if args.tagme_api_key is not None: tagme_api_key = args.tagme_api_key else: tagme_api_key = environ.get("TAGME_API_KEY") if args.bing_api_key is not None: bing_api_key = args.bing_api_key else: bing_api_key = environ.get("BING_API_KEY") if bing_api_key is None and args.n_web > 0: raise ValueError("If n_web > 0 you must give a BING_API_KEY") if args.debug is None: model = ModelDir(args.model) else: model = RandomPredictor(5, WithIndicators()) if args.vec_dir is not None: loader = LoadFromPath(args.vec_dir) else: loader = ResourceLoader() # Update Sanic's logging to register our class's loggers log_config = LOGGING formatter = "%(asctime)s: %(levelname)s: %(message)s" log_config["formatters"]['my_formatter'] = { 'format': formatter, 'datefmt': '%Y-%m-%d %H:%M:%S', } log_config['handlers']['stream_handler'] = { 'class': "logging.StreamHandler", 'formatter': 'my_formatter', 'stream': sys.stderr } log_config['handlers']['file_handler'] = { 'class': "logging.FileHandler", 'formatter': 'my_formatter', 'filename': 'logging.log' } # It looks like we have to go and name every logger our own code might # use in order to register it with Sanic log_config["loggers"]['qa_system'] = { 'level': 'INFO', 'handlers': ['stream_handler', 'file_handler'], } log_config["loggers"]['downloader'] = { 'level': 'INFO', 'handlers': ['stream_handler', 'file_handler'], } log_config["loggers"]['server'] = { 'level': 'INFO', 'handlers': ['stream_handler', 'file_handler'], } app = Sanic() app.config.REQUEST_TIMEOUT = args.request_timeout log = logging.getLogger('server') @app.listener('before_server_start') async def setup_qa(app, loop): # To play nice with iohttp's async ClientSession objects, we need to construct the QaSystem # inside the event loop. if args.debug == "dummy_qa": qa = DummyQa() else: qa = QaSystem( args.wiki_cache, MergeParagraphs(args.tokens), ShallowOpenWebRanker(args.n_paragraphs), args.voc, model, loader, bing_api_key, bing_version=args.bing_version, tagme_api_key=tagme_api_key, n_dl_threads=args.n_dl_threads, blacklist_trivia_sites=args.blacklist_trivia_sites, download_timeout=args.download_timeout, span_bound=span_bound, tagme_threshold=None if (tagme_api_key is None) else args.tagme_thresh, n_web_docs=args.n_web, ) app.qa = qa @app.listener('after_server_stop') async def setup_qa(app, loop): app.qa.close() @app.route("/answer") async def answer(request): try: question = request.args["question"][0] if question == "": return response.json({'message': 'No question given'}, status=400) spans, paras = await app.qa.answer_question(question) answers = select_answers(paras, spans, span_bound, 10) answers = answers[:n_to_return] best_span = max(answers[0].answers, key=lambda x: x.conf) log.info("Answered \"%s\" (with web search): \"%s\"", question, answers[0].original_text[best_span.start:best_span.end]) return json([x.to_json() for x in answers]) except Exception as e: log.info("Error: " + str(e)) raise ServerError(e, status_code=500) @app.route('/answer-from', methods=['POST']) async def answer_from(request): try: args = ujson.loads(request.body.decode("utf-8")) question = args.get("question") if question is None or question == "": return response.json({'message': 'No question given'}, status=400) doc = args["document"] if len(doc) > 500000: raise ServerError("Document too large", status_code=400) spans, paras = app.qa.answer_with_doc(question, doc) answers = select_answers(paras, spans, span_bound, 10) answers = answers[:n_to_return] best_span = max(answers[0].answers, key=lambda x: x.conf) log.info("Answered \"%s\" (with user doc): \"%s\"", question, answers[0].original_text[best_span.start:best_span.end]) return json([x.to_json() for x in answers]) except Exception as e: log.info("Error: " + str(e)) raise ServerError(e, status_code=500) app.static('/', './docqa//server/static/index.html') app.static('/about.html', './docqa/server/static/about.html') app.run(host="0.0.0.0", port=5000, workers=args.workers, debug=False, log_config=LOGGING)
def main(): parser = argparse.ArgumentParser( description='Train a model on TriviaQA unfiltered') parser.add_argument( 'mode', choices=["confidence", "merge", "shared-norm", "sigmoid", "paragraph"]) parser.add_argument("name", help="Where to store the model") parser.add_argument("-t", "--n_tokens", default=400, type=int, help="Paragraph size") parser.add_argument( '-n', '--n_processes', type=int, default=2, help="Number of processes (i.e., select which paragraphs to train on) " "the data with") parser.add_argument("-s", "--source_dir", type=str, default=None, help="where to take input files") parser.add_argument("--n_epochs", type=int, default=None, help="Max number of epoches to train on ") parser.add_argument("--char_th", type=int, default=None, help="char level embeddings") parser.add_argument("--hl_dim", type=int, default=None, help="hidden layer dim size") parser.add_argument("--regularization", type=int, default=None, help="hidden layer dim size") parser.add_argument("--LR", type=float, default=1.0, help="hidden layer dim size") parser.add_argument("--save_every", type=int, default=1800, help="save period") parser.add_argument("--init_from", type=str, default=None, help="model to init from") args = parser.parse_args() mode = args.mode #out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S") out = join('models', args.name) char_th = 100 hl_dim = 140 if args.char_th is not None: print(args.char_th) char_th = int(args.char_th) out += '--th' + str(char_th) if args.hl_dim is not None: print(args.hl_dim) hl_dim = int(args.hl_dim) out += '--hl' + str(hl_dim) if args.init_from is None: model = get_model(char_th, hl_dim, mode, WithIndicators()) else: md = model_dir.ModelDir(args.init_from) model = md.get_model() extract = ExtractMultiParagraphsPerQuestion(MergeParagraphs(args.n_tokens), ShallowOpenWebRanker(16), model.preprocessor, intern=True) eval = [ LossEvaluator(), MultiParagraphSpanEvaluator(8, "triviaqa", mode != "merge", per_doc=False) ] oversample = [1] * 4 if mode == "paragraph": n_epochs = 120 test = RandomParagraphSetDatasetBuilder(120, "flatten", True, oversample) train = StratifyParagraphsBuilder(ClusteredBatcher( 60, ContextLenBucketedKey(3), True), oversample, only_answers=True) elif mode == "confidence" or mode == "sigmoid": if mode == "sigmoid": n_epochs = 640 else: n_epochs = 160 test = RandomParagraphSetDatasetBuilder(120, "flatten", True, oversample) train = StratifyParagraphsBuilder( ClusteredBatcher(60, ContextLenBucketedKey(3), True), oversample) else: n_epochs = 80 test = RandomParagraphSetDatasetBuilder( 120, "merge" if mode == "merge" else "group", True, oversample) train = StratifyParagraphSetsBuilder(30, mode == "merge", True, oversample) if args.n_epochs is not None: n_epochs = args.n_epochs out += '--' + str(n_epochs) if args.LR != 1.0: out += '--' + str(args.LR) data = TriviaQaOpenDataset(args.source_dir) async_encoding = 10 #async_encoding = 0 params = TrainParams(SerializableOptimizer("Adadelta", dict(learning_rate=args.LR)), num_epochs=n_epochs, num_of_steps=250000, ema=0.999, max_checkpoints_to_keep=2, async_encoding=async_encoding, log_period=30, eval_period=1800, save_period=args.save_every, eval_samples=dict(dev=None, train=6000), regularization_weight=None) data = PreprocessedData(data, extract, train, test, eval_on_verified=False) data.preprocess(args.n_processes, 1000) with open(__file__, "r") as f: notes = f.read() notes = "Mode: " + args.mode + "\n" + notes if args.init_from is not None: init_from = model_dir.ModelDir(args.init_from).get_best_weights() if init_from is None: init_from = model_dir.ModelDir( args.init_from).get_latest_checkpoint() else: init_from = None trainer.start_training(data, model, params, eval, model_dir.ModelDir(out), notes, initialize_from=init_from)
def main(): parser = argparse.ArgumentParser(description='Run the demo server') parser.add_argument('model', help='Models to use') parser.add_argument( '-v', '--voc', help='vocab to use, only words from this file will be used') parser.add_argument('-t', '--tokens', type=int, default=400, help='Number of tokens to use per paragraph') parser.add_argument('--vec_dir', help='Location to find word vectors') parser.add_argument('--n_paragraphs', type=int, default=12, help="Number of paragraphs to run the model on") parser.add_argument('--span_bound', type=int, default=8, help="Max span size to return as an answer") parser.add_argument( '--tagme_api_key', help="Key to use for TAGME (tagme.d4science.org/tagme)") parser.add_argument('--bing_api_key', help="Key to use for bing searches") parser.add_argument('--tagme_thresh', default=0.2, type=float) parser.add_argument('--no_wiki', action="store_true", help="Dont use TAGME") parser.add_argument('--n_web', type=int, default=10, help='Number of web docs to fetch') parser.add_argument('--blacklist_trivia_sites', action="store_true", help="Don't use trivia websites") parser.add_argument('-c', '--wiki_cache', help="Cache wiki articles in this directory") parser.add_argument('--n_dl_threads', type=int, default=5, help="Number of threads to download documents with") parser.add_argument('--request_timeout', type=int, default=60) parser.add_argument('--download_timeout', type=int, default=25) parser.add_argument('--workers', type=int, default=1, help="Number of server workers") parser.add_argument('--debug', default=None, choices=["random_model", "dummy_qa"]) args = parser.parse_args() span_bound = args.span_bound if args.tagme_api_key is not None: tagme_api_key = args.tagme_api_key else: tagme_api_key = environ.get("TAGME_API_KEY") if args.bing_api_key is not None: bing_api_key = args.bing_api_key else: bing_api_key = environ.get("BING_API_KEY") if bing_api_key is None and args.n_web > 0: raise ValueError("If n_web > 0 you must give a BING_API_KEY") if args.debug is None: model = ModelDir(args.model) else: model = RandomPredictor(5, WithIndicators()) if args.vec_dir is not None: loader = LoadFromPath(args.vec_dir) else: loader = ResourceLoader() if args.debug == "dummy_qa": qa = DummyQa() else: qa = QaSystem( args.wiki_cache, MergeParagraphs(args.tokens), ShallowOpenWebRanker(args.n_paragraphs), args.voc, model, loader, bing_api_key, tagme_api_key=tagme_api_key, n_dl_threads=args.n_dl_threads, blacklist_trivia_sites=args.blacklist_trivia_sites, download_timeout=args.download_timeout, span_bound=span_bound, tagme_threshold=None if args.no_wiki else args.tagme_thresh, n_web_docs=args.n_web) logging.propagate = False formatter = logging.Formatter("%(asctime)s: %(levelname)s: %(message)s") handler = logging.StreamHandler() handler.setFormatter(formatter) logging.root.addHandler(handler) logging.root.setLevel(logging.DEBUG) app = Sanic() app.config.REQUEST_TIMEOUT = args.request_timeout @app.route("/answer") async def answer(request): try: question = request.args["question"][0] if question == "": return response.json({'message': 'No question given'}, status=400) spans, paras = await qa.answer_question(question) answers = select_answers(paras, spans, span_bound, 10) return json([x.to_json() for x in answers]) except Exception as e: log.info("Error: " + str(e)) raise ServerError("Server Error", status_code=500) @app.route('/answer-from', methods=['POST']) async def answer_from(request): try: args = ujson.loads(request.body.decode("utf-8")) question = args.get("question") if question is None or question == "": return response.json({'message': 'No question given'}, status=400) doc = args["document"] if len(doc) > 500000: raise ServerError("Document too large", status_code=400) spans, paras = qa.answer_with_doc(question, doc) answers = select_answers(paras, spans, span_bound, 10) return json([x.to_json() for x in answers]) except Exception as e: log.info("Error: " + str(e)) raise ServerError("Server Error", status_code=500) app.static('/', './docqa//server/static/index.html') app.static('/about.html', './docqa//service/static/about.html') app.run(host="0.0.0.0", port=8000, workers=args.workers, debug=False)
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--input_file', required=True, help="input file, e.g. train_data.pkl") parser.add_argument('--output_train_file', required=True, help="output train file, e.g. train_output.json") parser.add_argument('--num_epoch', required=True, type=int, help="num_epoch, e.g. 10") args = parser.parse_args() mode = "shared-norm" model = get_model(100, 140, mode, WithIndicators()) extract = ExtractMultiParagraphsPerQuestion(MergeParagraphs(400), ShallowOpenWebRanker(16), model.preprocessor, intern=True) oversample = [1] * 2 train = StratifyParagraphSetsBuilder(30, mode == "merge", True, oversample) test = RandomParagraphSetDatasetBuilder( 120, "merge" if mode == "merge" else "group", True, oversample) data = PreprocessedData(None, extract, train, test, eval_on_verified=False) data.load_preprocess(args.input_file) outputs = [] training_data = data.get_train()
def main(): parser = argparse.ArgumentParser(description='') parser.add_argument('answers', help='answer file') parser.add_argument('question_source') args = parser.parse_args() print("Loading answers..") answer_df = pd.read_csv(args.answers) print("Loading questions..") if args.question_source == "open": corpus = TriviaQaOpenDataset() questions = {q.question_id: q for q in corpus.get_dev()} elif args.question_source == "web": corpus = TriviaQaWebDataset() questions = {} for q in corpus.get_dev(): for d in q.all_docs: questions[(q.question_id, d.doc_id)] = q elif args.question_source == "squad": show_squad_errors(args.answers) return else: raise ValueError() pre = WithIndicators() answer_df.sort_values(["question_id", "rank"], inplace=True) if args.question_source == "open": iter = answer_df.groupby(["question_id"]) else: iter = answer_df.groupby(["question_id", "doc_id"]) grouped = list(iter) np.random.shuffle(grouped) for key, group in grouped: print(list(questions.keys())[:10]) q = questions[key] cur_best_score = group.text_f1.iloc[0] cur_best_conf = group.predicted_score.iloc[0] cur_best_ix = group.index[0] for i in range(1, len(group)): ix = group.index[i] conf = group.predicted_score[ix] if conf > cur_best_conf: score = group.text_f1[ix] if score < cur_best_score: # We hurt our selves! print("Oh no!") print(" ".join(q.question)) print(q.answer.all_answers) print("Best score was %.4f (conf=%.4f), but not is %.4f (conf=%.4f)" % ( cur_best_score, cur_best_conf, score, conf )) d1 = [d for d in q.all_docs if d.doc_id == group.doc_id[cur_best_ix]][0] p1 = extract_paragraph(corpus.evidence.get_document(d1.doc_id), group.para_start[cur_best_ix], group.para_end[cur_best_ix]) s, e = group.para_start[cur_best_ix], group.para_end[cur_best_ix] answers = d1.answer_spans[np.logical_and(d1.answer_spans[:, 0] >= s, d1.answer_spans[:, 1] < s)] - s p1 = pre.encode_extracted_paragraph(q.question, ExtractedParagraphWithAnswers( p1, group.para_start[cur_best_ix], group.para_end[cur_best_ix], answers)) d2 = [d for d in q.all_docs if d.doc_id == group.doc_id[ix]][0] p2 = extract_paragraph(corpus.evidence.get_document(d2.doc_id), group.para_start[ix], group.para_end[ix]) s, e = group.para_start[ix], group.para_end[ix] answers = d2.answer_spans[np.logical_and(d2.answer_spans[:, 0] >= s, d2.answer_spans[:, 1] < s)] - s p2 = pre.encode_extracted_paragraph(q.question, ExtractedParagraphWithAnswers( p2, group.para_start[ix], group.para_end[ix], answers)) p1_s, p1_e = group.predicted_start[cur_best_ix], group.predicted_end[cur_best_ix] p2_s, p2_e = group.predicted_start[ix], group.predicted_end[ix] print(" ".join(display_para(p1.text, p1.answer_spans, q.question, p1_s, p1_e))) print() print(" ".join(display_para(p2.text, p2.answer_spans, q.question, p2_s, p2_e))) input() else: cur_best_score = score cur_best_ix = ix cur_best_conf = conf