예제 #1
0
import time

import torch

from data import data
import model
import train

RANDOM_SEED = 200720

if __name__ == '__main__':
    print("Loading in Wiki2 data")
    print()
    start_time = time.time()
    batch_size, max_len = 512, 64
    train_iter, vocab = data.load_wiki2_data(batch_size, max_len)
    print(f"Done! Took {time.time()-start_time:.0f} seconds")

    # Actual BERT paper parameters:
    # * num_hidden = 768
    # * num_heads = 12
    # * num_hidden_feed_forward = 768
    # * num_layers = 12
    net = model.BERTModel(
        len(vocab),
        num_hidden=128,
        num_heads=2,
        num_hidden_feed_forward=256,
        num_layers=2,
        dropout=0.2,
        max_len=max_len,
예제 #2
0
import time

import torch

from data import data


if __name__ == '__main__':
    batch_size, max_len = 512, 64
    _, vocab, tokenizer = data.load_wiki2_data(batch_size, max_len)
    start_time = time.time()
    print("Testing Sentiment Analysis dataset")
    sentiment_analysis_iter = data.load_sentiment_analysis_data(
        tokenizer, batch_size, max_len, vocab
    )
    for (examples_X, weights_X, segments_X, labels_y) in sentiment_analysis_iter:
        assert examples_X.shape == torch.Size([512, 64])
        assert weights_X.shape == torch.Size([512, 64])
        assert segments_X.shape == torch.Size([512, 64])
        assert labels_y.shape == torch.Size([512])
        break
    print(f"All done! Took {time.time()-start_time:.0f} seconds")