Exemple #1
0
params = f"batchsize={BATCH_SIZE},layers={NUM_LAYER},neighbors={NUM_NEIGHBORS},uniform={UNIFORM}"
with open(res_path, "a") as file:
    file.write("tgat,{},{:.4f},{:.4f},\"{}\"".format(DATA, time_batch,
                                                     time_epoch, params))
    file.write("\n")

exit(0)

np.random.shuffle(idx_list)

early_stopper = EarlyStopMonitor()
epoch_bar = trange(NUM_EPOCH)
for epoch in epoch_bar:
    # Training
    # training use only training graph
    tgan.ngh_finder = train_ngh_finder
    np.random.shuffle(idx_list)
    batch_bar = trange(num_batch)
    for k in batch_bar:
        s_idx = k * BATCH_SIZE
        e_idx = min(num_instance - 1, s_idx + BATCH_SIZE)
        src_l_cut = train_src_l[s_idx:e_idx]
        dst_l_cut = train_dst_l[s_idx:e_idx]
        ts_l_cut = train_ts_l[s_idx:e_idx]
        size = len(src_l_cut)
        src_l_fake, dst_l_fake = train_rand_sampler.sample(size)

        with torch.no_grad():
            pos_label = torch.ones(size, dtype=torch.float, device=device)
            neg_label = torch.zeros(size, dtype=torch.float, device=device)
logger.debug('num of training instances: {}'.format(num_instance))
logger.debug('num of batches per epoch: {}'.format(num_batch))
idx_list = np.arange(num_instance)
np.random.shuffle(idx_list) 

logger.info('loading saved TGAN model')
model_path = f'./saved_models/{args.prefix}-{args.agg_method}-{args.attn_mode}-{DATA}.pth'
tgan.load_state_dict(torch.load(model_path))
tgan.eval()
logger.info('TGAN models loaded')
logger.info('Start training node classification task')

lr_model = LR(n_feat.shape[1])
lr_optimizer = torch.optim.Adam(lr_model.parameters(), lr=args.lr)
######################## lr_model = lr_model.to(device)
tgan.ngh_finder = full_ngh_finder
idx_list = np.arange(len(train_src_l))
lr_criterion = torch.nn.BCELoss()
lr_criterion_eval = torch.nn.BCELoss()

def eval_epoch(src_l, dst_l, ts_l, label_l, batch_size, lr_model, tgan, num_layer=NODE_LAYER):
    pred_prob = np.zeros(len(src_l))
    loss = 0
    num_instance = len(src_l)
    num_batch = math.ceil(num_instance / batch_size)
    with torch.no_grad():
        lr_model.eval()
        tgan.eval()
        for k in range(num_batch):          
            s_idx = k * batch_size
            e_idx = min(num_instance - 1, s_idx + batch_size)