Ejemplo n.º 1
0
import torch
from torch_geometric.data import Data

bb = torch.from_numpy(pd.read_csv('data/freebase/book_lo_business/index_separated/link_book-book_1.dat', usecols=[0, 1], sep='\t', header=None).to_numpy().T)
bp = torch.from_numpy(pd.read_csv('data/freebase/book_lo_business/index_separated/link_book-business_4.dat', usecols=[0, 1], sep='\t', header=None).to_numpy().T)
pp = torch.from_numpy(pd.read_csv('data/freebase/book_lo_business/index_separated/link_business-business_5.dat', usecols=[0, 1], sep='\t', header=None).to_numpy().T)
bq = torch.from_numpy(pd.read_csv('data/freebase/book_lo_business/index_separated/link_book-location&organization_2.dat', usecols=[0, 1], sep='\t', header=None).to_numpy().T)
qq = torch.from_numpy(pd.read_csv('data/freebase/book_lo_business/index_separated/link_location&organization_3.dat', usecols=[0, 1], sep='\t', header=None).to_numpy().T)


# train = torch.from_numpy(pd.read_csv('data/freebase_org/book_{}/index_separated/label.dat.train'.format(l[n]), sep='\t', header=None).to_numpy().T)
# test = torch.from_numpy(pd.read_csv('data/freebase/book_{}/index_separated/label.dat.test'.format(l[n]), sep='\t', header=None).to_numpy().T)


data = Data()
data.aa_edge_idx = bb
data.pa_edge_idx = torch.tensor([bp[1].tolist(), bp[0].tolist()])
data.pp_edge_idx = pp

data.n_a_node = (max(bb.max(), bp[0].max()) + 1).tolist()
data.n_p_node = (max(pp.max(), bp[1].max()) + 1).tolist()

data.n_a_type = 8

data.n_aa_edge = bb.shape[1]
data.n_pp_edge = pp.shape[1]
data.n_pa_edge = bp.shape[1]

torch.save(data, 'datasets-freebase/business.pt')

Ejemplo n.º 2
0
# aa.n_a_linked = len(linked_author)
# pp.n_p_linked = len(paper_with_author)
#
# pp.n_edge = pp.edge_index.shape[1]
#
# torch.save(aa, 'datasets-auta/aa.pt')
# torch.save(pa, 'datasets-auta/pa.pt')
# torch.save(pp, 'datasets-auta/pp.pt')

aa = torch.load('datasets-auta/aa.pt')
pa = torch.load('datasets-auta/pa.pt')
pp = torch.load('datasets-auta/pp.pt')

# auta-2 -- all the data
data = Data()
data.aa_edge_idx = aa.edge_index
data.pa_edge_idx = pa.edge_index
data.pp_edge_idx = pp.edge_index
data.n_a_node = aa.n_node
data.n_p_node = pp.n_node
data.n_aa_edge = aa.n_edge
data.n_pa_edge = pa.n_edge
data.n_pp_edge = pp.n_edge
data.train_node_idx, data.train_node_class, data.train_range, data.test_node_idx, data.test_node_class, data.test_range = process_node_multilabel(
    aa.node_label_list)
data.n_a_type = 7

torch.save(data, 'datasets-auta/auta-2.pt')

# auta-0 -- no unlinked paper
p_linked = set(range(pp.n_p_linked))