def test_collect_training_data(self): app = Vespa(url="https://api.cord19.vespa.ai") query_model = QueryModel(match_phase=OR(), rank_profile=Ranking(name="bm25", list_features=True)) labeled_data = [ { "query_id": 0, "query": "Intrauterine virus infections and congenital heart disease", "relevant_docs": [{ "id": 0, "score": 1 }, { "id": 3, "score": 1 }], }, { "query_id": 1, "query": "Clinical and immunologic studies in identical twins discordant for systemic lupus erythematosus", "relevant_docs": [{ "id": 1, "score": 1 }, { "id": 5, "score": 1 }], }, ] training_data_batch = app.collect_training_data( labeled_data=labeled_data, id_field="id", query_model=query_model, number_additional_docs=2, fields=["rankfeatures"], ) self.assertEqual(training_data_batch.shape[0], 12) # It should have at least one rank feature in addition to document_id, query_id and label self.assertTrue(training_data_batch.shape[1] > 3) training_data = [] for query_data in labeled_data: for doc_data in query_data["relevant_docs"]: training_data_point = app.collect_training_data_point( query=query_data["query"], query_id=query_data["query_id"], relevant_id=doc_data["id"], id_field="id", query_model=query_model, number_additional_docs=2, fields=["rankfeatures"], ) training_data.extend(training_data_point) training_data = DataFrame.from_records(training_data) self.assertEqual(training_data.shape[0], 12) # It should have at least one rank feature in addition to document_id, query_id and label self.assertTrue(training_data.shape[1] > 3)
def run(self): self.vespa = Vespa(url="http://vespa-search", port=8080) auth = tweepy.AppAuthHandler(self.api_key, self.api_secret) self.api = tweepy.API(auth) updated = 0 for userid in [ 'abcnews', 'GuardianAus', 'smh', 'iTnews_au', 'theage', 'canberratimes', 'zdnetaustralia', 'newscomauHQ', 'westaustralian' ]: try: for status in tweepy.Cursor(self.api.user_timeline, id=userid, include_entities=True).items(60): if len(status.entities['urls']) == 0: continue url = status.entities['urls'][0]['expanded_url'] url = url.split('?')[0] if (url.startswith("https://twitter.com")): continue if (url.startswith("https://zd.net") or url.startswith("https://bit.ly")): url = urlopen(url).geturl() article = self.get_article(url) if article: self.update_document(article, status) updated += 1 except Exception as e: logger.error(e) print("Completed run, updated {} tweets".format(updated))
def run(self): api_key = os.getenv('TWITTER_API_KEY') api_secret = os.getenv('TWITTER_API_SECRET') self.vespa = Vespa(url="http://vespa-search", port=8080) auth = tweepy.AppAuthHandler(api_key, api_secret) self.api = tweepy.API(auth) updated = 0 for userid in [ 'abcnews', 'GuardianAus', 'smh', 'iTnews_au', 'theage', 'canberratimes', 'zdnetaustralia', 'newscomauHQ', 'westaustralian', 'SBSNews', 'australian', 'crikey_news', '9NewsAUS', 'BBCNewsAus' ]: try: for status in tweepy.Cursor(self.api.user_timeline, id=userid, include_entities=True, tweet_mode="extended").items(60): if len(status.entities['urls']) == 0: continue url = status.entities['urls'][0]['expanded_url'] url = self.get_url(url) if (url.startswith("https://twitter.com") or url.startswith("https://www.reddit.com")): continue article = self.get_article(url) if article: self.update_document(article, status) updated += 1 # else: # self.insert_document(url) except Exception as e: print("exception! {}".format(e)) continue print("Completed run, updated {} tweets".format(updated))
def test_end_point(self): self.assertEqual( Vespa(url="https://cord19.vespa.ai").end_point, "https://cord19.vespa.ai") self.assertEqual( Vespa(url="http://localhost", port=8080).end_point, "http://localhost:8080") self.assertEqual( Vespa(url="http://localhost/", port=8080).end_point, "http://localhost:8080")
def deploy(self, disk_folder: str, container_memory: str = "4G"): """ Deploy the application into a Vespa container. :param disk_folder: Disk folder to save the required Vespa config files. :param container_memory: Docker container memory available to the application. :return: a Vespa connection instance. """ self.application_package.create_application_package_files(dir_path=disk_folder) self.run_vespa_engine_container( disk_folder=disk_folder, container_memory=container_memory ) while not self.check_configuration_server(): print("Waiting for configuration server.") sleep(5) deployment = self.container.exec_run( "bash -c '/opt/vespa/bin/vespa-deploy prepare /app/application && /opt/vespa/bin/vespa-deploy activate'" ) deployment_message = deployment.output.decode("utf-8").split("\n") if not any(re.match("Generation: [0-9]+", line) for line in deployment_message): raise RuntimeError(deployment_message) return Vespa( url="http://localhost", port=self.local_port, deployment_message=deployment_message, )
def test_query(self): app = Vespa(url="http://localhost", port=8080) body = {"yql": "select * from sources * where test"} self.assertDictEqual( app.query(body=body, debug_request=True).request_body, body) self.assertDictEqual( app.query( query="this is a test", query_model=Query(match_phase=OR(), rank_profile=RankProfile()), debug_request=True, hits=10, ).request_body, { "yql": 'select * from sources * where ([{"grammar": "any"}]userInput("this is a test"));', "ranking": { "profile": "default", "listFeatures": "false" }, "hits": 10, }, ) self.assertDictEqual( app.query( query="this is a test", query_model=Query(match_phase=OR(), rank_profile=RankProfile()), debug_request=True, hits=10, recall=("id", [1, 5]), ).request_body, { "yql": 'select * from sources * where ([{"grammar": "any"}]userInput("this is a test"));', "ranking": { "profile": "default", "listFeatures": "false" }, "hits": 10, "recall": "+(id:1 id:5)", }, )
def run(self): try: config.load_incluster_config() except: config.load_kube_config() v1 = client.CoreV1Api() twitter_secrets = v1.read_namespaced_secret(name='twitter-secrets', namespace='default').data api_key = base64.b64decode(twitter_secrets["api-key"]).decode('utf-8') api_secret = base64.b64decode( twitter_secrets["api-secret"]).decode('utf-8') self.vespa = Vespa(url="http://vespa-search", port=8080) auth = tweepy.AppAuthHandler(api_key, api_secret) self.api = tweepy.API(auth) updated = 0 for userid in [ 'abcnews', 'GuardianAus', 'smh', 'iTnews_au', 'theage', 'canberratimes', 'zdnetaustralia', 'newscomauHQ', 'westaustralian', 'SBSNews', 'australian', 'crikey_news', '9NewsAUS', 'BBCNewsAus' ]: try: for status in tweepy.Cursor(self.api.user_timeline, id=userid, include_entities=True, tweet_mode="extended").items(60): if len(status.entities['urls']) == 0: continue url = status.entities['urls'][0]['expanded_url'] url = self.get_url(url) if (url.startswith("https://twitter.com") or url.startswith("https://www.reddit.com")): continue article = self.get_article(url) if article: self.update_document(article, status) updated += 1 # else: # self.insert_document(url) except Exception as e: print("exception! {}".format(e)) continue print("Completed run, updated {} tweets".format(updated))
def test_query_with_body_function(self): app = Vespa(url="http://localhost", port=8080) def body_function(query): body = { "yql": "select * from sources * where userQuery();", "query": query, "type": "any", "ranking": { "profile": "bm25", "listFeatures": "true" }, } return body query_model = QueryModel(body_function=body_function) self.assertDictEqual( app.query( query="this is a test", query_model=query_model, debug_request=True, hits=10, recall=("id", [1, 5]), ).request_body, { "yql": "select * from sources * where userQuery();", "query": "this is a test", "type": "any", "ranking": { "profile": "bm25", "listFeatures": "true" }, "hits": 10, "recall": "+(id:1 id:5)", }, )
def deploy(self, instance: str, disk_folder: str) -> Vespa: """ Deploy the given application package as the given instance in the Vespa Cloud dev environment. :param instance: Name of this instance of the application, in the Vespa Cloud. :param disk_folder: Disk folder to save the required Vespa config files. :return: a Vespa connection instance. """ region = self._get_dev_region() job = "dev-" + region run = self._start_deployment(instance, job, disk_folder) self._follow_deployment(instance, job, run) endpoint_url = self._get_endpoint(instance=instance, region=region) return Vespa( url=endpoint_url, cert=os.path.join(disk_folder, self.private_cert_file_name), )
class VespaWrite: def open_spider(self, spider): self.vespa = Vespa(url = "http://vespa-search", port = 8080) def process_item(self, item, spider): try: vespa_fields = { } vespa_fields['url'] = item['url'] vespa_fields['bodytext'] = item['bodytext'] vespa_fields['firstpubtime'] = item['firstpubtime'] if ('modtime' in item): vespa_fields['modtime'] = item['modtime'] vespa_fields['wordcount'] = item['wordcount'] vespa_fields['headline'] = item['headline'] vespa_fields['sentiment'] = item['sentiment'] if ('summary' in item): vespa_fields['abstract'] = item['summary'] if ('keywords' in item): vespa_fields['keywords'] = item['keywords'] if ('bylines' in item): vespa_fields['bylines'] = item['bylines'] if ('section' in item): vespa_fields['section'] = item['section'] vespa_fields['source'] = item['source'] if ('twitter_retweet_count' in item): vespa_fields['twitter_retweet_count'] = item['twitter_retweet_count'] if ('twitter_favourite_count' in item): vespa_fields['twitter_favourite_count'] = item['twitter_favourite_count'] response = self.vespa.update_data( schema = "newsarticle", data_id = hashlib.sha256(item['url'].encode()).hexdigest(), fields = vespa_fields, create = True ) return item except (KeyError, TypeError): logger.debug("error: " + item) pass
def open_spider(self, spider): self.vespa = Vespa(url = "http://vespa-search", port = 8080)
def setUp(self) -> None: self.app = Vespa(url="http://localhost", port=8080) self.raw_vespa_result_recall = { "root": { "id": "toplevel", "relevance": 1.0, "fields": { "totalCount": 1083 }, "coverage": { "coverage": 100, "documents": 62529, "full": True, "nodes": 2, "results": 1, "resultsFull": 1, }, "children": [{ "id": "id:covid-19:doc::40215", "relevance": 30.368213170494712, "source": "content", "fields": { "vespa_id_field": "abc", "sddocname": "doc", "body_text": "this is a body", "title": "this is a title", "rankfeatures": { "a": 1, "b": 2 }, }, }], } } self.raw_vespa_result_additional = { "root": { "id": "toplevel", "relevance": 1.0, "fields": { "totalCount": 1083 }, "coverage": { "coverage": 100, "documents": 62529, "full": True, "nodes": 2, "results": 1, "resultsFull": 1, }, "children": [ { "id": "id:covid-19:doc::40216", "relevance": 10, "source": "content", "fields": { "vespa_id_field": "def", "sddocname": "doc", "body_text": "this is a body 2", "title": "this is a title 2", "rankfeatures": { "a": 3, "b": 4 }, }, }, { "id": "id:covid-19:doc::40217", "relevance": 8, "source": "content", "fields": { "vespa_id_field": "ghi", "sddocname": "doc", "body_text": "this is a body 3", "title": "this is a title 3", "rankfeatures": { "a": 5, "b": 6 }, }, }, ], } }
class TestVespaCollectData(unittest.TestCase): def setUp(self) -> None: self.app = Vespa(url="http://localhost", port=8080) self.raw_vespa_result_recall = { "root": { "id": "toplevel", "relevance": 1.0, "fields": { "totalCount": 1083 }, "coverage": { "coverage": 100, "documents": 62529, "full": True, "nodes": 2, "results": 1, "resultsFull": 1, }, "children": [{ "id": "id:covid-19:doc::40215", "relevance": 30.368213170494712, "source": "content", "fields": { "vespa_id_field": "abc", "sddocname": "doc", "body_text": "this is a body", "title": "this is a title", "rankfeatures": { "a": 1, "b": 2 }, }, }], } } self.raw_vespa_result_additional = { "root": { "id": "toplevel", "relevance": 1.0, "fields": { "totalCount": 1083 }, "coverage": { "coverage": 100, "documents": 62529, "full": True, "nodes": 2, "results": 1, "resultsFull": 1, }, "children": [ { "id": "id:covid-19:doc::40216", "relevance": 10, "source": "content", "fields": { "vespa_id_field": "def", "sddocname": "doc", "body_text": "this is a body 2", "title": "this is a title 2", "rankfeatures": { "a": 3, "b": 4 }, }, }, { "id": "id:covid-19:doc::40217", "relevance": 8, "source": "content", "fields": { "vespa_id_field": "ghi", "sddocname": "doc", "body_text": "this is a body 3", "title": "this is a title 3", "rankfeatures": { "a": 5, "b": 6 }, }, }, ], } } def test_disable_rank_features(self): with self.assertRaises(AssertionError): self.app.collect_training_data_point( query="this is a query", query_id="123", relevant_id="abc", id_field="vespa_id_field", query_model=Query(), number_additional_docs=2, ) def test_collect_training_data_point(self): self.app.query = Mock(side_effect=[ VespaResult(self.raw_vespa_result_recall), VespaResult(self.raw_vespa_result_additional), ]) query_model = Query(rank_profile=RankProfile(list_features=True)) data = self.app.collect_training_data_point( query="this is a query", query_id="123", relevant_id="abc", id_field="vespa_id_field", query_model=query_model, number_additional_docs=2, timeout="15s", ) self.assertEqual(self.app.query.call_count, 2) self.app.query.assert_has_calls([ call( query="this is a query", query_model=query_model, recall=("vespa_id_field", ["abc"]), timeout="15s", ), call( query="this is a query", query_model=query_model, hits=2, timeout="15s", ), ]) expected_data = [ { "document_id": "abc", "query_id": "123", "relevant": 1, "a": 1, "b": 2 }, { "document_id": "def", "query_id": "123", "relevant": 0, "a": 3, "b": 4 }, { "document_id": "ghi", "query_id": "123", "relevant": 0, "a": 5, "b": 6 }, ] self.assertEqual(data, expected_data) def test_collect_training_data_point_0_recall_hits(self): self.raw_vespa_result_recall = { "root": { "id": "toplevel", "relevance": 1.0, "fields": { "totalCount": 0 }, "coverage": { "coverage": 100, "documents": 62529, "full": True, "nodes": 2, "results": 1, "resultsFull": 1, }, } } self.app.query = Mock(side_effect=[ VespaResult(self.raw_vespa_result_recall), VespaResult(self.raw_vespa_result_additional), ]) query_model = Query(rank_profile=RankProfile(list_features=True)) data = self.app.collect_training_data_point( query="this is a query", query_id="123", relevant_id="abc", id_field="vespa_id_field", query_model=query_model, number_additional_docs=2, timeout="15s", ) self.assertEqual(self.app.query.call_count, 1) self.app.query.assert_has_calls([ call( query="this is a query", query_model=query_model, recall=("vespa_id_field", ["abc"]), timeout="15s", ), ]) expected_data = [] self.assertEqual(data, expected_data) def test_collect_training_data(self): mock_return_value = [ { "document_id": "abc", "query_id": "123", "relevant": 1, "a": 1, "b": 2, }, { "document_id": "def", "query_id": "123", "relevant": 0, "a": 3, "b": 4, }, { "document_id": "ghi", "query_id": "123", "relevant": 0, "a": 5, "b": 6, }, ] self.app.collect_training_data_point = Mock( return_value=mock_return_value) labelled_data = [{ "query_id": 123, "query": "this is a query", "relevant_docs": [{ "id": "abc", "score": 1 }], }] query_model = Query(rank_profile=RankProfile(list_features=True)) data = self.app.collect_training_data( labelled_data=labelled_data, id_field="vespa_id_field", query_model=query_model, number_additional_docs=2, timeout="15s", ) self.app.collect_training_data_point.assert_has_calls([ call( query="this is a query", query_id=123, relevant_id="abc", id_field="vespa_id_field", query_model=query_model, number_additional_docs=2, relevant_score=1, default_score=0, timeout="15s", ) ]) assert_frame_equal(data, DataFrame.from_records(mock_return_value))
def setUp(self) -> None: self.app = Vespa(url="http://localhost", port=8080) self.labelled_data = [ { "query_id": 0, "query": "Intrauterine virus infections and congenital heart disease", "relevant_docs": [{ "id": "def", "score": 1 }, { "id": "abc", "score": 1 }], }, ] self.query_results = { "root": { "id": "toplevel", "relevance": 1.0, "fields": { "totalCount": 1083 }, "coverage": { "coverage": 100, "documents": 62529, "full": True, "nodes": 2, "results": 1, "resultsFull": 1, }, "children": [ { "id": "id:covid-19:doc::40216", "relevance": 10, "source": "content", "fields": { "vespa_id_field": "ghi", "sddocname": "doc", "body_text": "this is a body 2", "title": "this is a title 2", "rankfeatures": { "a": 3, "b": 4 }, }, }, { "id": "id:covid-19:doc::40217", "relevance": 8, "source": "content", "fields": { "vespa_id_field": "def", "sddocname": "doc", "body_text": "this is a body 3", "title": "this is a title 3", "rankfeatures": { "a": 5, "b": 6 }, }, }, ], } }
class TestVespaEvaluate(unittest.TestCase): def setUp(self) -> None: self.app = Vespa(url="http://localhost", port=8080) self.labelled_data = [ { "query_id": 0, "query": "Intrauterine virus infections and congenital heart disease", "relevant_docs": [{ "id": "def", "score": 1 }, { "id": "abc", "score": 1 }], }, ] self.query_results = { "root": { "id": "toplevel", "relevance": 1.0, "fields": { "totalCount": 1083 }, "coverage": { "coverage": 100, "documents": 62529, "full": True, "nodes": 2, "results": 1, "resultsFull": 1, }, "children": [ { "id": "id:covid-19:doc::40216", "relevance": 10, "source": "content", "fields": { "vespa_id_field": "ghi", "sddocname": "doc", "body_text": "this is a body 2", "title": "this is a title 2", "rankfeatures": { "a": 3, "b": 4 }, }, }, { "id": "id:covid-19:doc::40217", "relevance": 8, "source": "content", "fields": { "vespa_id_field": "def", "sddocname": "doc", "body_text": "this is a body 3", "title": "this is a title 3", "rankfeatures": { "a": 5, "b": 6 }, }, }, ], } } def test_evaluate_query(self): self.app.query = Mock(return_value={}) eval_metric = Mock() eval_metric.evaluate_query = Mock(return_value={"metric": 1}) eval_metric2 = Mock() eval_metric2.evaluate_query = Mock(return_value={"metric_2": 2}) query_model = Query() evaluation = self.app.evaluate_query( eval_metrics=[eval_metric, eval_metric2], query_model=query_model, query_id="0", query="this is a test", id_field="vespa_id_field", relevant_docs=self.labelled_data[0]["relevant_docs"], default_score=0, hits=10, ) self.assertEqual(self.app.query.call_count, 1) self.app.query.assert_has_calls([ call(query="this is a test", query_model=query_model, hits=10), ]) self.assertEqual(eval_metric.evaluate_query.call_count, 1) eval_metric.evaluate_query.assert_has_calls([ call({}, self.labelled_data[0]["relevant_docs"], "vespa_id_field", 0), ]) self.assertDictEqual(evaluation, { "query_id": "0", "metric": 1, "metric_2": 2 }) def test_evaluate(self): self.app.evaluate_query = Mock(side_effect=[ { "query_id": "0", "metric": 1 }, ]) evaluation = self.app.evaluate( labelled_data=self.labelled_data, eval_metrics=[Mock()], query_model=Mock(), id_field="mock", default_score=0, ) assert_frame_equal( evaluation, DataFrame.from_records([{ "query_id": "0", "metric": 1 }]))
class TwitterInserter: def run(self): try: config.load_incluster_config() except: config.load_kube_config() v1 = client.CoreV1Api() twitter_secrets = v1.read_namespaced_secret(name='twitter-secrets', namespace='default').data api_key = base64.b64decode(twitter_secrets["api-key"]).decode('utf-8') api_secret = base64.b64decode( twitter_secrets["api-secret"]).decode('utf-8') self.vespa = Vespa(url="http://vespa-search", port=8080) auth = tweepy.AppAuthHandler(api_key, api_secret) self.api = tweepy.API(auth) updated = 0 for userid in [ 'abcnews', 'GuardianAus', 'smh', 'iTnews_au', 'theage', 'canberratimes', 'zdnetaustralia', 'newscomauHQ', 'westaustralian', 'SBSNews', 'australian', 'crikey_news', '9NewsAUS', 'BBCNewsAus' ]: try: for status in tweepy.Cursor(self.api.user_timeline, id=userid, include_entities=True, tweet_mode="extended").items(60): if len(status.entities['urls']) == 0: continue url = status.entities['urls'][0]['expanded_url'] url = self.get_url(url) if (url.startswith("https://twitter.com") or url.startswith("https://www.reddit.com")): continue article = self.get_article(url) if article: self.update_document(article, status) updated += 1 # else: # self.insert_document(url) except Exception as e: print("exception! {}".format(e)) continue print("Completed run, updated {} tweets".format(updated)) def get_url(self, url): if (re.match(r'https?://zd.net', url) or url.startswith("https://trib.al") or url.startswith("https://bit.ly") or url.startswith("https://bbc.in")): url = urlopen(url).geturl() return self.get_url(url) else: return url.split('?')[0] # def insert_document(self, url): # payload = {'url': url } # requests.get("http://localhost:8000/", params=payload) # print("Hit spider url for {}".format(url)) def update_document(self, article, status): vespa_fields = {} vespa_fields['twitter_favourite_count'] = status.favorite_count vespa_fields['twitter_retweet_count'] = status.retweet_count vespa_fields[ 'twitter_link'] = 'https://twitter.com/{}/status/{}'.format( status.user.screen_name, status.id) response = self.vespa.update_data( schema="newsarticle", data_id=hashlib.sha256( article['fields']['url'].encode()).hexdigest(), fields=vespa_fields) print("Updated {} with {} {}: {}".format(article['fields']['url'], status.favorite_count, status.retweet_count, response)) def get_article(self, url): article_time = time.time() - 24 * 60 * 60 body = { 'yql': 'select url from sources newsarticle where userQuery();', 'query': "url:{}".format(url), 'hits': 1, } results = self.vespa.query(body=body) if len(results.hits) > 0: return results.hits[0]
class TwitterInserter: api_key = "TWITTER_API_KEY" api_secret = "TWITTER_API_SECRET" def run(self): self.vespa = Vespa(url="http://vespa-search", port=8080) auth = tweepy.AppAuthHandler(self.api_key, self.api_secret) self.api = tweepy.API(auth) updated = 0 for userid in [ 'abcnews', 'GuardianAus', 'smh', 'iTnews_au', 'theage', 'canberratimes', 'zdnetaustralia', 'newscomauHQ', 'westaustralian' ]: try: for status in tweepy.Cursor(self.api.user_timeline, id=userid, include_entities=True).items(60): if len(status.entities['urls']) == 0: continue url = status.entities['urls'][0]['expanded_url'] url = url.split('?')[0] if (url.startswith("https://twitter.com")): continue if (url.startswith("https://zd.net") or url.startswith("https://bit.ly")): url = urlopen(url).geturl() article = self.get_article(url) if article: self.update_document(article, status) updated += 1 except Exception as e: logger.error(e) print("Completed run, updated {} tweets".format(updated)) def update_document(self, article, status): vespa_fields = {} vespa_fields['twitter_favourite_count'] = status.favorite_count vespa_fields['twitter_retweet_count'] = status.retweet_count response = self.vespa.update_data( schema="newsarticle", data_id=hashlib.sha256( article['fields']['url'].encode()).hexdigest(), fields=vespa_fields) #print("Updated {} with {} {}: {}".format(article['fields']['url'], status.favorite_count, status.retweet_count, response)) def get_article(self, url): article_time = time.time() - 24 * 60 * 60 body = { 'yql': 'select url from sources newsarticle where userQuery();', 'query': "url:{}".format(url), 'hits': 1, } results = self.vespa.query(body=body) if len(results.hits) > 0: return results.hits[0] def get_twitter_user(self, url): if url.startswith("https://www.abc.net.au"): return "abcnews" if url.startswith("https://www.theguardian.com/"): return "GuardianAus" if url.startswith("https://www.smh.com.au"): return "smh" if url.startswith("https://www.itnews.com.au"): return "iTnews_au" if url.startswith("https://www.theage.com.au"): return "theage" if url.startswith("https://www.canberratimes.com.au"): return "canberratimes" if url.startswith("https://www.zdnet.com"): return "zdnetaustralia" if url.startswith("https://www.news.com.au"): return "newscomauHQ" if url.startswith("https://thewest.com.au"): return "westaustralian"
def test_workflow(self): # # Connect to a running Vespa Application # app = Vespa(url="https://api.cord19.vespa.ai") # # Define a query model # match_phase = Union( WeakAnd(hits=10), ANN( doc_vector="title_embedding", query_vector="title_vector", hits=10, label="title", ), ) rank_profile = Ranking(name="bm25", list_features=True) query_model = QueryModel( name="ANN_bm25", query_properties=[ QueryRankingFeature( name="title_vector", mapping=lambda x: [random() for x in range(768)], ) ], match_phase=match_phase, rank_profile=rank_profile, ) # # Query Vespa app # query_result = app.query( query="Is remdesivir an effective treatment for COVID-19?", query_model=query_model, ) self.assertTrue(query_result.number_documents_retrieved > 0) self.assertEqual(len(query_result.hits), 10) # # Define labelled data # labeled_data = [ { "query_id": 0, "query": "Intrauterine virus infections and congenital heart disease", "relevant_docs": [{ "id": 0, "score": 1 }, { "id": 3, "score": 1 }], }, { "query_id": 1, "query": "Clinical and immunologic studies in identical twins discordant for systemic lupus erythematosus", "relevant_docs": [{ "id": 1, "score": 1 }, { "id": 5, "score": 1 }], }, ] # equivalent data in df format labeled_data_df = DataFrame( data={ "qid": [0, 0, 1, 1], "query": ["Intrauterine virus infections and congenital heart disease"] * 2 + [ "Clinical and immunologic studies in identical twins discordant for systemic lupus erythematosus" ] * 2, "doc_id": [0, 3, 1, 5], "relevance": [1, 1, 1, 1], }) # # Collect training data # training_data_batch = app.collect_training_data( labeled_data=labeled_data, id_field="id", query_model=query_model, number_additional_docs=2, fields=["rankfeatures"], ) self.assertTrue(training_data_batch.shape[0] > 0) self.assertEqual( len({"document_id", "query_id", "label"}.intersection(set(training_data_batch.columns))), 3, ) # # Evaluate a query model # eval_metrics = [MatchRatio(), Recall(at=10), ReciprocalRank(at=10)] evaluation = app.evaluate( labeled_data=labeled_data, eval_metrics=eval_metrics, query_model=query_model, id_field="id", ) self.assertEqual(evaluation.shape, (9, 1)) # # AssertionError - two models with the same name # with self.assertRaises(AssertionError): _ = app.evaluate( labeled_data=labeled_data, eval_metrics=eval_metrics, query_model=[QueryModel(), QueryModel(), query_model], id_field="id", ) evaluation = app.evaluate( labeled_data=labeled_data, eval_metrics=eval_metrics, query_model=[QueryModel(), query_model], id_field="id", ) self.assertEqual(evaluation.shape, (9, 2)) evaluation = app.evaluate( labeled_data=labeled_data_df, eval_metrics=eval_metrics, query_model=query_model, id_field="id", detailed_metrics=True, ) self.assertEqual(evaluation.shape, (15, 1)) evaluation = app.evaluate( labeled_data=labeled_data_df, eval_metrics=eval_metrics, query_model=query_model, id_field="id", detailed_metrics=True, per_query=True, ) self.assertEqual(evaluation.shape, (2, 7))