예제 #1
0
def classify_text(text_list):
    text_analyzer = ZeroShotClassificationAnalyzer(
        model_name_or_path="joeddav/bart-large-mnli-yahoo-answers", )

    return text_analyzer.analyze_input(
        source_response_list=text_list,
        analyzer_config=ClassificationAnalyzerConfig(labels=[
            "no parking", "registration issue", "app issue", "payment issue"
        ], ))
예제 #2
0
def test_text_classification_analyzer(text_classification_analyzer, label_map,
                                      expected):
    source_responses = [
        TextPayload(processed_text=text, source_name="sample")
        for text in BUY_SELL_TEXTS
    ]
    analyzer_responses = text_classification_analyzer.analyze_input(
        source_response_list=source_responses,
        analyzer_config=ClassificationAnalyzerConfig(label_map=label_map, ),
    )

    assert len(analyzer_responses) == len(BUY_SELL_TEXTS)

    for analyzer_response in analyzer_responses:
        assert analyzer_response.segmented_data["classifier_data"] is not None
        assert analyzer_response.segmented_data["classifier_data"].keys(
        ) <= set(expected)
예제 #3
0
def test_zero_shot_analyzer(zero_shot_analyzer):
    labels = ["facility", "food", "comfortable", "positive", "negative"]

    source_responses = [AnalyzerRequest(processed_text=text, source_name="sample") for text in TEXTS]
    analyzer_responses = zero_shot_analyzer.analyze_input(
        source_response_list=source_responses,
        analyzer_config=ClassificationAnalyzerConfig(
            labels=labels
        )
    )

    assert len(analyzer_responses) == len(TEXTS)

    for analyzer_response in analyzer_responses:
        assert len(analyzer_response.segmented_data) == len(labels)
        assert "positive" in analyzer_response.segmented_data
        assert "negative" in analyzer_response.segmented_data
예제 #4
0
def test_classification_analyzer_with_splitter_aggregator(
        aggregate_function, zero_shot_analyzer):
    labels = ["facility", "food", "comfortable", "positive", "negative"]

    source_responses = [
        TextPayload(processed_text=text, source_name="sample")
        for text in TEXTS
    ]
    analyzer_responses = zero_shot_analyzer.analyze_input(
        source_response_list=source_responses,
        analyzer_config=ClassificationAnalyzerConfig(
            labels=labels,
            use_splitter_and_aggregator=True,
            splitter_config=TextSplitterConfig(max_split_length=50),
            aggregator_config=InferenceAggregatorConfig(
                aggregate_function=aggregate_function),
        ),
    )

    assert len(analyzer_responses) == len(TEXTS)

    for analyzer_response in analyzer_responses:
        assert "aggregator_data" in analyzer_response.segmented_data
예제 #5
0
    max_results=150,
    after_date='2021-10-01',
    before_date='2021-10-31',
)

# Fetch full news article
source_config_with_full_text = GoogleNewsConfig(
    query="bitcoin",
    max_results=5,
    fetch_article=True,
    lookup_period="1d",
)

source = GoogleNewsSource()

analyzer_config = ClassificationAnalyzerConfig(
    labels=["buy", "sell", "going up", "going down"], )

text_analyzer = ZeroShotClassificationAnalyzer(
    model_name_or_path="typeform/mobilebert-uncased-mnli", device="auto")

news_articles_without_full_text = source.lookup(
    source_config_without_full_text)

news_articles_with_full_text = source.lookup(source_config_with_full_text)

analyzer_responses_without_full_text = text_analyzer.analyze_input(
    source_response_list=news_articles_without_full_text,
    analyzer_config=analyzer_config,
)

analyzer_responses_with_full_text = text_analyzer.analyze_input(
예제 #6
0
    expansions=["author_id"],
    place_fields=None,
    max_tweets=10,
)

source = TwitterSource()
sink = DailyGetSink()
text_analyzer = ZeroShotClassificationAnalyzer(
    model_name_or_path="joeddav/bart-large-mnli-yahoo-answers",
 #   model_name_or_path="joeddav/xlm-roberta-large-xnli",
)

source_response_list = source.lookup(source_config)
for idx, source_response in enumerate(source_response_list):
    logger.info(f"source_response#'{idx}'='{source_response.__dict__}'")

analyzer_response_list = text_analyzer.analyze_input(
    source_response_list=source_response_list,
    analyzer_config=ClassificationAnalyzerConfig(
            labels=["service", "delay", "tracking", "no response", "missing items", "delivery", "mask"],
        )
)
for idx, an_response in enumerate(analyzer_response_list):
    logger.info(f"analyzer_response#'{idx}'='{an_response.__dict__}'")

# HTTP Sink
sink_response_list = sink.send_data(analyzer_response_list, sink_config)
for sink_response in sink_response_list:
    if sink_response is not None:
        logger.info(f"sink_response='{sink_response.__dict__}'")
예제 #7
0
jira_sink_config = JiraSinkConfig(
    url="http://localhost:2990/jira",
    username=SecretStr("admin"),
    password=SecretStr("admin"),
    issue_type={"name": "Task"},
    project={"key": "CUS"},
)
jira_sink = JiraSink()

text_analyzer = ZeroShotClassificationAnalyzer(
    model_name_or_path="joeddav/bart-large-mnli-yahoo-answers"
)

source_response_list = source.lookup(source_config)
for idx, source_response in enumerate(source_response_list):
    logger.info(f"source_response#'{idx}'='{source_response.__dict__}'")

analyzer_response_list = text_analyzer.analyze_input(
    source_response_list=source_response_list,
    analyzer_config=ClassificationAnalyzerConfig(
        labels=["service", "delay", "performance"],
    ),
)
for idx, an_response in enumerate(analyzer_response_list):
    logger.info(f"analyzer_response#'{idx}'='{an_response.__dict__}'")

sink_response_list = jira_sink.send_data(analyzer_response_list, jira_sink_config)
for sink_response in sink_response_list:
    if sink_response is not None:
        logger.info(f"sink_response='{sink_response}'")
예제 #8
0
    PlayStoreScrapperSource,
)

logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

since_time = datetime.utcnow().astimezone(pytz.utc) + timedelta(days=-4)
source_config = PlayStoreScrapperConfig(
    countries=["us"],
    package_name="com.apcoaconnect",
    lookup_period=since_time.strftime(DATETIME_STRING_PATTERN),
)

source = PlayStoreScrapperSource()

text_analyzer = ZeroShotClassificationAnalyzer(
    model_name_or_path="joeddav/bart-large-mnli-yahoo-answers", device="auto")

source_response_list = source.lookup(source_config)
for idx, source_response in enumerate(source_response_list):
    logger.info(f"source_response#'{idx}'='{source_response.__dict__}'")

analyzer_response_list = text_analyzer.analyze_input(
    source_response_list=source_response_list,
    analyzer_config=ClassificationAnalyzerConfig(labels=[
        "no parking", "registration issue", "app issue", "payment issue"
    ], ),
)
for idx, an_response in enumerate(analyzer_response_list):
    logger.info(f"analyzer_response#'{idx}'='{an_response.__dict__}'")
from obsei.source.playstore_scrapper import (
    PlayStoreScrapperConfig,
    PlayStoreScrapperSource,
)

logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

source_config = PlayStoreScrapperConfig(
    app_url=
    'https://play.google.com/store/apps/details?id=com.google.android.gm&hl=en_IN&gl=US',
    max_count=3)

source = PlayStoreScrapperSource()

text_analyzer = ZeroShotClassificationAnalyzer(
    model_name_or_path="typeform/mobilebert-uncased-mnli", device="auto")

source_response_list = source.lookup(source_config)
for idx, source_response in enumerate(source_response_list):
    logger.info(f"source_response#'{idx}'='{source_response.__dict__}'")

analyzer_response_list = text_analyzer.analyze_input(
    source_response_list=source_response_list,
    analyzer_config=ClassificationAnalyzerConfig(
        labels=["interface", "slow", "battery"], ),
)
for idx, an_response in enumerate(analyzer_response_list):
    logger.info(f"analyzer_response#'{idx}'='{an_response.__dict__}'")
예제 #10
0
    ZeroShotClassificationAnalyzer,
)
from obsei.source import YoutubeScrapperSource, YoutubeScrapperConfig

logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

source_config = YoutubeScrapperConfig(
    video_url="https://www.youtube.com/watch?v=uZfns0JIlFk",
    fetch_replies=True,
    max_comments=10,
    lookup_period="1Y",
)

source = YoutubeScrapperSource()

source_response_list = source.lookup(source_config)
for idx, source_response in enumerate(source_response_list):
    logger.info(f"source_response#'{idx}'='{source_response.__dict__}'")

text_analyzer = ZeroShotClassificationAnalyzer(
    model_name_or_path="typeform/mobilebert-uncased-mnli", device="auto")

analyzer_response_list = text_analyzer.analyze_input(
    source_response_list=source_response_list,
    analyzer_config=ClassificationAnalyzerConfig(
        labels=["interesting", "enquiring"], ),
)
for idx, an_response in enumerate(analyzer_response_list):
    logger.info(f"analyzer_response#'{idx}'='{an_response.__dict__}'")