Пример #1
0
def predict():
    print("PREDICT ROUTE...")
    print("FORM DATA:", dict(request.form))
    #> {'screen_name_a': 'elonmusk', 'screen_name_b': 's2t2', 'tweet_text': 'Example tweet text here'}
    screen_name_a = request.form["screen_name_a"]
    screen_name_b = request.form["screen_name_b"]
    tweet_text = request.form["tweet_text"]
    #return "OK (TODO)"

    print("-----------------")
    print("FETCHING TWEETS FROM THE DATABASE...")
    # todo: wrap in a try block in case the user's don't exist in the database
    user_a = User.query.filter(User.screen_name == screen_name_a).one()
    user_b = User.query.filter(User.screen_name == screen_name_b).one()
    user_a_tweets = user_a.tweets
    user_b_tweets = user_b.tweets
    #user_a_embeddings = [tweet.embedding for tweet in user_a_tweets]
    #user_b_embeddings = [tweet.embedding for tweet in user_b_tweets]
    print("USER A", user_a.screen_name, len(user_a.tweets))
    print("USER B", user_b.screen_name, len(user_b.tweets))
    #breakpoint()

    print("-----------------")
    print("TRAINING THE MODEL...")
    embeddings = []  # X
    labels = []  # y
    for tweet in user_a_tweets:
        labels.append(user_a.screen_name)
        embeddings.append(tweet.embedding)

    for tweet in user_b_tweets:
        labels.append(user_b.screen_name)
        embeddings.append(tweet.embedding)

    #breakpoint()
    # inspect the x and y values to make sure they are the best format for training
    # maybe need to impute?
    classifier = LogisticRegression(random_state=0,
                                    solver='lbfgs')  # for example
    classifier.fit(embeddings, labels)
    # todo: make sure there are an even number of tweets for each user
    #breakpoint()

    print("-----------------")
    print("MAKING A PREDICTION...")
    #result_a = classifier.predict([user_a_tweets[0].embedding])
    #result_b = classifier.predict([user_b_tweets[0].embedding])

    basilica_conn = basilica_api_client()
    example_embedding = basilica_conn.embed_sentence(tweet_text,
                                                     model="twitter")
    result = classifier.predict([example_embedding])
    #breakpoint()

    #return jsonify({"message": "RESULTS", "most_likely": result[0]})
    return render_template("results.html",
                           screen_name_a=screen_name_a,
                           screen_name_b=screen_name_b,
                           tweet_text=tweet_text,
                           screen_name_most_likely=result[0])
def store_twitter_user_data(screen_name):
    api = twitter_api_client()
    twitter_user = api.get_user(screen_name)
    statuses = api.user_timeline(screen_name, tweet_mode="extended", count=150)

    db_user = User.query.get(twitter_user.id) or User(id=twitter_user.id)
    db_user.screen_name = twitter_user.screen_name
    db_user.name = twitter_user.name
    db_user.location = twitter_user.location
    db_user.followers_count = twitter_user.followers_count
    db.session.add(db_user)
    db.session.commit()

    print("STATUS COUNT:", len(statuses))
    basilica_api = basilica_api_client()
    all_tweet_texts = [status.full_text for status in statuses]
    embeddings = list(
        basilica_api.embed_sentences(all_tweet_texts, model="twitter"))
    print("NUMBER OF EMBEDDINGS", len(embeddings))
    counter = 0
    for status in statuses:
        print(status.full_text)
        print("----")
        db_tweet = Tweet.query.get(status.id) or Tweet(id=status.id)
        db_tweet.user_id = status.author.id
        db_tweet.full_text = status.full_text
        embedding = embeddings[counter]
        print(len(embedding))
        db_tweet.embedding = embedding
        db.session.add(db_tweet)
        counter += 1
    db.session.commit()

    return db_user, statuses
Пример #3
0
def results():
    screen_name_a = request.form["screen_name_a"]
    screen_name_b = request.form["screen_name_b"]
    tweet_text = request.form["tweet_text"]

    user_a = User.query.filter(User.screen_name == screen_name_a).one()
    user_b = User.query.filter(User.screen_name == screen_name_b).one()
    user_a_tweets = user_a.tweets
    user_b_tweets = user_b.tweets

    embeddings = []
    labels = []
    for tweet in user_a_tweets:
        labels.append(user_a.screen_name)
        embeddings.append(tweet.embedding)

    for tweet in user_b_tweets:
        labels.append(user_b.screen_name)
        embeddings.append(tweet.embedding)

    classifier = LogisticRegression(random_state=0, solver='lbfgs')
    classifier.fit(embeddings, labels)

    basilica_conn = basilica_api_client()
    example_embedding = basilica_conn.embed_sentence(tweet_text,
                                                     model="twitter")
    result = classifier.predict([example_embedding])

    return render_template("results.html",
                           screen_name_a=screen_name_a,
                           screen_name_b=screen_name_b,
                           tweet_text=tweet_text,
                           screen_name_most_likely=result[0])
Пример #4
0
def predict():
    print("PREDICT ROUTE...")
    print("FORM DATA:", dict(request.form))
    #> {'screen_name_a': 'elonmusk', 'screen_name_b': 's2t2', 'tweet_text': 'Example tweet text here'}
    screen_name_a = request.form["screen_name_a"]
    screen_name_b = request.form["screen_name_b"]
    tweet_text = request.form["tweet_text"]

    print("-----------------")
    print("FETCHING TWEETS FROM THE DATABASE...")

    # todo: wrap in a try block in case the user's don't exist in the database

    user_a = User.query.filter_by(User.screen_name == screen_name_a).one()
    user_b = User.query.filter_by(User.screen_name == screen_name_b).one()
    user_a_tweets = user_a.tweets
    user_b_tweets = user_b.tweets
    #user_a_embeddings = [tweet.embedding for tweet in user_a_tweets]
    #user_b_embeddings = [tweet.embedding for tweet in user_b_tweets]
    print("USER A", user_a.screen_name, len(user_a.tweets))
    print("USER B", user_b.screen_name, len(user_b.tweets))

    print("-----------------")
    print("TRAINING THE MODEL...")
    embeddings = []
    labels = []
    for tweet in user_a_tweets:
        labels.append(user_a.screen_name)
        embeddings.append(tweet.embedding)

    for tweet in user_b_tweets:
        labels.append(user_b.screen_name)
        embeddings.append(tweet.embedding)

    classifier = LogisticRegression()  # for example
    classifier.fit(embeddings, labels)

    print("-----------------")
    print("MAKING A PREDICTION...")
    #result_a = classifier.predict([user_a_tweets[0].embedding])
    #result_b = classifier.predict([user_b_tweets[0].embedding])

    basilica_api = basilica_api_client()
    example_embedding = basilica_api.embed_sentence(tweet_text,
                                                    model='twitter')
    result = classifier.predict([example_embedding])
    #breakpoint()

    #return jsonify({"message": "RESULTS", "most_likely": result[0]})
    return render_template("results.html",
                           screen_name_a=screen_name_a,
                           screen_name_b=screen_name_b,
                           tweet_text=tweet_text,
                           screen_name_most_likely=result[0])
Пример #5
0
def get_user(screen_name=None):
    print(screen_name)

    api = twitter_api()

    twitter_user = api.get_user(screen_name)
    statuses = api.user_timeline(screen_name,
                                 tweet_mode="extended",
                                 count=150,
                                 exclude_replies=True,
                                 include_rts=False)
    print("STATUSES COUNT:", len(statuses))
    #return jsonify({"user": user._json, "tweets": [s._json for s in statuses]})

    # get existing user from the db or initialize a new one:
    db_user = User.query.get(twitter_user.id) or User(id=twitter_user.id)
    db_user.screen_name = twitter_user.screen_name
    db_user.name = twitter_user.name
    db_user.location = twitter_user.location
    db_user.followers_count = twitter_user.followers_count
    db.session.add(db_user)
    db.session.commit()
    #return "OK"
    #breakpoint()

    basilica_api = basilica_api_client()

    all_tweet_texts = [status.full_text for status in statuses]
    embeddings = list(
        basilica_api.embed_sentences(all_tweet_texts, model="twitter"))
    print("NUMBER OF EMBEDDINGS", len(embeddings))

    # TODO: explore using the zip() function maybe...
    counter = 0
    for status in statuses:
        print(status.full_text)
        print("----")
        #print(dir(status))
        # get existing tweet from the db or initialize a new one:
        db_tweet = Tweet.query.get(status.id) or Tweet(id=status.id)
        db_tweet.user_id = status.author.id  # or db_user.id
        db_tweet.full_text = status.full_text
        #embedding = basilica_client.embed_sentence(status.full_text, model="twitter") # todo: prefer to make a single request to basilica with all the tweet texts, instead of a request per tweet
        embedding = embeddings[counter]
        print(len(embedding))
        db_tweet.embedding = embedding
        db.session.add(db_tweet)
        counter += 1
    db.session.commit()
    #breakpoint()
    #return "OK"
    return render_template("user.html", user=db_user,
                           tweets=statuses)  # tweets=db_tweets
Пример #6
0
def create_user():
    print("FORM DATA", dict(request.form))
    # return jsonify({
    #     "message": "it works!"
    # })
    new_user = User(
        screen_name=request.form["screen_name"]
    )
    screen_name = request.form["screen_name"]
    print(screen_name)

    twitter_user = twitter_api_client.get_user(screen_name)
    statuses = twitter_api_client.user_timeline(
        screen_name, tweet_mode="extended", count=150)
    print("STATUSES COUNT:", len(statuses))
    # return jsonify({"user": user._json, "tweets": [s._json for s in statuses]})

    # get existing user from the db or initialize a new one:
    db_user = User.query.get(twitter_user.id) or User(id=twitter_user.id)
    db_user.screen_name = twitter_user.screen_name
    db_user.name = twitter_user.name
    db_user.location = twitter_user.location
    db_user.followers_count = twitter_user.followers_count
    db.session.add(db_user)
    db.session.commit()
    # return "OK"
    # breakpoint()

    all_tweet_texts = [status.full_text for status in statuses]
    embeddings = list(basilica_api_client().embed_sentences(
        all_tweet_texts, model="twitter"))
    print("NUMBER OF EMBEDDINGS", len(embeddings))

    counter = 0
    for status in statuses:
        print(status.full_text)
        print("----")
        # get existing tweet from the db or initialize a new one:
        db_tweet = Tweet.query.get(status.id) or Tweet(id=status.id)
        db_tweet.user_id = status.author.id  # or db_user.id
        db_tweet.full_text = status.full_text
        embedding = embeddings[counter]
        print(len(embedding))
        db_tweet.embedding = embedding
        db.session.add(db_tweet)
        counter += 1
    db.session.commit()
    # return "OK"
    # tweets=db_tweets
    return render_template("user.html", user=db_user, tweets=statuses)
Пример #7
0
def fetch_user(screen_name=None):
    print(screen_name)

    api = twitter_api()

    twitter_user = api.get_user(screen_name)

    # Get user from database if exists, if not initialaize new one:
    db_user = User.query.get(twitter_user.id) or User(id=twitter_user.id)

    db_user.screen_name = twitter_user.screen_name
    db_user.name = twitter_user.name
    db_user.location = twitter_user.location
    db_user.followers_count = twitter_user.followers_count

    # store user in a database:
    db.session.add(db_user)
    db.session.commit()

    # Get tweets:
    basilica_api = basilica_api_client()
    tweets = api.user_timeline(screen_name, tweet_mode="extended", count=150)
    # exclude_replies=True, include_rts=False)

    all_tweet_texts = [status.full_text for status in tweets]
    embeddings = list(
        basilica_api.embed_sentences(all_tweet_texts, model="twitter"))
    print("Number Of Embeddings", len(embeddings))

    for index, status in enumerate(tweets):
        print(index)
        print(status.full_text)
        print("----")

        embedding = embeddings[index]

        # get existing tweet from the db or initialize a new one:
        db_tweet = Tweet.query.get(status.id) or Tweet(id=status.id)

        db_tweet.user_id = status.author.id  # or db_user.id
        db_tweet.full_text = status.full_text

        # embedding = embeddings[counter]
        print(len(embeddings))
        db_tweet.embedding = embedding
        db.session.add(db_tweet)
        db.session.commit()

    return "OK"
Пример #8
0
def store_twitter_user_data(screen_name):
    api = twitter_api_client()

    twitter_user = api.get_user(screen_name)
    statuses = api.user_timeline(screen_name,
                                 tweet_mode="extended",
                                 count=200,
                                 exclude_replies=True,
                                 include_rts=False)
    # return jsonify({"user": user._json, "tweets": [s._json for s in statuses]})

    # store users
    db_user = User.query.get(twitter_user.id) or User(id=twitter_user.id)
    db_user.screen_name = twitter_user.screen_name
    db_user.name = twitter_user.name
    db_user.location = twitter_user.location
    db_user.followers_count = twitter_user.followers_count
    db.session.add(db_user)
    db.session.commit()
    #return "OKAY"

    # store basilica embedded tweets
    print("STATUS COUNT:", len(statuses))
    basilica_api = basilica_api_client()
    all_tweet_texts = [status.full_text for status in statuses]
    embeddings = list(
        basilica_api.embed_sentences(all_tweet_texts, model="twitter"))
    print("NUMBER OF EMBEDDINGS", len(embeddings))

    # TODO: explore using the zip() function maybe...
    counter = 0
    for status in statuses:
        print(status.full_text)
        print("----")
        #print(dir(status))

        # Find or create database tweet:
        db_tweet = Tweet.query.get(status.id) or Tweet(id=status.id)
        db_tweet.user_id = status.author.id  # or db_user.id
        db_tweet.full_text = status.full_text
        #embedding = basilica_client.embed_sentence(status.full_text, model="twitter") # todo: prefer to make a single request to basilica with all the tweet texts, instead of a request per tweet
        embedding = embeddings[counter]
        print(len(embedding))
        db_tweet.embedding = embedding
        db.session.add(db_tweet)
        counter += 1
    db.session.commit()

    return db_user, statuses
Пример #9
0
def store_twitter_user_data(screen_name):
    # Get the username from the twitter api and save the
    # user and tweets objects
    api = twitter_api_client()
    twitter_user = api.get_user(screen_name)
    statuses = api.user_timeline(screen_name,
                                 tweet_mode="extended",
                                 count=150,
                                 exclude_replies=True,
                                 include_rts=False)

    # Check to see if the user already exists in the db user table and
    # if not then add it to the db user table
    db_user = User.query.get(twitter_user.id) or User(id=twitter_user.id)
    db_user.screen_name = twitter_user.screen_name
    db_user.name = twitter_user.name
    db_user.location = twitter_user.location
    db_user.followers_count = twitter_user.followers_count
    db.session.add(db_user)
    db.session.commit()

    print("STATUS COUNT:", len(statuses))
    # Use the basilica api to turn the tweets (statuses) into numeric lists
    basilica_api = basilica_api_client()
    all_tweet_texts = [status.full_text for status in statuses]
    embeddings = list(
        basilica_api.embed_sentences(all_tweet_texts, model="twitter"))
    print("NUMBER OF EMBEDDINGS", len(embeddings))

    # Put all of the tweets in the db tweet table
    counter = 0
    for status in statuses:
        print(status.full_text)
        print("----")

        # Find or create database tweet:
        db_tweet = Tweet.query.get(status.id) or Tweet(id=status.id)
        db_tweet.user_id = status.author.id  # or db_user.id
        db_tweet.full_text = status.full_text
        embedding = embeddings[counter]
        print(len(embedding))
        db_tweet.embedding = embedding
        db.session.add(db_tweet)
        counter += 1
    db.session.commit()

    return db_user, statuses
Пример #10
0
def predict():
    print("PREDICT ROUTE...")
    print("FORM DATA:", dict(request.form))
    screen_name_a = request.form["screen_name_a"]
    screen_name_b = request.form["screen_name_b"]
    tweet_text = request.form["tweet_text"]

    print("-----------------")
    print("FETCHING TWEETS FROM THE DATABASE...")
    user_a = User.query.filter(User.screen_name == screen_name_a).one()
    user_b = User.query.filter(User.screen_name == screen_name_b).one()
    user_a_tweets = user_a.tweets
    user_b_tweets = user_b.tweets
    print("USER A", user_a.screen_name, len(user_a.tweets))
    print("USER B", user_b.screen_name, len(user_b.tweets))

    print("-----------------")
    print("TRAINING THE MODEL...")
    embeddings = []
    labels = []
    for tweet in user_a_tweets:
        labels.append(user_a.screen_name)
        embeddings.append(tweet.embedding)

    for tweet in user_b_tweets:
        labels.append(user_b.screen_name)
        embeddings.append(tweet.embedding)

    classifier = LogisticRegression(random_state=0,
                                    solver='lbfgs')  # for example
    classifier.fit(embeddings, labels)

    print("-----------------")
    print("MAKING A PREDICTION...")

    basilica_conn = basilica_api_client()
    example_embedding = basilica_conn.embed_sentence(tweet_text,
                                                     model="twitter")
    result = classifier.predict([example_embedding])

    return render_template("results.html",
                           screen_name_a=screen_name_a,
                           screen_name_b=screen_name_b,
                           tweet_text=tweet_text,
                           screen_name_most_likely=result[0])
Пример #11
0
def twitoff_prediction():
    print("FORM DATA:", dict(request.form))
    screen_name_a = request.form["screen_name_a"]
    screen_name_b = request.form["screen_name_b"]
    tweet_text = request.form["tweet_text"]

    #
    # TRAIN THE MODEL
    #
    # inputs: embeddings for each tweet
    # labels: screen name for each tweet

    model = LogisticRegression(max_iter=1000)

    user_a = User.query.filter(User.screen_name == screen_name_a).one()
    user_b = User.query.filter(User.screen_name == screen_name_b).one()

    user_a_tweets = user_a.tweets
    user_b_tweets = user_b.tweets

    embeddings = []
    labels = []
    all_tweets = user_a_tweets + user_b_tweets
    for tweet in all_tweets:
        embeddings.append(tweet.embedding)
        labels.append(tweet.user.screen_name)

    model.fit(embeddings, labels)

    #
    # MAKE PREDICTION
    #

    basilica_connection = basilica_api_client()

    example_embedding = basilica_connection.embed_sentence(tweet_text, model="twitter")
    result = model.predict([example_embedding])
    screen_name_most_likely = result[0]

    return render_template("prediction_results.html",
        screen_name_a=screen_name_a,
        screen_name_b=screen_name_b,
        tweet_text=tweet_text,
        screen_name_most_likely=screen_name_most_likely
    )
def get_user(screen_name=None):
    # print(screen_name)

    api = twitter_api_client()

    twitter_user = api.get_user(screen_name)
    statuses = api.user_timeline(screen_name, tweet_mode="extended", count=150)
    # print("STATUSES COUNT:", len(statuses))

    # get existing user from the db or initialize a new one:
    db_user = User.query.get(twitter_user.id) or User(id=twitter_user.id)
    db_user.screen_name = twitter_user.screen_name
    db_user.name = twitter_user.name
    db_user.location = twitter_user.location
    db_user.followers_count = twitter_user.followers_count
    db.session.add(db_user)
    db.session.commit()

    basilica_api = basilica_api_client()

    all_tweet_texts = [status.full_text for status in statuses]
    embeddings = list(
        basilica_api.embed_sentences(all_tweet_texts, model="twitter"))
    # print("NUMBER OF EMBEDDINGS", len(embeddings))

    counter = 0
    for status in statuses:
        # print(status.full_text)
        # print("----")
        # get existing tweet from the db or initialize a new one:
        db_tweet = Tweet.query.get(status.id) or Tweet(id=status.id)
        db_tweet.user_id = status.author.id  # or db_user.id
        db_tweet.full_text = status.full_text
        embedding = embeddings[counter]
        # print(len(embedding))
        db_tweet.embedding = embedding
        db.session.add(db_tweet)
        counter += 1
    db.session.commit()
    # return "OK"
    return render_template("user.html", user=db_user, tweets=statuses)
Пример #13
0
def seed_db():
    # print(type(db))
    api = twitter_api_client()

    for screen_name in ['elonmusk','justinbieber','s2t2']:
        twitter_user = api.get_user(screen_name)
        statuses = api.user_timeline(screen_name, tweet_mode="extended", count=150)
        # print("STATUSES COUNT:", len(statuses))
        
        # get existing user from the db or initialize a new one:
        db_user = User.query.get(twitter_user.id) or User(id=twitter_user.id)
        db_user.screen_name = twitter_user.screen_name
        db_user.name = twitter_user.name
        db_user.location = twitter_user.location
        db_user.followers_count = twitter_user.followers_count
        db.session.add(db_user)
        db.session.commit()

        basilica_api = basilica_api_client()

        all_tweet_texts = [status.full_text for status in statuses]
        embeddings = list(basilica_api.embed_sentences(all_tweet_texts, model="twitter"))
        # print("NUMBER OF EMBEDDINGS", len(embeddings))

        counter = 0
        for status in statuses:
            # print(status.full_text)
            # print("----")
            # get existing tweet from the db or initialize a new one:
            db_tweet = Tweet.query.get(status.id) or Tweet(id=status.id)
            db_tweet.user_id = status.author.id # or db_user.id
            db_tweet.full_text = status.full_text
            embedding = embeddings[counter]
            # print(len(embedding))
            db_tweet.embedding = embedding
            db.session.add(db_tweet)
            counter+=1
    db.session.commit()

    return jsonify({"message": "DB SEEDED OK"})
Пример #14
0
def store_twitter_user_data(screen_name):
    print(screen_name)

    api = twitter_api_client()
    twitter_user = api.get_user(screen_name)
    statuses = api.user_timeline(screen_name, tweet_mode="extended", count=150)

    db_user = User.query.get(twitter_user.id) or User(id=twitter_user.id)
    db_user.screen_name = twitter_user.screen_name
    db_user.name = twitter_user.name
    db_user.location = twitter_user.location
    db_user.followers_count = twitter_user.followers_count
    db.session.add(db_user)
    db.session.commit()

    print("STATUS COUNT:", len(statuses))
    basilica_api = basilica_api_client()
    all_tweet_texts = [status.full_text for status in statuses]
    embeddings = list(
        basilica_api.embed_sentences(all_tweet_texts, model="twitter"))
    print("NUMBER OF EMBEDDINGS", len(embeddings))
    counter = 0
    for status in statuses:
        print(status.full_text)
        print("----")
        # print(dir(status))

        # Find or create database tweet:
        db_tweet = Tweet.query.get(status.id) or Tweet(id=status.id)
        db_tweet.user_id = status.author.id  # or db_user.id
        db_tweet.full_text = status.full_text
        # embedding = basilica_client.embed_sentence(status.full_text, model="twitter") # todo: prefer to make a single request to basilica with all the tweet texts, instead of a request per tweet
        embedding = embeddings[counter]
        print(len(embedding))
        db_tweet.embedding = embedding
        db.session.add(db_tweet)
        counter += 1
    db.session.commit()

    return db_user, statuses
def fetch_user_data(screen_name=None):

    print(screen_name)

    # get our user data for whichever twitter user
    api = twitter_api()
    twitter_user = api.get_user(screen_name)
    statuses = api.user_timeline(screen_name,
                                 tweet_mode="extended",
                                 count=150,
                                 exclude_replies=True,
                                 include_rts=False)

    #
    # STORE USER
    #

    # get existing user from the db
    # or initialize a new one if the query returns None
    db_user = User.query.get(twitter_user.id) or User(id=twitter_user.id)

    # assign column values to the twitter information
    db_user.screen_name = twitter_user.screen_name
    db_user.name = twitter_user.name
    db_user.location = twitter_user.location
    db_user.followers_count = twitter_user.followers_count

    # add user to database
    db.session.add(db_user)
    db.session.commit()

    #breakpoint()
    #return "OK"
    #return render_template("user.html", user=db_user, tweets=statuses) # tweets=db_tweets

    #
    # STORE TWEETS
    #

    # initialize basilica api
    basilica_api = basilica_api_client()

    # create list of tweet texts
    all_tweet_texts = [status.full_text for status in statuses]

    # pass tweets to api so it can embed the sentences using the twitter model
    embeddings = list(
        basilica_api.embed_sentences(all_tweet_texts, model="twitter"))
    print("NUMBER OF EMBEDDINGS", len(embeddings))

    # TODO: explore using the zip() function maybe...
    counter = 0
    for status in statuses:
        print(status.full_text)
        print("----")
        # get existing tweet from the db or initialize a new one:
        db_tweet = Tweet.query.get(status.id) or Tweet(id=status.id)
        db_tweet.user_id = status.author.id  # or db_user.id
        db_tweet.full_text = status.full_text
        #embedding = basilica_client.embed_sentence(status.full_text, model="twitter") # todo: prefer to make a single request to basilica with all the tweet texts, instead of a request per tweet
        embedding = embeddings[counter]
        print(len(embedding))
        db_tweet.embedding = embedding
        db.session.add(db_tweet)
        counter += 1
    db.session.commit()
    #breakpoint()
    return "OK"
Пример #16
0
def predict():
    print("PREDICT ROUTE...")
    print("FORM DATA:", dict(request.form))
    #> {'screen_name_a': 'elonmusk', 'screen_name_b': 's2t2', 'tweet_text': 'Example tweet text here'}
    screen_name_a = request.form["screen_name_a"]
    screen_name_b = request.form["screen_name_b"]
    tweet_text = request.form["tweet_text"]

    print("-----------------")
    print("FETCHING TWEETS FROM THE DATABASE...")
    # todo: wrap in a try block in case the user's don't exist in the database
    user_a = User.query.filter(User.screen_name == screen_name_a).one()
    user_b = User.query.filter(User.screen_name == screen_name_b).one()
    user_a_tweets = user_a.tweets
    user_b_tweets = user_b.tweets
    #user_a_embeddings = [tweet.embedding for tweet in user_a_tweets]
    #user_b_embeddings = [tweet.embedding for tweet in user_b_tweets]
    # take the same number of tweets for each user
    min_tweets = min(len(user_a.tweets), len(user_b.tweets))
    user_a_tweets = user_a_tweets[:min_tweets]
    user_b_tweets = user_b_tweets[:min_tweets]

    print("-----------------")
    print("TRAINING THE MODEL...")
    embeddings = []
    labels = []
    for tweet in user_a_tweets:
        labels.append(user_a.screen_name)
        embeddings.append(tweet.embedding)

    for tweet in user_b_tweets:
        labels.append(user_b.screen_name)
        embeddings.append(tweet.embedding)

    labels_array = np.array(labels)
    labels_array = labels_array.reshape(-1,1)
    embeddings_array = np.array(embeddings)
    # embeddings_array = embeddings_array.reshape(-1,1)

    classifier = LogisticRegression(random_state=42, solver='lbfgs') # for example
    classifier.fit(embeddings_array, labels_array)
    print("labels_array type: ",type(labels_array))
    print("labels_array shape: ",labels_array.shape)
    print("labels_array: ", labels_array)
    print("embeddings_array type: ",type(embeddings_array))
    print("embeddings_array shape: ",embeddings_array.shape)
    # print("embeddings: ", embeddings)

    print("-----------------")
    print("MAKING A PREDICTION...")
    #result_a = classifier.predict([user_a_tweets[0].embedding])
    #result_b = classifier.predict([user_b_tweets[0].embedding])

    basilica_client = basilica_api_client()
    example_embedding = basilica_client.embed_sentence(tweet_text, model="twitter")
    result = classifier.predict([example_embedding])
    #breakpoint()

    #return jsonify({"message": "RESULTS", "most_likely": result[0]})
    return render_template("results.html",
        screen_name_a=screen_name_a,
        screen_name_b=screen_name_b,
        tweet_text=tweet_text,
        screen_name_most_likely= result[0]
    )