num_layers=NUM_LAYER, use_time=USE_TIME, agg_method=AGG_METHOD, attn_mode=ATTN_MODE,
            seq_len=SEQ_LEN, n_head=NUM_HEADS, drop_out=DROP_OUT, node_dim=NODE_DIM, time_dim=TIME_DIM)
# optimizer = torch.optim.Adam(tgan.parameters(), lr=LEARNING_RATE)
# criterion = torch.nn.BCELoss()
tgan = tgan.to(device)

#logger.info('loading saved TGAN model')
model_path = f'./saved_models/{args.prefix}-{args.agg_method}-{args.attn_mode}-{DATA}.pth'
tgan.load_state_dict(torch.load(model_path))
tgan.eval()
#logger.info('TGAN models loaded')
#logger.info('TGAN starts generating representations')

with torch.no_grad():
    tgan.eval()
    src_embed = tgan.tem_conv(src_l, ts_l, NODE_LAYER)            
    tgt_embed = tgan.tem_conv(dst_l, ts_l, NODE_LAYER)  
          
g_df['user_embed'] = [emb[0] for emb in src_embed.tolist()]
g_df['tgt_embed'] = [emb[0] for emb in tgt_embed.tolist()]

inv_i_dict = {v: k for k, v in i_dict.items()}
tgtpg_item = [inv_i_dict.get(i) for i in target_pages]
tgt_emb = [g_df.loc[orig_df['item'] == item, 'tgt_embed'].iloc[0] for item in tgtpg_item]
g_df = g_df.groupby('u').tail(1)
username = g_df['clientId'].tolist()
user, page = pd.DataFrame(g_df['user_embed']).set_index(g_df['clientId']) , pd.DataFrame(tgt_emb, columns=['page embed']).set_index(pd.DataFrame(target_pages, columns=['pagePath'])['pagePath'])
#user.set_index(g_df['clientId'])
#page.set_index(pd.DataFrame(target_pages, columns=['pagePath']))
s1 = [u*i for u, i in itertools.product(user['user_embed'], [page['page embed'][0]])]
s2 = [u*i for u, i in itertools.product(user['user_embed'], [page['page embed'][1]])]
    tgan = tgan.eval()
    lr_model = lr_model.train()
    #num_batch
    for k in range(num_batch):
        s_idx = k * BATCH_SIZE
        e_idx = min(num_instance - 1, s_idx + BATCH_SIZE)
        src_l_cut = train_src_l[s_idx:e_idx]
        dst_l_cut = train_dst_l[s_idx:e_idx]
        ts_l_cut = train_ts_l[s_idx:e_idx]
        label_l_cut = train_label_l[s_idx:e_idx]
        
        size = len(src_l_cut)
        
        lr_optimizer.zero_grad()
        with torch.no_grad():
            src_embed = tgan.tem_conv(src_l_cut, ts_l_cut, NODE_LAYER)
        
        src_label = torch.from_numpy(label_l_cut).float()  ################### .to(device)
        lr_prob = lr_model(src_embed).sigmoid()
        lr_loss = lr_criterion(lr_prob, src_label)
        lr_loss.backward()
        lr_optimizer.step()
        if k % 10000 == 0:
            print('{}/{}'.format(k, num_batch))

    train_auc, train_loss = eval_epoch(train_src_l, train_dst_l, train_ts_l, train_label_l, BATCH_SIZE, lr_model, tgan)
    test_auc, test_loss = eval_epoch(test_src_l, test_dst_l, test_ts_l, test_label_l, BATCH_SIZE, lr_model, tgan)
    #torch.save(lr_model.state_dict(), './saved_models/edge_{}_wkiki_node_class.pth'.format(DATA))
    logger.info(f'train auc: {train_auc}, test auc: {test_auc}')

test_auc, test_loss = eval_epoch(test_src_l, test_dst_l, test_ts_l, test_label_l, BATCH_SIZE, lr_model, tgan)