logging.basicConfig(level=getattr(logging, args.logging_level)) start_time = time.time() logging.info("Loading vocab from: {}".format(args.vocab)) vocab = torch.load(args.vocab) logging.info("Counting training lines and labels") num_labels, train_num_lines = count(train_data_path) logging.info("Counting testing lines and labels") num_labels, test_num_lines = count(test_data_path) logging.info("Loading iterable datasets") train_dataset = Dataset(get_csv_iterator(train_data_path, ngrams, vocab), train_num_lines, num_epochs) test_dataset = Dataset(get_csv_iterator(test_data_path, ngrams, vocab), test_num_lines, num_epochs) logging.info("Creating models") model = TextSentiment(len(vocab), embed_dim, num_labels).to(device) criterion = torch.nn.CrossEntropyLoss().to(device) logging.info("Setup took: {:3.0f}s".format(time.time() - start_time)) logging.info("Starting training") train(lr, num_epochs, train_dataset) test(test_dataset) if args.save_model_path: print("Saving model to {}".format(args.save_model_path)) torch.save(model.to('cpu'), args.save_model_path)
import argparse import torch ag_news_label = {1: "World", 2: "Sports", 3: "Business", 4: "Sci/Tec"} WEIGHT_PATH = "../weights/text_news0.2672930294473966.pth" vocab = pickle.load(open(".data/save_vocab.p", "rb")) device = "cuda" if torch.cuda.is_available() else "cpu" VOCAB_SIZE = 1308844 EMBED_DIM = 32 NUM_CLASS = 4 model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUM_CLASS) checkpoint = torch.load(WEIGHT_PATH, map_location=torch.device('cpu')) model.load_state_dict(checkpoint) model.to(device) def predict(text, model, vocab, ngrams): tokenizer = get_tokenizer("basic_english") with torch.no_grad(): text = torch.tensor([ vocab[token] for token in ngrams_iterator(tokenizer(text), ngrams) ]) output = model(text, torch.tensor([0])) return output.argmax(1).item() + 1 parser = argparse.ArgumentParser( description='Text_classification With Pytorch') parser.add_argument("--text",