import time import torch from data import data import model import train RANDOM_SEED = 200720 if __name__ == '__main__': print("Loading in Wiki2 data") print() start_time = time.time() batch_size, max_len = 512, 64 train_iter, vocab = data.load_wiki2_data(batch_size, max_len) print(f"Done! Took {time.time()-start_time:.0f} seconds") # Actual BERT paper parameters: # * num_hidden = 768 # * num_heads = 12 # * num_hidden_feed_forward = 768 # * num_layers = 12 net = model.BERTModel( len(vocab), num_hidden=128, num_heads=2, num_hidden_feed_forward=256, num_layers=2, dropout=0.2, max_len=max_len,
import time import torch from data import data if __name__ == '__main__': batch_size, max_len = 512, 64 _, vocab, tokenizer = data.load_wiki2_data(batch_size, max_len) start_time = time.time() print("Testing Sentiment Analysis dataset") sentiment_analysis_iter = data.load_sentiment_analysis_data( tokenizer, batch_size, max_len, vocab ) for (examples_X, weights_X, segments_X, labels_y) in sentiment_analysis_iter: assert examples_X.shape == torch.Size([512, 64]) assert weights_X.shape == torch.Size([512, 64]) assert segments_X.shape == torch.Size([512, 64]) assert labels_y.shape == torch.Size([512]) break print(f"All done! Took {time.time()-start_time:.0f} seconds")