Пример #1
0
    def load(self):
        # Generate paths
        graphs_path, info_path = tuple((path_saves + x) for x in self.get_dataset_name())

        # Load graphs
        self.graphs, label_dict = load_graphs(graphs_path)
        self.labels = label_dict['labels']

        # Load info
        self.data['typemaps'] = load_info(info_path)['typemaps']
        self.data['coordinates'] = load_info(info_path)['coordinates']
Пример #2
0
def train(args):
    set_random_seed(args.seed)
    g = load_graphs(os.path.join(args.data_path, 'neighbor_graph.bin'))[0][0]
    feats = load_info(os.path.join(args.data_path, 'in_feats.pkl'))

    model = HetGNN(feats['author'].shape[-1], args.num_hidden, g.ntypes)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    neg_sampler = RatioNegativeSampler()
    for epoch in range(args.epochs):
        model.train()
        embeds = model(g, feats)
        score = model.calc_score(g, embeds)
        neg_g = construct_neg_graph(g, neg_sampler)
        neg_score = model.calc_score(neg_g, embeds)
        logits = torch.cat([score, neg_score])  # (2A*E,)
        labels = torch.cat(
            [torch.ones(score.shape[0]),
             torch.zeros(neg_score.shape[0])])
        loss = F.binary_cross_entropy_with_logits(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('Epoch {:d} | Loss {:.4f}'.format(epoch, loss.item()))
    with torch.no_grad():
        final_embeds = model(g, feats)
        with open(args.save_node_embed_path, 'wb') as f:
            pickle.dump(final_embeds, f)
        print('Final node embeddings saved to', args.save_node_embed_path)
Пример #3
0
 def load(self):
     graphs, _ = load_graphs(os.path.join(self.save_path, self.name + '_dgl_graph.bin'))
     self.g = graphs[0]
     info = load_info(join(self.save_path, self.name + '_info.pkl'))
     self.author_names = info['author_names']
     self.paper_titles = info['paper_titles']
     self.conf_names = info['conf_names']
Пример #4
0
    def load(self):
        graph_path = os.path.join(self.save_path, 'tu_{}.bin'.format(self.name))
        info_path = os.path.join(self.save_path, 'tu_{}.pkl'.format(self.name))
        graphs, label_dict = load_graphs(str(graph_path))
        info_dict = load_info(str(info_path))

        self.graph_lists = graphs
        self.graph_labels = label_dict['labels']
        self.max_num_node = info_dict['max_num_node']
        self.num_labels = info_dict['num_labels']
Пример #5
0
 def load(self):
     graphs, _ = load_graphs(
         os.path.join(self.save_path, self.name + '_dgl_graph.bin'))
     self.g = graphs[0]
     ntype = self.predict_ntype
     self._num_classes = self.g.nodes[ntype].data['label'].max().item() + 1
     for k in ('train_mask', 'val_mask', 'test_mask'):
         self.g.nodes[ntype].data[k] = self.g.nodes[ntype].data[k].bool()
     info = load_info(os.path.join(self.raw_path, self.name + '_pos.pkl'))
     self.pos_i, self.pos_j = info['pos_i'], info['pos_j']
Пример #6
0
def load_dgl(graph_path, info_path=None):
    """ Loads saved dgl graphs, labels and other info.

    :param graph_path:
    :param info_path:
    :return:
    """
    # load processed data from directory graph_path
    logger.info(f'Loading graph data from: {graph_path}')
    graphs, label_dict = load_graphs(graph_path)
    labels = label_dict['labels']
    if info_path is not None:
        info = load_info(info_path)['info']
        return graphs, labels, info
    return graphs, labels
Пример #7
0
    def load(self):
        graph_path = os.path.join(
            self.save_path, 'dgl_graph_{}_{}.bin'.format(self.name, self.hash))
        info_path = os.path.join(
            self.save_path, 'dgl_graph_{}_{}.pkl'.format(self.name, self.hash))
        graphs, label_dict = load_graphs(str(graph_path))
        info_dict = load_info(str(info_path))

        self.graphs = graphs
        self.labels = label_dict['labels']
        self.num_graphs = info_dict['num_graphs']
        self.num_labels = info_dict['num_labels']
        self.max_labels = info_dict['max_labels']
        self.max_node_id = info_dict['max_node_id']
        self.max_num_unique_node = info_dict['max_num_unique_node']
        self.max_seq_length = info_dict['max_seq_length']