def get_most_relevant_doc_based_on_config(config, query_string, target_index): """ 1. Instantiate various classes based on config 2. Get the most relevant doc """ # We still need to init a pipeline because it pre-processes some config params, and we rely on that to # construct paths e.t.c. config = config.copy() # because we end up modifying config pipeline = Pipeline(config) pipeline.initialize(config) path_dict = pipeline.get_paths(config) index_path = target_index index_class = Index.get_index_from_index_path(index_path) index = index_class(pipeline.collection, index_path, None) # TODO: Pass a proper index_key model_class = Reranker.ALL[config["reranker"]] tokenizer = NeuralQueryView.get_tokenizer(pipeline, config, index_class.name) embedding_holder = EmbeddingHolder.get_instance(config.get("embeddings", "glove6b")) trained_weight_path = path_dict["trained_weight_path"] config = NeuralQueryView.add_model_required_params_to_config(config, embedding_holder) return NeuralQueryView.do_query( config, query_string, pipeline, index, tokenizer, embedding_holder, model_class, trained_weight_path=trained_weight_path, )
def _train(_config): pipeline_config = _config early_stopping = pipeline_config["earlystopping"] pipeline = Pipeline(pipeline_config) pipeline.initialize(pipeline_config) reranker = pipeline.reranker benchmark = pipeline.benchmark fold = benchmark.folds.get(pipeline.cfg["fold"], None) datagen = benchmark.training_tuples(fold["train_qids"]) run_path = os.path.join(pipeline.reranker_path, pipeline.cfg["fold"]) weight_path = os.path.join(run_path, "weights") prepare_batch = functools.partial( _prepare_batch_with_strings, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")) batches_per_epoch = pipeline_config["itersize"] // pipeline_config["batch"] batches_per_step = pipeline_config.get("gradacc", 1) optimizer = reranker.get_optimizer() best_accuracy = 0 for niter in range(pipeline.cfg["niters"]): reranker.model.train() reranker.next_iteration() for bi, data in enumerate(datagen): data = prepare_batch(data) tag_scores = reranker.score(data) loss = pipeline.lossf(tag_scores[0], tag_scores[1], pipeline.cfg["batch"]) loss.backward() if bi % batches_per_step == 0: optimizer.step() optimizer.zero_grad() if (bi + 1) % batches_per_epoch == 0: break if early_stopping: current_accuracy = max(evaluate_pipeline(pipeline)) if current_accuracy > best_accuracy: logger.debug( "Current accuracy: {0} is greater than best so far: {1}". format(current_accuracy, best_accuracy)) best_accuracy = current_accuracy reranker.save(os.path.join(weight_path, "dev")) # TODO: Do early stopping to return the best instance of the reranker if early_stopping: reranker.load(os.path.join(weight_path, "dev")) return pipeline