train_loader, dev_loader, test_loader = (DataLoader( dataset=dataset, batch_size=args.batch_size, collate_fn=utils.collate_fn, shuffle=dataset.train) for i, dataset in enumerate(dataests)) print("Building Model...") if args.model == "cnn": model = CNN.Model(vocab_size=len(vocab), embedding_size=args.embedding_size, hidden_size=args.hidden_size, filter_sizes=[3, 4, 5], dropout=args.dropout) elif args.model == "bilstm": model = BiLSTM.Model(vocab_size=len(vocab), embedding_size=args.embedding_size, hidden_size=args.hidden_size, dropout=args.dropout) else: raise ValueError( "Model should be either cnn or bilstm, {} is invalid.".format( args.model)) if torch.cuda.is_available(): model = model.cuda() trainer = Trainer(model) best_acc = 0 train_list = [] dev_list = [] for i in range(args.epochs):
def run(*args, **kwargs): parser = argparse.ArgumentParser() parser.add_argument('-train_file', type=str, default='./data/train.csv') parser.add_argument('-dev_file', type=str, default='./data/dev.csv') parser.add_argument('-test_file', type=str, default='./data/test.csv') parser.add_argument('-save_path', type=str, default='./model.pkl') parser.add_argument('-model', type=str, default=kwargs['model'], help="[cnn, bilstm]") parser.add_argument('-batch_size', type=int, default=kwargs['batch_size']) parser.add_argument('-embedding_size', type=int, default=128) parser.add_argument('-hidden_size', type=int, default=128) parser.add_argument('-learning_rate', type=float, default=1e-3) parser.add_argument('-dropout', type=float, default=0.5) parser.add_argument('-epochs', type=int, default=20) parser.add_argument('-seed', type=int, default=1) args = parser.parse_args() print(args) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) print("Loading Data...") datasets, vocab = utils.build_dataset(args.train_file, args.dev_file, args.test_file) train_loader, dev_loader, test_loader = (DataLoader( dataset=dataset, batch_size=args.batch_size, collate_fn=utils.collate_fn, shuffle=dataset.train) for i, dataset in enumerate(datasets)) print("Building Model...") if args.model == "cnn": model = CNN.Model(vocab_size=len(vocab), embedding_size=args.embedding_size, hidden_size=args.hidden_size, filter_sizes=[3, 4, 5], dropout=args.dropout) elif args.model == "bilstm": model = BiLSTM.Model(vocab_size=len(vocab), embedding_size=args.embedding_size, hidden_size=args.hidden_size, dropout=args.dropout) if torch.cuda.is_available(): model = model.cuda() trainer = Trainer(model, args.learning_rate) train_loss_list = list() dev_loss_list = list() best_acc = 0 for i in range(args.epochs): print("Epoch: {} ################################".format(i)) train_loss, train_acc = trainer.train(train_loader) dev_loss, dev_acc = trainer.evaluate(dev_loader) train_loss_list.append(train_loss) dev_loss_list.append(dev_loss) print("Train Loss: {:.4f} Acc: {:.4f}".format(train_loss, train_acc)) print("Dev Loss: {:.4f} Acc: {:.4f}".format(dev_loss, dev_acc)) if dev_acc > best_acc: best_acc = dev_acc trainer.save(args.save_path) print("###########################################") trainer.load(args.save_path) test_loss, test_acc = trainer.evaluate(test_loader) print("Test Loss: {:.4f} Acc: {:.4f}".format(test_loss, test_acc)) return train_loss_list, dev_loss_list