コード例 #1
0
ファイル: heco.py プロジェクト: ZZy979/GNN-Recommendation
 def save(self):
     save_graphs(os.path.join(self.save_path, self.name + '_dgl_graph.bin'),
                 [self.g])
     save_info(os.path.join(self.raw_path, self.name + '_pos.pkl'), {
         'pos_i': self.pos_i,
         'pos_j': self.pos_j
     })
コード例 #2
0
 def save(self):
     save_graphs(join(self.save_path, self.name + '_dgl_graph.bin'), [self.g])
     save_info(join(self.save_path, self.name + '_info.pkl'), {
         'author_names': self.author_names,
         'paper_titles': self.paper_titles,
         'conf_names': self.conf_names
     })
コード例 #3
0
 def save(self):
     graph_path = os.path.join(self.save_path, 'tu_{}.bin'.format(self.name))
     info_path = os.path.join(self.save_path, 'tu_{}.pkl'.format(self.name))
     label_dict = {'labels': self.graph_labels}
     info_dict = {'max_num_node': self.max_num_node,
                  'num_labels': self.num_labels}
     save_graphs(str(graph_path), self.graph_lists, label_dict)
     save_info(str(info_path), info_dict)
コード例 #4
0
 def save(self):
     """save the graph list and the labels"""
     graph_path = os.path.join(self.save_path, self.save_name + '.bin')
     info_path = os.path.join(self.save_path, self.save_name + '.pkl')
     save_graphs(str(graph_path), self._g)
     save_info(str(info_path), {
         'num_nodes': self.num_nodes,
         'num_rels': self.num_rels
     })
コード例 #5
0
ファイル: socnavData.py プロジェクト: gnns4hri/sonata
    def save(self):
        if self.debug:
            return
        # Generate paths
        graphs_path, info_path = tuple((path_saves + x) for x in self.get_dataset_name())
        os.makedirs(os.path.dirname(path_saves), exist_ok=True)

        # Save graphs
        save_graphs(graphs_path, self.graphs, {'labels': self.labels})

        # Save additional info
        save_info(info_path, {'typemaps': self.data['typemaps'],
                              'coordinates': self.data['coordinates']})
コード例 #6
0
 def save(self):
     graph_path = os.path.join(
         self.save_path, 'dgl_graph_{}_{}.bin'.format(self.name, self.hash))
     info_path = os.path.join(
         self.save_path, 'dgl_graph_{}_{}.pkl'.format(self.name, self.hash))
     label_dict = {'labels': self.labels}
     info_dict = {
         'num_graphs': self.num_graphs,
         'num_labels': self.num_labels,
         'max_labels': self.max_labels,
         'max_node_id': self.max_node_id,
         'max_num_unique_node': self.max_num_unique_node,
         'max_seq_length': self.max_seq_length,
     }
     save_graphs(str(graph_path), self.graphs, label_dict)
     save_info(str(info_path), info_dict)
コード例 #7
0
def save_dgl(graphs, labels, graph_path, info=None, info_path=None):
    """ Saves dgl graphs, labels and other info.

    :param instance_graph_global_node_ids:
    :param info:
    :param graphs:
    :param labels:
    :param graph_path:
    :param num_classes:
    :param info_path:
    """
    # save graphs and labels
    logger.info(f'Saving graph data: {graph_path}')
    save_graphs(graph_path, graphs, {'labels': labels})
    # save other information in python dict
    if info_path is not None:
        save_info(info_path, {'info': info})
コード例 #8
0
def preprocess(args):
    g, feats = load_data(args.neighbor_path, args.pretrained_node_embed_path)
    save_graphs(os.path.join(args.save_path, 'neighbor_graph.bin'), [g])
    save_info(os.path.join(args.save_path, 'in_feats.pkl'), feats)
    print('Neighbor graph and input features saved to', args.save_path)