コード例 #1
0
def get_batch_emb_old(X, batch_size, doc_len, sen_len, tokenizer, estimator):
    idx = np.random.choice(len(X), batch_size, replace = False)
    X_batch = [X[k] for k in idx]
    ####batch_doc_seq, batch_sen_seq = get_batch_seq_old(X_batch, doc_len)
    X_batch_data = read_examples_batch(X_batch, doc_len)
    X_batch_feature = convert_all_to_features(X_batch_data, sen_len, tokenizer)
    X_batch_input = input_fn_builder(X_batch_feature, sen_len)
    X_emb = get_batch_emb_2(X_batch_input, estimator, doc_len, sen_len)
    batch_doc_seq, batch_sen_seq = get_batch_seq(X_batch_feature, doc_len, sen_len)
    return X_emb, batch_doc_seq, batch_sen_seq, X_batch
コード例 #2
0
def get_batch_emb(X_batch, doc_len, sen_len, tokenizer, estimator):
    X_batch_data = read_examples_batch(X_batch, doc_len)
    X_batch_feature = convert_all_to_features(X_batch_data, sen_len, tokenizer)
    X_batch_input = input_fn_builder(X_batch_feature, sen_len)
    X_emb = get_batch_emb_2(X_batch_input, estimator, doc_len, sen_len)
    return X_emb
コード例 #3
0
# In[ ]:

epochs = 4
batch_size = 2
n_iters = len(X) // batch_size

# In[21]:

for epoch in range(epochs):
    print('epcoh: ', epoch)
    t1 = time.time()
    for n_iter in tqdm(range(n_iters), total=n_iters):

        idx = np.random.choice(len(X), batch_size, replace=False)
        X_batch = [X[k] for k in idx]
        a, b = get_batch_seq(X_batch, doc_len)
        X_batch_data = read_examples_batch(X_batch, doc_len)
        X_batch_feature = convert_all_to_features(X_batch_data, sen_len,
                                                  tokenizer)
        X_batch_input = input_fn_builder(X_batch_feature, sen_len)
        X_emb = get_batch_emb_2(X_batch_input, estimator, doc_len, sen_len)
        print(len(a), len(b), X_emb.shape)

    print('time:', time.time() - t1)
# In[ ]:

# In[ ]:

# In[ ]:
コード例 #4
0
def get_batch_seq(X_batch, doc_len, sen_len, tokenizer, estimator):
    X_batch_data = read_examples_batch(X_batch, doc_len)
    X_batch_feature = convert_all_to_features(X_batch_data, sen_len, tokenizer)
    batch_doc_id, batch_doc_mask = get_batch_seq_2(X_batch_feature, doc_len, sen_len)
    return batch_doc_id, batch_doc_mask