def classify_text(text_list): text_analyzer = ZeroShotClassificationAnalyzer( model_name_or_path="joeddav/bart-large-mnli-yahoo-answers", ) return text_analyzer.analyze_input( source_response_list=text_list, analyzer_config=ClassificationAnalyzerConfig(labels=[ "no parking", "registration issue", "app issue", "payment issue" ], ))
def test_text_classification_analyzer(text_classification_analyzer, label_map, expected): source_responses = [ TextPayload(processed_text=text, source_name="sample") for text in BUY_SELL_TEXTS ] analyzer_responses = text_classification_analyzer.analyze_input( source_response_list=source_responses, analyzer_config=ClassificationAnalyzerConfig(label_map=label_map, ), ) assert len(analyzer_responses) == len(BUY_SELL_TEXTS) for analyzer_response in analyzer_responses: assert analyzer_response.segmented_data["classifier_data"] is not None assert analyzer_response.segmented_data["classifier_data"].keys( ) <= set(expected)
def test_zero_shot_analyzer(zero_shot_analyzer): labels = ["facility", "food", "comfortable", "positive", "negative"] source_responses = [AnalyzerRequest(processed_text=text, source_name="sample") for text in TEXTS] analyzer_responses = zero_shot_analyzer.analyze_input( source_response_list=source_responses, analyzer_config=ClassificationAnalyzerConfig( labels=labels ) ) assert len(analyzer_responses) == len(TEXTS) for analyzer_response in analyzer_responses: assert len(analyzer_response.segmented_data) == len(labels) assert "positive" in analyzer_response.segmented_data assert "negative" in analyzer_response.segmented_data
def test_classification_analyzer_with_splitter_aggregator( aggregate_function, zero_shot_analyzer): labels = ["facility", "food", "comfortable", "positive", "negative"] source_responses = [ TextPayload(processed_text=text, source_name="sample") for text in TEXTS ] analyzer_responses = zero_shot_analyzer.analyze_input( source_response_list=source_responses, analyzer_config=ClassificationAnalyzerConfig( labels=labels, use_splitter_and_aggregator=True, splitter_config=TextSplitterConfig(max_split_length=50), aggregator_config=InferenceAggregatorConfig( aggregate_function=aggregate_function), ), ) assert len(analyzer_responses) == len(TEXTS) for analyzer_response in analyzer_responses: assert "aggregator_data" in analyzer_response.segmented_data
max_results=150, after_date='2021-10-01', before_date='2021-10-31', ) # Fetch full news article source_config_with_full_text = GoogleNewsConfig( query="bitcoin", max_results=5, fetch_article=True, lookup_period="1d", ) source = GoogleNewsSource() analyzer_config = ClassificationAnalyzerConfig( labels=["buy", "sell", "going up", "going down"], ) text_analyzer = ZeroShotClassificationAnalyzer( model_name_or_path="typeform/mobilebert-uncased-mnli", device="auto") news_articles_without_full_text = source.lookup( source_config_without_full_text) news_articles_with_full_text = source.lookup(source_config_with_full_text) analyzer_responses_without_full_text = text_analyzer.analyze_input( source_response_list=news_articles_without_full_text, analyzer_config=analyzer_config, ) analyzer_responses_with_full_text = text_analyzer.analyze_input(
expansions=["author_id"], place_fields=None, max_tweets=10, ) source = TwitterSource() sink = DailyGetSink() text_analyzer = ZeroShotClassificationAnalyzer( model_name_or_path="joeddav/bart-large-mnli-yahoo-answers", # model_name_or_path="joeddav/xlm-roberta-large-xnli", ) source_response_list = source.lookup(source_config) for idx, source_response in enumerate(source_response_list): logger.info(f"source_response#'{idx}'='{source_response.__dict__}'") analyzer_response_list = text_analyzer.analyze_input( source_response_list=source_response_list, analyzer_config=ClassificationAnalyzerConfig( labels=["service", "delay", "tracking", "no response", "missing items", "delivery", "mask"], ) ) for idx, an_response in enumerate(analyzer_response_list): logger.info(f"analyzer_response#'{idx}'='{an_response.__dict__}'") # HTTP Sink sink_response_list = sink.send_data(analyzer_response_list, sink_config) for sink_response in sink_response_list: if sink_response is not None: logger.info(f"sink_response='{sink_response.__dict__}'")
jira_sink_config = JiraSinkConfig( url="http://localhost:2990/jira", username=SecretStr("admin"), password=SecretStr("admin"), issue_type={"name": "Task"}, project={"key": "CUS"}, ) jira_sink = JiraSink() text_analyzer = ZeroShotClassificationAnalyzer( model_name_or_path="joeddav/bart-large-mnli-yahoo-answers" ) source_response_list = source.lookup(source_config) for idx, source_response in enumerate(source_response_list): logger.info(f"source_response#'{idx}'='{source_response.__dict__}'") analyzer_response_list = text_analyzer.analyze_input( source_response_list=source_response_list, analyzer_config=ClassificationAnalyzerConfig( labels=["service", "delay", "performance"], ), ) for idx, an_response in enumerate(analyzer_response_list): logger.info(f"analyzer_response#'{idx}'='{an_response.__dict__}'") sink_response_list = jira_sink.send_data(analyzer_response_list, jira_sink_config) for sink_response in sink_response_list: if sink_response is not None: logger.info(f"sink_response='{sink_response}'")
PlayStoreScrapperSource, ) logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, level=logging.INFO) since_time = datetime.utcnow().astimezone(pytz.utc) + timedelta(days=-4) source_config = PlayStoreScrapperConfig( countries=["us"], package_name="com.apcoaconnect", lookup_period=since_time.strftime(DATETIME_STRING_PATTERN), ) source = PlayStoreScrapperSource() text_analyzer = ZeroShotClassificationAnalyzer( model_name_or_path="joeddav/bart-large-mnli-yahoo-answers", device="auto") source_response_list = source.lookup(source_config) for idx, source_response in enumerate(source_response_list): logger.info(f"source_response#'{idx}'='{source_response.__dict__}'") analyzer_response_list = text_analyzer.analyze_input( source_response_list=source_response_list, analyzer_config=ClassificationAnalyzerConfig(labels=[ "no parking", "registration issue", "app issue", "payment issue" ], ), ) for idx, an_response in enumerate(analyzer_response_list): logger.info(f"analyzer_response#'{idx}'='{an_response.__dict__}'")
from obsei.source.playstore_scrapper import ( PlayStoreScrapperConfig, PlayStoreScrapperSource, ) logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, level=logging.INFO) source_config = PlayStoreScrapperConfig( app_url= 'https://play.google.com/store/apps/details?id=com.google.android.gm&hl=en_IN&gl=US', max_count=3) source = PlayStoreScrapperSource() text_analyzer = ZeroShotClassificationAnalyzer( model_name_or_path="typeform/mobilebert-uncased-mnli", device="auto") source_response_list = source.lookup(source_config) for idx, source_response in enumerate(source_response_list): logger.info(f"source_response#'{idx}'='{source_response.__dict__}'") analyzer_response_list = text_analyzer.analyze_input( source_response_list=source_response_list, analyzer_config=ClassificationAnalyzerConfig( labels=["interface", "slow", "battery"], ), ) for idx, an_response in enumerate(analyzer_response_list): logger.info(f"analyzer_response#'{idx}'='{an_response.__dict__}'")
ZeroShotClassificationAnalyzer, ) from obsei.source import YoutubeScrapperSource, YoutubeScrapperConfig logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, level=logging.INFO) source_config = YoutubeScrapperConfig( video_url="https://www.youtube.com/watch?v=uZfns0JIlFk", fetch_replies=True, max_comments=10, lookup_period="1Y", ) source = YoutubeScrapperSource() source_response_list = source.lookup(source_config) for idx, source_response in enumerate(source_response_list): logger.info(f"source_response#'{idx}'='{source_response.__dict__}'") text_analyzer = ZeroShotClassificationAnalyzer( model_name_or_path="typeform/mobilebert-uncased-mnli", device="auto") analyzer_response_list = text_analyzer.analyze_input( source_response_list=source_response_list, analyzer_config=ClassificationAnalyzerConfig( labels=["interesting", "enquiring"], ), ) for idx, an_response in enumerate(analyzer_response_list): logger.info(f"analyzer_response#'{idx}'='{an_response.__dict__}'")