예제 #1
0
def yield_tree_information(tree_file_name):

    with open(tree_file_name, "rt") as tree_file:
        for line in tree_file.readlines():
            if "ROOT" in line:
                continue
            tweet_in, tweet_out, user_in, user_out, time_in, time_out = parse_edge_line(line)

            dict_to_yield = {"tweet_in": tweet_in,
                             "user_in": user_in,
                             "time_in": time_in,
                             "tweet_out": tweet_out,
                             "user_out": user_out,
                             "time_out": time_out
                             }
            yield dict_to_yield
예제 #2
0
 def get_user_and_tweet_ids_in_train(self, trees_to_parse, train_ids):
     """ Returns sets of all the user ids and tweet ids that appear in train set """
     user_ids_in_train = set()
     tweet_ids_in_train = set()
     for tree_file_name in trees_to_parse:
         news_id = get_root_id(tree_file_name)
         if news_id in train_ids:
             with open(tree_file_name, "rt") as tree_file:
                 for line in tree_file.readlines():
                     if "ROOT" in line:
                         continue
                     tweet_in, tweet_out, user_in, user_out, _, _ = parse_edge_line(line)
                     user_ids_in_train.add(user_in)  # user_ids_in_train may be bigger
                     user_ids_in_train.add(user_out)
                     tweet_ids_in_train.add(tweet_in)
                     tweet_ids_in_train.add(tweet_out)
     return user_ids_in_train, tweet_ids_in_train
예제 #3
0
    def build_tree(self, tree_file_name, tweet_fts, user_fts):
        """ Parses the file to build a tree, adding all the features.

        Args:
            tree_file_name:str (path to the file storing the tree)
            tweet_fts: dict[tweet_id:int -> tweet-features:np array]
            user_fts: dict[user_id:int -> user-features:np array]
            labels: dict[tweet_id:int -> label:int]

        Returns:
            x: list (n_nodes)[np.array (n_features)]
            edge_index: list (nb_edges)[node_in_id, node_out_id, time_out]
        """

        edges = []  #
        x = []
        node_id_to_count = {}  # Dict tweet id, user id -> node id, which starts at 0 # changed as before, a tweet can be seen a first time with a given uid then a second time with a different one
        count = 0

        # First run to get the ROOT line and shift in time (if there is one)
        time_shift = 0
        with open(tree_file_name, "rt") as tree_file:
            for line in tree_file.readlines():
                tweet_in, tweet_out, user_in, user_out, _, time_out = utils.parse_edge_line(line)
                if time_out < 0 and time_shift == 0:
                    # if buggy dataset, and we haven't found the time_shift yet
                    time_shift = -time_out
                if "ROOT" in line:
                    node_id_to_count[(tweet_out, user_out)] = 0
                    self.add_node_features_to_x(x, node_id_to_count, tweet_out, user_out, 
                                                tweet_fts, user_fts, time_out)
                    count += 1
                    break

        if count == 0:
            raise ValueError(f"Didn't find ROOT... File {tree_file_name} is corrupted")

        with open(tree_file_name, "rt") as tree_file:

            current_time_out = 0
            for line in tree_file.readlines():

                if 'ROOT' in line:
                    continue

                tweet_in, tweet_out, user_in, user_out, _, time_out = utils.parse_edge_line(line)
                time_out += time_shift  # fix buggy dataset
                assert time_out >= 0

                if (self.time_cut is None) or (time_out <= self.time_cut):

                    # Add dest if unseen. First line with ROOT adds the original tweet.
                    if (tweet_out, user_out) not in node_id_to_count:
                        node_id_to_count[(tweet_out, user_out)] = count
                        self.add_node_features_to_x(x, node_id_to_count, tweet_out, user_out, 
                                                    tweet_fts, user_fts, time_out)
                        count += 1

                    # Remove some buggy lines (i.e. duplicated or make no sense)
                    if time_out >= current_time_out:
                        potential_edge = [
                            node_id_to_count[(tweet_in, user_in)],
                            node_id_to_count[(tweet_out, user_out)],
                            time_out,
                            user_in,
                            user_out
                        ]
                        if potential_edge not in edges:
                            current_time_out = time_out
                            edges.append(potential_edge)

                if (self.time_cut is not None) and (time_out > self.time_cut):
                    # We've seen all interesting edges
                    break

        self.num_node_features = len(x[-1])

        return x, edges
예제 #4
0
    def get_retweet_list(self, tree_file_name, user_features):
        """
            Parses the file to get tweeters. Then let all the retweet user size of news be the same.
        Args:
            tree_file_name:str (path to the file storing the tree)
        Returns:
            retweeters:list of users who retweet the source tweet
            retweet_lens: {user_id -> len} dict of the length of retweet path between user and the source tweet
            time_outs: {user_id -> time_out} dict of the time_out corresponding to the retweeters on the source tweet
        """
        retweeters = []
        retweet_lens = {}        
        time_outs = {}
        tweet_user_dict = {}
        edges = []  
        x = []
        count = 0

        # First run to get the ROOT line and shift in time (if there is one)
        time_shift = 0
        with open(tree_file_name, "rt") as tree_file:
            for line in tree_file.readlines():
                tweet_in, tweet_out, user_in, user_out, _, time_out = parse_edge_line(line)

                if time_out < 0 and time_shift == 0:
                    # if buggy dataset, and we haven't found the time_shift yet
                    time_shift = -time_out
                if "ROOT" in line:
                    tweet_user_dict[tweet_out] = user_out
                    retweet_lens[user_out] = 0
                    time_outs = {}
                    # node_id_to_count[(tweet_out, user_out)] = 0
                    # self.add_node_features_to_x(x, node_id_to_count, tweet_out, user_out, 
                    #                             tweet_fts, user_fts, time_out)
                    count += 1
                    break

        if count == 0:
            raise ValueError(f"Didn't find ROOT... File {tree_file_name} is corrupted")

        with open(tree_file_name, "rt") as tree_file:
            current_time_out = 0
            for line in tree_file.readlines():

                if 'ROOT' in line:
                    continue

                tweet_in, tweet_out, user_in, user_out, _, time_out = parse_edge_line(line)
                time_out += time_shift  # fix buggy dataset
                assert time_out >= 0

                if (self.time_cut is None) or (time_out <= self.time_cut):
                    # make sure the user exists in the dict "user_features"
                    if user_out not in retweeters and user_out in user_features \
                        and user_in in retweet_lens:   
                        retweeters.append(user_out)
                        retweet_lens[user_out] = retweet_lens[user_in]+1
                        time_outs[user_out] = time_out
                if len(retweeters) == self.retweet_user_size:
                    break
        #let all the retweet user size of news be the same
        if len(retweeters) == self.retweet_user_size:
            retweeters = retweeters[:self.retweet_user_size]
        else:
            #if retweet users less than the fixed number, 
            #pad by random sample from current retweet users
            src_retweeters = np.array(retweeters)
            q = src_retweeters[np.random.choice(src_retweeters.shape[0], \
                self.retweet_user_size, replace=True)]
            q = q.tolist()
            src_retweeters = src_retweeters.tolist()
            src_retweeters.extend(q)
            pad_retweeters = src_retweeters[:self.retweet_user_size]
            retweeters = np.array(pad_retweeters)
        
        return retweeters, retweet_lens, time_outs