def load_raw_data(dataset_dir, only_binary=True): """ ids: list of tweet ids that are in tree file names list and has binary labels. contents: tweet text corresponding to ids. """ trees_to_parse = get_tree_file_names(dataset_dir) labels = {} with open(os.path.join(dataset_dir, "label.txt")) as label_file: for line in label_file.readlines(): label, news_id = line.split(":") labels[int(news_id)] = label news_ids_to_consider = list(labels.keys()) if only_binary: news_ids_to_consider = [ news_id for news_id in news_ids_to_consider if labels[news_id] in ['false', 'true'] ] id_tweet_dict = load_tweet_content(dataset_dir + "/source_tweets.txt") ids = [] contents = [] for tree_file_name in trees_to_parse: news_id = get_root_id(tree_file_name) if news_id in news_ids_to_consider: ids.append(news_id) contents.append(id_tweet_dict[news_id]) return ids, contents, labels
def get_user_and_tweet_ids_in_train(self, trees_to_parse, train_ids): """ Returns sets of all the user ids and tweet ids that appear in train set """ user_ids_in_train = set() tweet_ids_in_train = set() for tree_file_name in trees_to_parse: news_id = get_root_id(tree_file_name) if news_id in train_ids: with open(tree_file_name, "rt") as tree_file: for line in tree_file.readlines(): if "ROOT" in line: continue tweet_in, tweet_out, user_in, user_out, _, _ = parse_edge_line(line) user_ids_in_train.add(user_in) # user_ids_in_train may be bigger user_ids_in_train.add(user_out) tweet_ids_in_train.add(tweet_in) tweet_ids_in_train.add(tweet_out) return user_ids_in_train, tweet_ids_in_train
def create_dataset(self, dataset_type="graph", standardize_features=True, on_gpu=False, oversampling_ratio=1): """ Args: dataset_type:str. Has to be "graph", "sequential" or "raw" Returns: dict with keys "train", "val", "test": If dataset_type is "graph" contains list of torch_geometric.data.Data(x=x, y=y, edge_index=edge_index) If dataset_type is "sequential" contains list of (sequential_data, y) """ if dataset_type not in ["graph", "sequential", "raw"]: raise ValueError("Supported dataset types are: 'graph', 'sequential', 'raw'.") start_time = time.time() trees_to_parse = utils.get_tree_file_names(self.dataset_dir) labels = self.load_labels() # Create train-val-test split # Remove useless trees (i.e. with labels that we don't consider) news_ids_to_consider = list(labels.keys()) if self.only_binary: news_ids_to_consider = [news_id for news_id in news_ids_to_consider if labels[news_id] in ['false', 'true']] train_ids, val_ids = train_test_split(news_ids_to_consider, test_size=0.1, random_state=self.seed) train_ids, test_ids = train_test_split(train_ids, test_size=0.25, random_state=self.seed*7) print(f"Len train/val/test {len(train_ids)} {len(val_ids)} {len(test_ids)}") user_ids_in_train, tweet_ids_in_train = \ self.get_user_and_tweet_ids_in_train(trees_to_parse, train_ids) tweet_features = self.load_tweet_features() user_features = self.load_user_features() if standardize_features: print("Standardizing features") preprocessed_tweet_fts = self.preprocess_tweet_features(tweet_features, tweet_ids_in_train) preprocessed_user_fts = self.preprocess_user_features(user_features, user_ids_in_train, standardize_features) # basic_tests.test_user_preprocessed_features(preprocessed_user_fts) ids_to_dataset = {news_id: 'train' for news_id in train_ids} ids_to_dataset.update({news_id: 'val' for news_id in val_ids}) ids_to_dataset.update({news_id: 'test' for news_id in test_ids}) dataset = {'train': [], 'val': [], 'test': []} trees = [] for tree_file_name in trees_to_parse: news_id = utils.get_root_id(tree_file_name) label = labels[news_id] if (not self.only_binary) or (label in ['false', 'true']): node_features, edges = self.build_tree(tree_file_name, tweet_fts=preprocessed_tweet_fts, user_fts=preprocessed_user_fts) trees.append((news_id, label, node_features, edges)) self.oversample(trees, ids_to_dataset, ratio=oversampling_ratio) for news_id, label, node_features, edges in trees: if dataset_type == "graph": import torch_geometric x = torch.tensor(node_features, dtype=torch.float32) y = torch.tensor(utils.to_label(label)) edge_index = np.array([edge[:2] for edge in edges], dtype=int) # change if you want the time somewhere edge_index = torch.tensor(edge_index).t().contiguous() if on_gpu: y.to(torch.device("cuda")) x.to(torch.device("cuda")) edge_index.to(torch.device("cuda")) data_point = torch_geometric.data.Data(x=x, y=y, edge_index=edge_index) if on_gpu: data_point.to(torch.device("cuda")) dataset[ids_to_dataset[news_id]].append(data_point) # Uncomment for test, to see if graphs are well created # if news_id in [580320684305416192, 387021726007042051]: # basic_tests.inspect_graph(dataset[ids_to_dataset[news_id]][-1], news_id) elif dataset_type == "sequential": y = utils.to_label(label) sequential_data = np.array( node_features) # If we go for this one, returns the features of the successive new tweet-user tuples encountered over time dataset[ids_to_dataset[news_id]].append([sequential_data, y]) # print(sequential_data.mean(dim=0)) # print("label was {}".format(label)) elif dataset_type == "raw": dataset[ids_to_dataset[news_id]].append( [[label, news_id] + edge + list(node_features[edge[1]]) for edge in edges]) # edge = [node_index_in, node_index_out, time_out, uid_in, uid_out] print(f"Dataset loaded in {time.time() - start_time:.3f}s") return dataset
def create_dataset(self, dataset_type="id_index", standardize_features=True, on_gpu=False, \ oversampling_ratio=1): """ Args: dataset_type:str. Has to be "train_val", "id_index" or "raw" Returns: dict with keys "train", "val", "test": """ if dataset_type not in ["train_val", "id_index", "raw"]: raise ValueError("Supported dataset types are: 'train_val', 'id_index', 'raw'.") start_time = time.time() trees_to_parse = get_tree_file_names(self.dataset_dir) labels = self.load_labels() # Create train-val-test split # Remove useless trees (i.e. with labels that we don't consider) news_ids_to_consider = list(labels.keys()) if self.only_binary: news_ids_to_consider = [news_id for news_id in news_ids_to_consider if labels[news_id] in ['false', 'true']] train_ids, val_ids = train_test_split(news_ids_to_consider, test_size=0.01, random_state=self.seed) train_ids, test_ids = train_test_split(train_ids, test_size=0.3, random_state=self.seed*7) print(f"Len train/val/test {len(train_ids)} {len(val_ids)} {len(test_ids)}") user_ids_in_train, tweet_ids_in_train = \ self.get_user_and_tweet_ids_in_train(trees_to_parse, train_ids) # tweet_features = self.load_tweet_features_bert() # print("tweet_features_size:{}".format(len(tweet_features))) tweet_features = self.load_tweet_features_one_hot() user_features = self.load_user_features() print("User features:") for key in user_features: for k in user_features[key]: print('\t'+k) break # preprocessed_tweet_fts = self.preprocess_tweet_features(tweet_features, tweet_ids_in_train) # preprocessed_user_fts = self.preprocess_user_features(user_features, user_ids_in_train, \ # standardize_features) # basic_tests.test_user_preprocessed_features(preprocessed_user_fts) ids_to_dataset = {news_id: 'train' for news_id in train_ids} ids_to_dataset.update({news_id: 'val' for news_id in val_ids}) ids_to_dataset.update({news_id: 'test' for news_id in test_ids}) print("Parsing trees...") trees = [] for tree_file_name in trees_to_parse: news_id = get_root_id(tree_file_name) label = labels[news_id] if (not self.only_binary) or (label in ['false', 'true']): retweeters, retweet_lens, time_outs = self.get_retweet_list(tree_file_name, user_features) # node_features, edges = self.build_tree(tree_file_name, tweet_fts=preprocessed_tweet_fts, # user_fts=preprocessed_user_fts) adj, retweeter_fts = self.get_retweeter_adj(news_id, retweeters, \ retweet_lens, time_outs, user_features) trees.append((news_id, label, retweeter_fts, tweet_features[news_id], adj)) print("trees num: {}".format(len(trees))) print("Generating dataset...") if dataset_type == "train_val": dataset = {'train': {'data_all':[], 'padded_docs':[], 'cos':[], 'label':[]}, 'val': {'data_all':[], 'padded_docs':[], 'cos':[], 'label':[]}, 'test': {'data_all':[], 'padded_docs':[], 'cos':[], 'label':[]}} elif dataset_type == "id_index": dataset = {} for news_id, label, retweeter_fts, tweet_feature, adj in trees: # dataset[news_id] = {'data_all':[], 'padded_docs':[], 'cos':[], 'label':[]} dataset[news_id] = {} data_all = [] padded_docs = [] cos = [] for news_id, label, retweeter_fts, tweet_feature, adj in trees: x = tweet_feature y = np.array(to_label(label)) retweeter_fts = retweeter_fts.astype('float') if dataset_type == "train_val": dataset[ids_to_dataset[news_id]]['data_all'].append(retweeter_fts) dataset[ids_to_dataset[news_id]]['padded_docs'].append(x) dataset[ids_to_dataset[news_id]]['cos'].append(adj) dataset[ids_to_dataset[news_id]]['label'].append(y) elif dataset_type == "id_index": dataset[news_id]['data_all'] = retweeter_fts dataset[news_id]['padded_docs'] = x dataset[news_id]['cos'] = adj dataset[news_id]['label'] = y # dataset[news_id]['data_all'].append(retweeter_fts) # dataset[news_id]['padded_docs'].append(x) # dataset[news_id]['cos'].append(adj) # dataset[news_id]['label'].append(y) # elif dataset_type == "sequential": # y = to_label(label) # sequential_data = np.array(node_features) # # If we go for this one, returns the features of the successive new tweet-user tuples # # encountered over time # dataset[ids_to_dataset[news_id]].append([sequential_data, y]) # # print(sequential_data.mean(dim=0)) # # print("label was {}".format(label)) # elif dataset_type == "raw": # dataset[ids_to_dataset[news_id]].append( # [[label, news_id] + edge + list(node_features[edge[1]]) for edge in # edges]) # edge = [node_index_in, node_index_out, time_out, uid_in, uid_out] if dataset_type == 'train_val': for key in dataset: # print(type(dataset[key]['data_all']), type(dataset[key]['data_all'][0])) dataset[key]['data_all'] = np.array(dataset[key]['data_all']) # print(dataset[key]['data_all'].shape) # dataset[key]['data_all'] = torch.from_numpy(dataset[key]['data_all']) dataset[key]['padded_docs'] = np.array(dataset[key]['padded_docs']) dataset[key]['cos'] = np.array(dataset[key]['cos']) dataset[key]['label'] = np.array(dataset[key]['label']) print(f"Dataset loaded in {time.time() - start_time:.3f}s") return dataset