def generate_subgraph_datasets(params, splits=['train', 'valid'], saved_relation2id=None, max_label_value=None): testing = 'test' in splits adj_list, triplets, entity2id, relation2id, id2entity, id2relation = process_files(params.file_paths, saved_relation2id) # plot_rel_dist(adj_list, os.path.join(params.main_dir, f'data/{params.dataset}/rel_dist.png')) data_path = os.path.join(params.main_dir, f'data/{params.dataset}/relation2id.json') if not os.path.isdir(data_path) and not testing: with open(data_path, 'w') as f: json.dump(relation2id, f) graphs = {} for split_name in splits: graphs[split_name] = {'triplets': triplets[split_name], 'max_size': params.max_links} # Sample train and valid/test links for split_name, split in graphs.items(): logging.info(f"Sampling negative links for {split_name}") split['pos'], split['neg'] = sample_neg(adj_list, split['triplets'], params.num_neg_samples_per_link, max_size=split['max_size'], constrained_neg_prob=params.constrained_neg_prob) if testing: directory = os.path.join(params.main_dir, 'data/{}/'.format(params.dataset)) save_to_file(directory, f'neg_{params.test_file}_{params.constrained_neg_prob}.txt', graphs['test']['neg'], id2entity, id2relation) links2subgraphs(adj_list, graphs, params, max_label_value)
def __init__(self, db_path, db_name_pos, db_name_neg, raw_data_paths, included_relations=None, add_traspose_rels=False, num_neg_samples_per_link=1, use_kge_embeddings=False, dataset='', kge_model='', file_name=''): self.main_env = lmdb.open(db_path, readonly=True, max_dbs=3, lock=False) self.db_pos = self.main_env.open_db(db_name_pos.encode()) ##### del neg # self.db_neg = self.main_env.open_db(db_name_neg.encode()) self.node_features, self.kge_entity2id = get_kge_embeddings(dataset, kge_model) if use_kge_embeddings else ( None, None) self.num_neg_samples_per_link = num_neg_samples_per_link self.file_name = file_name ssp_graph, __, __, __, id2entity, id2relation = process_files(raw_data_paths, included_relations) self.relation_list = list(id2relation.keys()) self.num_rels = len(ssp_graph) # Add transpose matrices to handle both directions of relations. if add_traspose_rels: ssp_graph_t = [adj.T for adj in ssp_graph] ssp_graph += ssp_graph_t # the effective number of relations after adding symmetric adjacency matrices and/or self connections self.aug_num_rels = len(ssp_graph) self.graph = ssp_multigraph_to_dgl(ssp_graph) self.ssp_graph = ssp_graph self.id2entity = id2entity self.id2relation = id2relation self.max_n_label = np.array([0, 0]) with self.main_env.begin() as txn: self.max_n_label[0] = int.from_bytes(txn.get('max_n_label_sub'.encode()), byteorder='little') self.max_n_label[1] = int.from_bytes(txn.get('max_n_label_obj'.encode()), byteorder='little') self.avg_subgraph_size = struct.unpack('f', txn.get('avg_subgraph_size'.encode())) self.min_subgraph_size = struct.unpack('f', txn.get('min_subgraph_size'.encode())) self.max_subgraph_size = struct.unpack('f', txn.get('max_subgraph_size'.encode())) self.std_subgraph_size = struct.unpack('f', txn.get('std_subgraph_size'.encode())) self.avg_enc_ratio = struct.unpack('f', txn.get('avg_enc_ratio'.encode())) self.min_enc_ratio = struct.unpack('f', txn.get('min_enc_ratio'.encode())) self.max_enc_ratio = struct.unpack('f', txn.get('max_enc_ratio'.encode())) self.std_enc_ratio = struct.unpack('f', txn.get('std_enc_ratio'.encode())) self.avg_num_pruned_nodes = struct.unpack('f', txn.get('avg_num_pruned_nodes'.encode())) self.min_num_pruned_nodes = struct.unpack('f', txn.get('min_num_pruned_nodes'.encode())) self.max_num_pruned_nodes = struct.unpack('f', txn.get('max_num_pruned_nodes'.encode())) self.std_num_pruned_nodes = struct.unpack('f', txn.get('std_num_pruned_nodes'.encode())) logging.info(f"Max distance from sub : {self.max_n_label[0]}, Max distance from obj : {self.max_n_label[1]}") with self.main_env.begin(db=self.db_pos) as txn: self.num_graphs_pos = int.from_bytes(txn.get('num_graphs'.encode()), byteorder='little') ########## del neg # with self.main_env.begin(db=self.db_neg) as txn: # self.num_graphs_neg = int.from_bytes(txn.get('num_graphs'.encode()), byteorder='little') self.__getitem__(0)