コード例 #1
0
    def insert(self, model_type: str, image_url: str,
               data: Tuple[float, float]) -> None:
        """
        Inserts image prediction result into images table.
        :param model_type: vgg or resnet for image classification model.
        :param image_url: image's url that identify each distinct image.
        :param data: image prediction result -- probability of being wildfire and not wildfire.
        """
        # extract data to be dumped
        prob_not_wildfire, prob_wildfire = data

        # for vgg model
        if model_type == ImageClassifier.VGG_MODEL:
            try:
                Connection().sql_execute_commit(
                    f"UPDATE images SET not_wildfire_prob = {prob_not_wildfire}, wildfire_prob = {prob_wildfire} "
                    f"WHERE image_url = {repr(image_url)}")

            except Exception:
                logger.error("error: " + traceback.format_exc())

        # for resnet model
        elif model_type == ImageClassifier.RESNET_MODEL:
            try:
                Connection().sql_execute_commit(
                    f"UPDATE images SET resnet_not_wildfire = {prob_not_wildfire}, resnet_wildfire = {prob_wildfire} "
                    f"WHERE image_url = {repr(image_url)}")

            except Exception:
                logger.error("error: " + traceback.format_exc())
        else:
            logger.error(
                "Insertion fail. Please specify the model type to be vgg or resnet."
            )
コード例 #2
0
class Dumper:
    RECORD = 0
    LOCATION = 1

    def __init__(self, json_filename):
        with open(json_filename, 'rb') as file:
            self.tweets = json.load(file)
        self.target_fields = ["id", "bounding_box"]
        self.conn = Connection()()

    def __iter__(self, target):
        if target == Dumper.LOCATION:
            for tweet in self.tweets:
                (a, b), (c, d) = tweet['place']["bounding_box"]
                yield tuple([tweet['id'], a, b, c, d])

        elif target == Dumper.RECORD:
            for tweet in self.tweets:
                yield tuple([tweet[field] for field in ["id", "create_at", "text"]])

    def dump_all(self, table, value_count):
        try:
            sql = f'INSERT INTO {table} VALUES({"%s " * value_count});'
            for record in self:
                print(record)
                cur = self.conn.cursor()
                cur.execute(sql, record)
                cur.close()
                self.conn.commit()

        except psycopg2.DatabaseError as error:
            print(error)

    def get_location(self):
        pass
コード例 #3
0
    def run(self,
            target: Optional[int] = None,
            model: Union[object, str] = None,
            batch_insert: bool = False):
        """get records from database and dump prediction results into database"""
        # set up event2mind classifier
        event2mind_classifier = Event2MindClassifier()

        event2mind_classifier.set_model(model)

        # set up event2mind dumper
        event2mind_dumper = Event2MindDumper()

        # set sql statement according to target, do not select records which have already been classified
        if target == Event2MindClassifier.X_INTENT:
            sql = "SELECT id, text  from records where id not in (SELECT record_id from intent_in_records)"
        elif target == Event2MindClassifier.X_REACTION:
            sql = "SELECT id, text  from records where id not in (SELECT record_id from reaction_x_in_records)"
        elif target == Event2MindClassifier.Y_REACTION:
            sql = "SELECT id, text  from records where id not in (SELECT record_id from reaction_y_in_records)"
        else:
            sql = "SELECT id, text  from records " \
                  "where id not in (SELECT record_id from intent_in_records)" \
                  "and id not in (SELECT record_id from reaction_x_in_records) " \
                  "and id not in (SELECT record_id from reaction_y_in_records)"

        if batch_insert:
            # insert all records by batch
            dict_list = []
            id_list = []
            print("Begin selecting records and making prediction...")
            for id, text in Connection().sql_execute(sql):
                # get prediction result of text
                prediction_dict = event2mind_classifier.predict(text, target)
                dict_list.append(prediction_dict)
                id_list.append(id)
            print("Prediction done. Batch insertion begins...")
            # do batch insertion
            event2mind_dumper.batch_insert(
                dict_list,
                id_list,
                page_size=Event2MindClassification.PAGE_SIZE)

        else:
            # insert each records one by one
            for id, text in Connection().sql_execute(sql):
                # get prediction result of text
                prediction_dict = event2mind_classifier.predict(text, target)
                # dump prediction result into database
                event2mind_dumper.insert(prediction_dict, id)
コード例 #4
0
 def batch_traverse_tokens(self, token_type: str, probability_type: str,
                           table_name: str, data_list: list,
                           record_id_list: list, page_size: int):
     """
     Traverses each token in data dictionary for the whole batch,
     and calls respective inserting functions to dump them into database.
     :param token_type: which one of X's reaction, Y's reaction and X's intent.
     :param probability_type: type of respective probability.
     :param table_name: the target table to dump data into.
     :param data_list: list of the prediction dictionary.
     :param record_id_list: list of tweet record id corresponding to prediction result.
     :param page_size: an interger for batch insertion into database.
     """
     # list of all the record id
     rids = []
     # list of all the probabilies in data_list
     all_probabilities = []
     # list of all the tokens in data_list
     all_tokens = []
     for j in range(len(data_list)):
         # loop each dictionary in data_list
         data = data_list[j]
         tokens = data[token_type]
         probabilities = data[probability_type]
         for i in range(len(tokens)):
             all_tokens.append(' '.join(tokens[i]))
             all_probabilities.append(probabilities[i])
             rids.append(record_id_list[j])
     with Connection() as conn:
         # eids is a list containing ids of all the tokens
         eids = self.batch_insert_into_tokens(all_tokens, token_type,
                                              page_size, conn)
         self.batch_insert_into_pairs(rids, eids, all_probabilities,
                                      table_name, page_size, conn)
コード例 #5
0
    def batch_traverse_tokens(self, token_type: str, probability_type: str,
                              table_name: str, data_list: list,
                              record_id_list: list, page_size: int):

        # list of all the record id
        rids = []
        # list of all the probabilies in data_list
        all_probabilities = []
        # list of all the tokens in data_list
        all_tokens = []
        for j in range(len(data_list)):
            # loop each dictionary in data_list
            data = data_list[j]
            tokens = data[token_type]
            probabilities = data[probability_type]
            for i in range(len(tokens)):
                all_tokens.append(' '.join(tokens[i]))
                all_probabilities.append(probabilities[i])
                rids.append(record_id_list[j])
        with Connection() as conn:
            # eids is a list containing ids of all the tokens
            eids = self.batch_insert_into_tokens(all_tokens, token_type,
                                                 page_size, conn)
            self.batch_insert_into_pairs(rids, eids, all_probabilities,
                                         table_name, page_size, conn)
コード例 #6
0
 def get_exists(self):
     """get how far we went last time"""
     with Connection() as conn:
         cur = conn.cursor()
         cur.execute(self.select_exists)
         exists_list = cur.fetchall()
         cur.close()
     return exists_list
コード例 #7
0
    def f1_score(self):
        total_positive = next(Connection().sql_execute(
            "select count(*) from records where (label1 = label2 and label1= true) or judge=True"
        ))
        print(total_positive)
        l = []
        self.train()
        for id, text in Connection().sql_execute(
                "select id, text from records where label1 is not null"):
            print(text)
            if self.predict(text):
                l.append(id)
        print(l)
        print(len(l))

        precision = len(l) / total_positive
        recall = total_positive

        pass
コード例 #8
0
 def traverse_tokens(self, token_type: str, probability_type: str,
                     table_name: str, data: dict, record_id):
     """traverse each token in data dictionary, and insert them into database"""
     all_tokens = data[token_type]
     probabilities = data[probability_type]
     for i in range(len(all_tokens)):
         token = ' '.join(all_tokens[i])
         with Connection() as conn:
             eid = self.insert_into_tokens(token, token_type, conn)
             self.insert_into_pairs(record_id, eid, probabilities[i],
                                    table_name, conn)
コード例 #9
0
 def fetch_status_id_from_db():
     """a generator which generates 100 id list at a time"""
     count = 0
     result = list()
     for id, in Connection.sql_execute(
             f"SELECT id FROM records WHERE user_id IS NULL order by create_at desc"
     ):
         count += 1
         result.append(id)
         if count >= 100:
             yield result
             time.sleep(20)
             # set sleep time to prevent the twitter api from being banned
             result.clear()
             count = 0
コード例 #10
0
    def insert(self, model_type: str, image_url: str, data: Tuple[float]):
        """
        data: image prediction result -- probability of being wildfire and not wildfire
        insert image prediction result into images table
        """
        prob_not_wildfire, prob_wildfire = data
        if model_type == ImageClassifier.VGG_MODEL:
            try:
                Connection().sql_execute_commit(
                    f"UPDATE images SET not_wildfire_prob = {prob_not_wildfire}, wildfire_prob = {prob_wildfire} "
                    f"WHERE image_url = {repr(image_url)}")

            except Exception:
                logger.error("error: " + traceback.format_exc())
        elif model_type == ImageClassifier.RESNET_MODEL:
            try:
                Connection().sql_execute_commit(
                    f"UPDATE images SET resnet_not_wildfire = {prob_not_wildfire}, resnet_wildfire = {prob_wildfire} "
                    f"WHERE image_url = {repr(image_url)}")

            except Exception:
                logger.error("error: " + traceback.format_exc())
        else:
            logger.error("Insertion fail. Please specify the model type to be vgg or resnet.")
コード例 #11
0
    def get_labeled_data(self):
        self.labeled_data = list()
        for id, text, label1, label2, judge in Connection().sql_execute(
                "select id, text, label1, label2, judge from records where label1 is not null and label2 is not null"
        ):
            self.labeled_data.append((self.pre_process(text),
                                      label1 if label1 == label2 else judge))

        for text_dict, label in self.labeled_data:
            for word in text_dict.copy():
                if self.low_fre_words[word] < 35 and word not in [
                        "THISISALINK", "THISISAUSER"
                ]:
                    del text_dict[word]

        random.shuffle(self.labeled_data)
コード例 #12
0
 def traverse_tokens(self, token_type: str, probability_type: str,
                     table_name: str, data: dict, record_id):
     """
     Traverses each token in data dictionary, and calls respective inserting functions to dump them into database.
     :param token_type: which one of X's reaction, Y's reaction and X's intent.
     :param probability_type: type of respective probability.
     :param table_name: the target table to dump data into.
     :param data: the prediction dictionary.
     :param record_id: id of tweet record corresponding to prediction result.
     """
     # get tokens and corresponding probabilities from dictionary
     all_tokens = data[token_type]
     probabilities = data[probability_type]
     # traverse and call inserting functions to dump
     for i in range(len(all_tokens)):
         token = ' '.join(all_tokens[i])
         with Connection() as conn:
             eid = self.insert_into_tokens(token, token_type, conn)
             self.insert_into_pairs(record_id, eid, probabilities[i],
                                    table_name, conn)
コード例 #13
0
    def insert(self, date: datetime.date, unflattened_data: np.ndarray, var_type: str):
        """

        :param date: datetime.date
        :param unflattened_data: numpy.ndarray
        :param var_type: string
        :return: None
        """
        flattened = unflattened_data.flatten()
        with Connection() as conn:
            cur = conn.cursor()
            psycopg2.extras.execute_values(cur, PRISMDumper.INSERT_SQLS[var_type],
                                           PRISMDumper.record_generator(date, flattened),
                                           template=None, page_size=10000)
            if var_type == 'usgs':
                cur.execute(PRISMDumper.INSERT_INFOS[var_type], (date,))
            else:
                cur.execute(PRISMDumper.INSERT_INFOS[var_type], (date, 1))
            conn.commit()
            cur.close()
コード例 #14
0
    def run(self, model_type: str = ImageClassifier.RESNET_MODEL):
        """get image_id and image_url from database and dump prediction results into database"""
        # set up image classifier
        image_classifier = ImageClassifier(model_type)

        image_classifier.set_model()

        # set up event2mindDumper
        img_classification_dumper = ImgClassificationDumper()

        # loop every image in database
        try:
            for id, image_url in Connection().sql_execute(
                    "select id, image_url from images"):
                # get prediction result of image
                prediction_tuple = image_classifier.predict(image_url)
                # dump prediction result into database
                img_classification_dumper.insert(model_type, image_url,
                                                 prediction_tuple)
                logger.info("id " + str(id) + " is done!")
        except:
            logger.error('error: ' + traceback.format_exc())
コード例 #15
0
    def read_train_test_data(self):
        tweets_neg = []
        labels_neg = []
        tweets_pos = []
        labels_pos = []

        with Connection() as conn:
            cur = conn.cursor()
            cur.execute("SELECT text from records where label1 = 0 ")
            text_label1_0 = cur.fetchmany(468)

            for record in text_label1_0:
                a = str(record[0])
                a = a.encode('ascii', 'ignore').decode('ascii')
                tweets_pos.append(a.strip().replace('\n', '. '))
                labels_pos.append(0)  # 0 for true(wildfire), 1 for false

            cur.execute("SELECT text from records where label1 = 1 ")
            text_label1_1 = cur.fetchmany(532)

            for record in text_label1_1:
                a = str(record[0])
                a = a.encode('ascii', 'ignore').decode('ascii')
                tweets_neg.append(a.strip().replace('\n', '. '))
                labels_neg.append(1)  # 0 for true(wildfire), 1 for false

        random.seed(1)
        random.shuffle(tweets_pos)
        random.seed(1)
        random.shuffle(tweets_neg)

        while len(tweets_neg) < self.neg_num * 2 + self.neg_num / 10 * 2:
            tweets_neg = tweets_neg + tweets_neg
            labels_neg = labels_neg + labels_neg

        while len(tweets_pos) < self.pos_num * 2 + self.pos_num / 10 * 2:
            tweets_pos = tweets_pos + tweets_pos
            labels_pos = labels_pos + labels_pos

        random.seed(1)
        random.shuffle(tweets_pos)
        random.seed(1)
        random.shuffle(tweets_neg)

        tweets_neg_train = tweets_neg[:self.neg_num]
        labels_neg_train = labels_neg[:self.neg_num]
        tweets_neg_test = tweets_neg[self.neg_num:self.neg_num * 2]
        labels_neg_test = labels_neg[self.neg_num:self.neg_num * 2]
        tweets_neg_validate = tweets_neg[self.neg_num *
                                         2:int(self.neg_num * 2 +
                                               self.neg_num / 10 * 2)]
        labels_neg_validate = labels_neg[self.neg_num *
                                         2:int(self.neg_num * 2 +
                                               self.neg_num / 10 * 2)]

        tweets_pos_train = tweets_pos[:self.pos_num]
        labels_pos_train = labels_pos[:self.pos_num]
        tweets_pos_test = tweets_pos[self.pos_num:self.pos_num * 2]
        labels_pos_test = labels_pos[self.pos_num:self.pos_num * 2]
        tweets_pos_validate = tweets_pos[self.pos_num *
                                         2:int(self.pos_num * 2 +
                                               self.pos_num / 10 * 2)]
        labels_pos_validate = labels_pos[self.pos_num *
                                         2:int(self.pos_num * 2 +
                                               self.pos_num / 10 * 2)]

        tweets_train = tweets_neg_train + tweets_pos_train
        labels_train = labels_neg_train + labels_pos_train
        tweet_label_pair_train = list(zip(tweets_train, labels_train))
        random.seed(1)
        random.shuffle(tweet_label_pair_train)
        tweet_texts_Train, tweet_labels_Train = zip(*tweet_label_pair_train)

        tweets_test = tweets_neg_test + tweets_pos_test
        labels_test = labels_neg_test + labels_pos_test
        tweet_label_pair_test = list(zip(tweets_test, labels_test))
        random.seed(1)
        random.shuffle(tweet_label_pair_test)
        tweet_texts_Test, tweet_labels_Test = zip(*tweet_label_pair_test)

        tweets_validate = tweets_neg_validate + tweets_pos_validate
        labels_validate = labels_neg_validate + labels_pos_validate
        tweet_label_pair_validate = list(zip(tweets_validate, labels_validate))
        random.seed(1)
        random.shuffle(tweet_label_pair_validate)
        tweet_texts_Validate, tweet_labels_Validate = zip(
            *tweet_label_pair_validate)

        print("neg_num = " + str(self.neg_num) + " pos_num = " +
              str(self.pos_num))
        print("tweet_train = " + str(len(tweets_train)) + " label_train = " +
              str(len(labels_train)))
        print("tweet_test = " + str(len(tweets_test)) + " label_test = " +
              str(len(labels_test)))
        print("tweet_validate = " + str(len(tweets_validate)) +
              " label_validate = " + str(len(labels_validate)))

        return tweet_texts_Train, tweet_labels_Train, tweet_texts_Test, tweet_labels_Test, tweet_texts_Validate, tweet_labels_Validate
コード例 #16
0
        self.train()
        for id, text in Connection().sql_execute(
                "select id, text from records where label1 is not null"):
            print(text)
            if self.predict(text):
                l.append(id)
        print(l)
        print(len(l))

        precision = len(l) / total_positive
        recall = total_positive

        pass


if __name__ == '__main__':

    nl = NLTKTest()

    nl.train()

    # Store data (serialize)
    with open('../models/nltk_model.pickle', 'wb') as handle:
        pickle.dump(nl, handle, protocol=pickle.HIGHEST_PROTOCOL)

    for text, in Connection().sql_execute(
            "select text from records where label1 is NULL"):
        print(text)
        print(nl.predict(text))
        input()
コード例 #17
0
    def insert(self, data_list: List[Dict]) -> None:
        """inserts the given list into the database"""
        # construct sql statement to insert data into the records db table
        tuples_records = []
        for data in data_list:
            tuples_records += [
                (data['id'], data['date_time'], data['full_text'],
                 ', '.join(data['hashtags']) if data['hashtags'] else None,
                 data['profile_pic'], data['created_date_time'],
                 data['screen_name'], data['user_name'],
                 data['followers_count'], data['favourites_count'],
                 data['friends_count'], data['user_id'], data['user_location'],
                 data['statuses_count'])
            ]
            self.inserted_count += 1

        try:
            with Connection() as connection:
                cur = connection.cursor()
                if tuples_records:
                    extras.execute_values(
                        cur,
                        f"insert into records (id,create_at, text, hash_tag,profile_pic,created_date_time,screen_name,"
                        f"user_name,followers_count,favourites_count,friends_count,user_id,user_location,statuses_count"
                        f") values %s "
                        f"ON CONFLICT(id) DO UPDATE set text = excluded.text, profile_pic = excluded.profile_pic, "
                        f"screen_name = excluded.screen_name, user_name = excluded.user_name, "
                        f"followers_count = excluded.followers_count, favourites_count = excluded.favourites_count, "
                        f"friends_count= excluded.friends_count, user_id= excluded.user_id, user_location= excluded.user_location, "
                        f" statuses_count= excluded.statuses_count;",
                        tuples_records)
                # if the data is fetched from db and reprocessed, the values will be updated with the help of the ON CONFLICT DO UPDATE
                # if the data is just crawled, the sql statement will just simply insert data into db
                connection.commit()
                cur.close()
        except Exception as err:
            logger.error(str(err) + traceback.format_exc())
        else:
            logger.info(f'data inserted into records {self.inserted_count}')
        # construct sql statement to insert data into the locations db table
        tuples_locations: list[tuple] = []
        for data in data_list:
            if data['top_left'] is not None and data[
                    'bottom_right'] is not None:
                tuples_locations += [
                    (data['id'], data['top_left'][1], data['top_left'][0],
                     data['bottom_right'][1], data['bottom_right'][0])
                ]
                self.inserted_locations_count += 1

        try:
            with Connection() as connection:
                cur = connection.cursor()
                if tuples_locations:
                    extras.execute_values(
                        cur,
                        f"insert into locations (id, top_left_lat, top_left_long, bottom_right_lat,"
                        f"bottom_right_long) values %s "
                        f"ON CONFLICT(id) DO NOTHING;", tuples_locations)
                connection.commit()
                cur.close()
        except Exception as err:
            logger.error(str(err) + traceback.format_exc())
        else:
            logger.info(
                f'data inserted into locations {self.inserted_locations_count}'
            )
コード例 #18
0
ファイル: labeler.py プロジェクト: Yicong-Huang/Wildfires
 def __init__(self, role):
     self.role = role
     self.conn = Connection()()
     self.unlabeled = None
コード例 #19
0
 def insert(self, id: int, wildfire_prob: float, not_wildfire_prob: float):
     Connection().sql_execute_commit(
         f"UPDATE records SET text_cnn_wildfire_prob = {wildfire_prob}, text_cnn_not_wildfire_prob = {not_wildfire_prob} "
         f"WHERE id = {id}")
     self.inserted_count += 1
コード例 #20
0
ファイル: labeler.py プロジェクト: Yicong-Huang/Wildfires
class Labeler:
    labels = {0: "TRUE", 1: "FALSE", 2: "NOT_SURE"}

    def __init__(self, role):
        self.role = role
        self.conn = Connection()()
        self.unlabeled = None

    def mark(self, tweet_id, value) -> None:
        try:
            cur = self.conn.cursor()
            sql = f'UPDATE records SET label{self.role} = {value} WHERE id = {tweet_id};'
            cur.execute(sql, (value, tweet_id))
            cur.close()
            self.conn.commit()
        except (Exception, psycopg2.DatabaseError) as error:
            print(error)

    def get_next_unlabeled(self):

        cur = self.conn.cursor()
        sql = f'SELECT id, text FROM records WHERE label{self.role} IS NULL order by random() LIMIT 1;'
        cur.execute(sql)

        row = cur.fetchone()
        while row:
            yield row
            row = cur.fetchone()
        cur.close()
        self.conn.commit()

    def start(self):
        prev_id = None
        prev_text = None
        prev_label = None
        next_batch = self.get_next_unlabeled()
        while next_batch:
            for id, text in next_batch:
                char = self.get_next_char(text)

                while char == 'r':
                    if prev_text:
                        self.mark(prev_id, not prev_label)
                        prev_label = (prev_label + 1) % 3
                        print(
                            f"[{prev_text} is changed to {self.labels[prev_label]}]"
                        )
                        self.mark(prev_id, prev_label)
                    char = self.get_next_char(text)

                label = int(char) - 1

                print(self.labels[label])
                self.mark(id, label)
                prev_label = label
                prev_id = id
                prev_text = text
            next_batch = self.get_next_unlabeled()

    @staticmethod
    def get_next_char(text):
        char = None
        while not char or char not in list('123r'):
            print(
                f'================================================\n\n\n\n{text}\n\n\n\n\n\n\n\n\n\n([1] for True, [2] for False, [3] for not sure, [r] for reverse previous (rotate in three values), enter for skip to next) ->'
            )

            char = input().strip()
        return char
コード例 #21
0
#
# model.save_model("model_1.0.bin")
# #
# #
# def print_results(N, p, r):
#     print("N\t" + str(N))
#     print("P@{}\t{:.3f}".format(1, p))
#     print("R@{}\t{:.3f}".format(1, r))
#
# print_results(*model.test('test_1.txt'))

# print(model.predict("they are not wildfire they are campus fire", k=2))

trainset = []

with Connection() as conn:
    cur = conn.cursor()
    cur.execute("SELECT text from records where label1 = 0 ")
    text_label1 = cur.fetchmany(468)

    # f = open('data_0.txt','a')
    # print('1')
    for r in text_label1:
        a = str(r[0])

        a = a.encode('ascii', 'ignore').decode('ascii')

        trainset.append(a.strip().replace('\n', '. '))
        # print(a)
        # f.write(a)
        # f.write('\n')
コード例 #22
0
ファイル: url_dumper.py プロジェクト: Yicong-Huang/Wildfires
 def insert(self, data: Union[List, Dict]):
     if isinstance(data, dict):
         Connection().sql_execute_values(
             "insert into images(id,image_url) values ",
             self._gen_id_url_pair(data))
コード例 #23
0
 def __init__(self, json_filename):
     with open(json_filename, 'rb') as file:
         self.tweets = json.load(file)
     self.target_fields = ["id", "bounding_box"]
     self.conn = Connection()()
コード例 #24
0
def read_test_data():
    tweets_neg = []
    labels_neg = []
    tweets_pos = []
    labels_pos = []
    with Connection() as conn:
        cur = conn.cursor()
        # FOR label2 data:
        cur.execute(
            "SELECT text from records where label1 is null and label2 = 0 ")
        text_label2_0 = cur.fetchmany(264)

        for record in text_label2_0:
            a = str(record[0])
            a = a.encode('ascii', 'ignore').decode('ascii')
            tweets_pos.append(a.strip().replace('\n', '. '))
            labels_pos.append(0)  # 0 for true(wildfire), 1 for false

        # FOR label1 data:
        cur.execute(
            "SELECT text from records where label1 = 0 and label2 is null")
        text_label1_0 = cur.fetchmany(264)

        for record in text_label1_0:
            a = str(record[0])
            a = a.encode('ascii', 'ignore').decode('ascii')
            tweets_pos.append(a.strip().replace('\n', '. '))
            labels_pos.append(0)  # 0 for true(wildfire), 1 for false

        # FOR label2 data:
        cur.execute(
            "SELECT text from records where label1 is null and label2 = 1 ")
        text_label2_1 = cur.fetchmany(264)

        for record in text_label2_1:
            a = str(record[0])
            a = a.encode('ascii', 'ignore').decode('ascii')
            tweets_neg.append(a.strip().replace('\n', '. '))
            labels_neg.append(1)  # 0 for true(wildfire), 1 for false

        # FOR label1 data:
        cur.execute(
            "SELECT text from records where label1 = 1 and label2 is null")
        text_label1_1 = cur.fetchmany(264)

        for record in text_label1_1:
            a = str(record[0])
            a = a.encode('ascii', 'ignore').decode('ascii')
            tweets_neg.append(a.strip().replace('\n', '. '))
            labels_neg.append(1)  # 0 for true(wildfire), 1 for false

        random.seed(1)
        random.shuffle(tweets_pos)
        random.seed(1)
        random.shuffle(tweets_neg)

        tweets_test = tweets_pos + tweets_neg
        labels_test = labels_pos + labels_neg
        tweet_label_pair_test = list(zip(tweets_test, labels_test))
        random.seed(1)
        random.shuffle(tweet_label_pair_test)
        tweet_texts_Test, tweet_labels_Test = zip(*tweet_label_pair_test)

        print("tweet_test = " + str(len(tweet_texts_Test)) + " label_test = " +
              str(len(tweet_labels_Test)))

        return tweet_texts_Test, tweet_labels_Test
コード例 #25
0
 def get_exists(self) -> set:
     """gets how far we went last time"""
     return set(Connection.sql_execute(self.select_exists))