Exemplo n.º 1
0
# build the graph for fast query
adj_list = [[] for _ in range(max_idx + 1)]
for src, dst, eidx, ts in zip(train_src_l, train_dst_l, train_e_idx_l,
                              train_ts_l):
    adj_list[src].append((dst, eidx, ts))
    adj_list[dst].append((src, eidx, ts))
train_ngh_finder = NeighborFinder(adj_list, uniform=UNIFORM)

# full graph with all the data for the test and validation purpose
full_adj_list = [[] for _ in range(max_idx + 1)]
for src, dst, eidx, ts in zip(src_l, dst_l, e_idx_l, ts_l):
    full_adj_list[src].append((dst, eidx, ts))
    full_adj_list[dst].append((src, eidx, ts))
full_ngh_finder = NeighborFinder(full_adj_list, uniform=UNIFORM)

train_rand_sampler = RandEdgeSampler(train_src_l, train_dst_l)
val_rand_sampler = RandEdgeSampler(src_l, dst_l)
test_rand_sampler = RandEdgeSampler(src_l, dst_l)

# Model initialize
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device('cuda:{}'.format(GPU))
tgan = TGAN(train_ngh_finder,
            n_feat,
            e_feat,
            n_feat_freeze=args.freeze,
            num_layers=NUM_LAYER,
            use_time=USE_TIME,
            agg_method=AGG_METHOD,
            attn_mode=ATTN_MODE,
            seq_len=SEQ_LEN,