def main(args): model = RCNN(vocab_size=args.vocab_size, embedding_dim=args.embedding_dim, hidden_size=args.hidden_size, hidden_size_linear=args.hidden_size_linear, class_num=args.class_num, dropout=args.dropout).to(args.device) if args.n_gpu > 1: model = torch.nn.DataParallel(model, dim=0) train_texts, train_labels = read_file(args.train_file_path) word2idx = build_dictionary(train_texts, vocab_size=args.vocab_size) logger.info('Dictionary Finished!') full_dataset = CustomTextDataset(train_texts, train_labels, word2idx) num_train_data = len(full_dataset) - args.num_val_data train_dataset, val_dataset = random_split( full_dataset, [num_train_data, args.num_val_data]) train_dataloader = DataLoader(dataset=train_dataset, collate_fn=lambda x: collate_fn(x, args), batch_size=args.batch_size, shuffle=True) valid_dataloader = DataLoader(dataset=val_dataset, collate_fn=lambda x: collate_fn(x, args), batch_size=args.batch_size, shuffle=True) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) train(model, optimizer, train_dataloader, valid_dataloader, args) logger.info('******************** Train Finished ********************') # Test if args.test_set: test_texts, test_labels = read_file(args.test_file_path) test_dataset = CustomTextDataset(test_texts, test_labels, word2idx) test_dataloader = DataLoader(dataset=test_dataset, collate_fn=lambda x: collate_fn(x, args), batch_size=args.batch_size, shuffle=True) model.load_state_dict( torch.load(os.path.join(args.model_save_path, "best.pt"))) _, accuracy, precision, recall, f1, cm = evaluate( model, test_dataloader, args) logger.info('-' * 50) logger.info( f'|* TEST SET *| |ACC| {accuracy:>.4f} |PRECISION| {precision:>.4f} |RECALL| {recall:>.4f} |F1| {f1:>.4f}' ) logger.info('-' * 50) logger.info('---------------- CONFUSION MATRIX ----------------') for i in range(len(cm)): logger.info(cm[i]) logger.info('--------------------------------------------------')
def load_model(test_arguments): rcnn = RCNN(test_arguments.pos_loss_method, test_arguments.loss_weight_lambda).cuda() rcnn.load_state_dict(t.load(test_arguments.model_path)) rcnn.eval() # dropout rate = 0 return rcnn
def main(args): acc_list = [] f1_score_list = [] prec_list = [] recall_list = [] for i in range(10): setup_data() model = RCNN(vocab_size=args.vocab_size, embedding_dim=args.embedding_dim, hidden_size=args.hidden_size, hidden_size_linear=args.hidden_size_linear, class_num=args.class_num, dropout=args.dropout).to(args.device) if args.n_gpu > 1: model = torch.nn.DataParallel(model, dim=0) train_texts, train_labels = read_file(args.train_file_path) word2idx, embedding = build_dictionary(train_texts, args.vocab_size, args.lexical, args.syntactic, args.semantic) logger.info('Dictionary Finished!') full_dataset = CustomTextDataset(train_texts, train_labels, word2idx, args) num_train_data = len(full_dataset) - args.num_val_data train_dataset, val_dataset = random_split( full_dataset, [num_train_data, args.num_val_data]) train_dataloader = DataLoader(dataset=train_dataset, collate_fn=lambda x: collate_fn(x, args), batch_size=args.batch_size, shuffle=True) valid_dataloader = DataLoader(dataset=val_dataset, collate_fn=lambda x: collate_fn(x, args), batch_size=args.batch_size, shuffle=True) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) train(model, optimizer, train_dataloader, valid_dataloader, embedding, args) logger.info('******************** Train Finished ********************') # Test if args.test_set: test_texts, test_labels = read_file(args.test_file_path) test_dataset = CustomTextDataset(test_texts, test_labels, word2idx, args) test_dataloader = DataLoader( dataset=test_dataset, collate_fn=lambda x: collate_fn(x, args), batch_size=args.batch_size, shuffle=True) model.load_state_dict( torch.load(os.path.join(args.model_save_path, "best.pt"))) _, accuracy, precision, recall, f1, cm = evaluate( model, test_dataloader, embedding, args) logger.info('-' * 50) logger.info( f'|* TEST SET *| |ACC| {accuracy:>.4f} |PRECISION| {precision:>.4f} |RECALL| {recall:>.4f} |F1| {f1:>.4f}' ) logger.info('-' * 50) logger.info('---------------- CONFUSION MATRIX ----------------') for i in range(len(cm)): logger.info(cm[i]) logger.info('--------------------------------------------------') acc_list.append(accuracy / 100) prec_list.append(precision) recall_list.append(recall) f1_score_list.append(f1) avg_acc = sum(acc_list) / len(acc_list) avg_prec = sum(prec_list) / len(prec_list) avg_recall = sum(recall_list) / len(recall_list) avg_f1_score = sum(f1_score_list) / len(f1_score_list) logger.info('--------------------------------------------------') logger.info( f'|* TEST SET *| |Avg ACC| {avg_acc:>.4f} |Avg PRECISION| {avg_prec:>.4f} |Avg RECALL| {avg_recall:>.4f} |Avg F1| {avg_f1_score:>.4f}' ) logger.info('--------------------------------------------------') plot_df = pd.DataFrame({ 'x_values': range(10), 'avg_acc': acc_list, 'avg_prec': prec_list, 'avg_recall': recall_list, 'avg_f1_score': f1_score_list }) plt.plot('x_values', 'avg_acc', data=plot_df, marker='o', markerfacecolor='blue', markersize=12, color='skyblue', linewidth=4) plt.plot('x_values', 'avg_prec', data=plot_df, marker='', color='olive', linewidth=2) plt.plot('x_values', 'avg_recall', data=plot_df, marker='', color='olive', linewidth=2, linestyle='dashed') plt.plot('x_values', 'avg_f1_score', data=plot_df, marker='', color='olive', linewidth=2, linestyle='dashed') plt.legend() fname = 'lexical-semantic-syntactic.png' if args.lexical and args.semantic and args.syntactic \ else 'semantic-syntactic.png' if args.semantic and args.syntactic \ else 'lexical-semantic.png' if args.lexical and args.semantic \ else 'lexical-syntactic.png'if args.lexical and args.syntactic \ else 'lexical.png' if args.lexical \ else 'syntactic.png' if args.syntactic \ else 'semantic.png' if args.semantic \ else 'plain.png' if not (path.exists('./images')): mkdir('./images') plt.savefig(path.join('./images', fname))
def load_models(config): # train_data = pickle.load(open(os.path.join(config.data_path, config.train_name), "rb")) # debug # if config.debug: # train_data = train_data[0:30] # dev_data = pickle.load(open(os.path.join(config.data_path, config.dev_name), "rb")) # test_data = pickle.load(open(os.path.join(config.data_path, config.test_name), "rb")) vocabulary = pickle.load( open(os.path.join(config.data_path, config.vocabulary_name), "rb")) # load w2v data # weight = pickle.load(open(os.path.join(config.data_path, config.weight_name), "rb")) cnn = CNN(vocab_size=len(vocabulary), embed_dim=config.embed_dim, class_num=config.class_num, kernel_num=config.kernel_num, kernel_sizes=config.kernel_sizes, dropout=config.dropout, static=config.static, in_channels=config.in_channels) state_dict = torch.load( os.path.join(config.save_path, config.ensemble_models[0])).state_dict() cnn.load_state_dict(state_dict) lstm = LSTM(vocab_size=len(vocabulary), embed_dim=config.embed_dim, output_dim=config.class_num, hidden_dim=config.hidden_dim, num_layers=config.num_layers, dropout=config.dropout) state_dict = torch.load( os.path.join(config.save_path, config.ensemble_models[1])).state_dict() lstm.load_state_dict(state_dict) lstm_mxp = LSTM_maxpool(vocab_size=len(vocabulary), embed_dim=config.embed_dim, output_dim=config.class_num, hidden_dim=config.hidden_dim, num_layers=config.num_layers, dropout=config.dropout) state_dict = torch.load( os.path.join(config.save_path, config.ensemble_models[2])).state_dict() lstm_mxp.load_state_dict(state_dict) rcnn = RCNN(vocab_size=len(vocabulary), embed_dim=config.embed_dim, output_dim=config.class_num, hidden_dim=config.hidden_dim, num_layers=config.num_layers, dropout=config.dropout) state_dict = torch.load( os.path.join(config.save_path, config.ensemble_models[3])).state_dict() rcnn.load_state_dict(state_dict) schemas = get_schemas(config.source_path) state_dict = torch.load( os.path.join(config.save_path, config.ensemble_models[4])).state_dict() bert = BertForMultiLabelSequenceClassification.from_pretrained( config.bert_folder, state_dict=state_dict, num_labels=len(schemas)) bert.load_state_dict(state_dict) return cnn, lstm, lstm_mxp, rcnn, bert