def student_train(dataset): t_logits = teacher_predict(dataset) print(1) train_loader = get_train_data(dataset) student = biLSTM() total_params = sum(p.numel() for p in student.parameters()) print(f'{total_params:,} total parameters.') optimizer = torch.optim.SGD(student.parameters(), lr=0.05) tra_best_loss = float('inf') student.train() for epoch in range(400): hidden_train = None print('Epoch [{}/{}]'.format(epoch + 1, 400)) for i, (x,y) in enumerate(train_loader): optimizer.zero_grad() s_logits, _ = student(x, hidden_train) hidden_train = None label = y.squeeze(1).long() loss = get_loss(t_logits[i], s_logits.squeeze(1), label, 0.2, 4) # loss = get_loss(t_logits[i], s_logits, label, 0, 4) loss.backward() optimizer.step() if loss.item() < tra_best_loss: tra_best_loss = loss.item() torch.save(student.state_dict(), 'data/saved_dict/lstm.ckpt') print(loss.item()) student_test(dataset)
def main_train(): '''Main function for training + validation.''' # if we are resuming training on a model RESUME = True # Hyper-parameters NUM_EPOCH = 100 LOSS_FUNCTION = CrossEntropyLoss OPTIMIZER = optim.Adam BATCH_SIZE = 256 MAX_VOCAB_SIZE = 50000 #max_vocab_size: takes the 100,000 most frequent words as the vocab lr = 1e-4 optimiser_params = {'lr': lr, 'weight_decay': 1e-5} EMBEDDING_DIM = 200 HIDDEN_DIM = 256 # for the LSTM model: OUTPUT_DIM = 3 MODEL_MODE = 'RNN' # 'RNN' or 'CNN' conv_out_ch = 200 # for the CNN model: filter_sizes = [3, 4, 5] # for the CNN model: SPLIT_RATIO = 0.85 # ratio of the train set, 1.0 means 100% training, 0% valid data EXPERIMENT_NAME = "Adam_lr" + str(lr) + "_max_vocab_size" + str(MAX_VOCAB_SIZE) if RESUME == True: params = open_experiment(EXPERIMENT_NAME) else: params = create_experiment(EXPERIMENT_NAME) cfg_path = params["cfg_path"] # Prepare data data_handler = data_provider_V2(cfg_path=cfg_path, batch_size=BATCH_SIZE, split_ratio=SPLIT_RATIO, max_vocab_size=MAX_VOCAB_SIZE, mode=Mode.TRAIN, model_mode=MODEL_MODE) train_iterator, valid_iterator, vocab_size, PAD_IDX, UNK_IDX, pretrained_embeddings, weights, classes = data_handler.data_loader() print(f'\nSummary:\n----------------------------------------------------') print(f'Total # of Training tweets: {BATCH_SIZE * len(train_iterator):,}') if SPLIT_RATIO == 1: print(f'Total # of Valid. tweets: {0}') else: print(f'Total # of Valid. tweets: {BATCH_SIZE * len(valid_iterator):,}') # Initialize trainer trainer = Training(cfg_path, num_epochs=NUM_EPOCH, RESUME=RESUME, model_mode=MODEL_MODE) if MODEL_MODE == 'RNN': MODEL = biLSTM(vocab_size=vocab_size, embeddings=pretrained_embeddings, embedding_dim=EMBEDDING_DIM, hidden_dim=HIDDEN_DIM, output_dim=OUTPUT_DIM, pad_idx=PAD_IDX, unk_idx=UNK_IDX) elif MODEL_MODE == 'CNN': MODEL = CNN1d(vocab_size=vocab_size, embeddings=pretrained_embeddings, embedding_dim=EMBEDDING_DIM, conv_out_ch=conv_out_ch, filter_sizes=filter_sizes, output_dim=OUTPUT_DIM, pad_idx=PAD_IDX, unk_idx=UNK_IDX) if RESUME == True: trainer.load_checkpoint(model=MODEL, optimiser=OPTIMIZER, optimiser_params=optimiser_params, loss_function=LOSS_FUNCTION, weight=weights) else: trainer.setup_model(model=MODEL, optimiser=OPTIMIZER, optimiser_params=optimiser_params, loss_function=LOSS_FUNCTION, weight=weights) trainer.execute_training(train_loader=train_iterator, valid_loader=valid_iterator, batch_size=BATCH_SIZE)
def student_train(dataset): X_train, X_test, y_train, y_test = \ train_test_split(dataset['text'], dataset['pred'], stratify=dataset['pred'], test_size=0.2, random_state=1) train_student = data2frame(X_train, y_train) test_student = data2frame(X_test, y_test) _, t_logits = teacher_predict(train_student) _, t_test = teacher_predict(test_student) train_loader = get_train_data(train_student) student = biLSTM() total_params = sum(p.numel() for p in student.parameters()) print(f'{total_params:,} total parameters.') optimizer = torch.optim.SGD(student.parameters(), lr=0.05) total_batch = 0 total_epoch = 100 tra_best_loss = float('inf') dev_best_loss = float('inf') student.train() start_time = time.time() for epoch in range(total_epoch): hidden_train = None print('Epoch [{}/{}]'.format(epoch + 1, total_epoch)) for i, (x, y) in enumerate(train_loader): optimizer.zero_grad() s_logits, _ = student(x, hidden_train) hidden_train = None label = y.squeeze(1).long() loss = get_loss(t_logits[i], s_logits.squeeze(1), label, 1, 3) loss.backward() optimizer.step() if total_batch % 50 == 0: cur_pred = torch.squeeze(s_logits, dim=1) train_acc = metrics.accuracy_score( y.squeeze(1).long(), torch.max(cur_pred, 1)[1].cpu().numpy()) _, dev_loss, dev_acc = student_evaluate( test_student, student, t_test) if dev_loss < dev_best_loss: dev_best_loss = dev_loss torch.save(student.state_dict(), 'data/saved_dict/lstm.ckpt') improve = '*' last_improve = total_batch else: improve = '' time_dif = get_time_dif(start_time) msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}' print( msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve)) student.train() total_batch += 1 student_test(test_student)
def student_predict(dataset): model = biLSTM() data = get_train_data(dataset) model.load_state_dict(torch.load('data/saved_dict/lstm.ckpt')) model.eval() predict_all = [] hidden_predict = None with torch.no_grad(): for texts, labels in data: pred_X, hidden_predict = model(texts, hidden_predict) hidden_predict = None cur_pred = torch.squeeze(pred_X, dim=1) predic = torch.max(cur_pred, 1)[1].cpu().numpy() predict_all = np.append(predict_all, predic) return predict_all
def student_predict(x, y): model = biLSTM() data = get_train_data(x, y) model.load_state_dict(torch.load('data/saved_dict/lstm.ckpt')) model.eval() predict_all = [] hidden_predict = None total_params = sum(p.numel() for p in model.parameters()) print(f'{total_params:,} total parameters.') with torch.no_grad(): for texts, labels in data: pred_X, hidden_predict = model(texts, hidden_predict) hidden_predict = None cur_pred = torch.squeeze(pred_X, dim=1) predic = torch.max(cur_pred, 1)[1].cpu().numpy() predict_all = np.append(predict_all, predic) return predict_all
def predict_alone(dataset): model = biLSTM() # model = StudentNet() valid_data = get_train_data(dataset) model.load_state_dict(torch.load('data/saved_dict/lstm_student.ckpt')) model.eval() predict_all = [] hidden_predict = None result = torch.Tensor() with torch.no_grad(): for texts, labels in valid_data: pred_X, hidden_predict = model(texts, hidden_predict) # pred_X = model(texts) hidden_predict = None cur_pred = torch.squeeze(pred_X, dim=1) result = torch.cat((result, cur_pred), dim=0) predic = torch.max(cur_pred, 1)[1].cpu().numpy() predict_all = np.append(predict_all, predic) # print(result.detach().numpy()) return predict_all
def student_train_alone(dataset): print(1) train_loader = get_train_data(dataset) student = biLSTM() optimizer = torch.optim.SGD(student.parameters(), lr=0.05) tra_best_loss = float('inf') student.train() losss = nn.CrossEntropyLoss() for epoch in range(200): hidden_train = None print('Epoch [{}/{}]'.format(epoch + 1, 200)) for x, y in train_loader: optimizer.zero_grad() pred_X, _ = student(x, hidden_train) hidden_train = None # print(pred_X.squeeze(1).dtype,y.squeeze(1).dtype) loss = losss(pred_X.squeeze(1), y.squeeze(1).long()) loss.backward() optimizer.step() if loss.item() < tra_best_loss: tra_best_loss = loss.item() torch.save(student.state_dict(), 'data/saved_dict/lstm_student.ckpt') print(loss.item())
def main_train_postreply(): ''' Main function for training + validation of the second part of the project: Sentiment analysis of the Post-Replies. ''' # if we are resuming training on a model RESUME = False # Hyper-parameters NUM_EPOCH = 500 LOSS_FUNCTION = CrossEntropyLoss OPTIMIZER = optim.Adam BATCH_SIZE = 256 MAX_VOCAB_SIZE = 750000 #max_vocab_size: takes the 100,000 most frequent words as the vocab lr = 9e-5 optimiser_params = {'lr': lr, 'weight_decay': 1e-4} EMBEDDING_DIM = 200 HIDDEN_DIM = 300 OUTPUT_DIM = 3 MODEL_MODE = "CNN" # "RNN" or "CNN" conv_out_ch = 200 # for the CNN model: filter_sizes = [3, 4, 5] # for the CNN model: SPLIT_RATIO = 0.9 # ratio of the train set, 1.0 means 100% training, 0% valid data EXPERIMENT_NAME = "new_october_CNN" if RESUME == True: params = open_experiment(EXPERIMENT_NAME) else: params = create_experiment(EXPERIMENT_NAME) cfg_path = params["cfg_path"] # Prepare data data_handler = data_provider_PostReply(cfg_path=cfg_path, batch_size=BATCH_SIZE, split_ratio=SPLIT_RATIO, max_vocab_size=MAX_VOCAB_SIZE, mode=Mode.TRAIN, model_mode=MODEL_MODE) train_iterator, valid_iterator, vocab_size, PAD_IDX, UNK_IDX, pretrained_embeddings, weights, classes = data_handler.data_loader() if SPLIT_RATIO == 1: total_valid_tweets = 0 else: total_valid_tweets = BATCH_SIZE * len(valid_iterator) total_train_tweets = BATCH_SIZE * len(train_iterator) print(f'\nSummary:\n----------------------------------------------------') print(f'Total # of Training tweets: {total_train_tweets:,}') print(f'Total # of Valid. tweets: {total_valid_tweets:,}') # Initialize trainer trainer = Training(cfg_path, num_epochs=NUM_EPOCH, RESUME=RESUME, model_mode=MODEL_MODE) if MODEL_MODE == "RNN": MODEL = biLSTM(vocab_size=vocab_size, embeddings=pretrained_embeddings, embedding_dim=EMBEDDING_DIM, hidden_dim=HIDDEN_DIM, output_dim=OUTPUT_DIM, pad_idx=PAD_IDX, unk_idx=UNK_IDX) elif MODEL_MODE == "CNN": MODEL = CNN1d(vocab_size=vocab_size, embeddings=pretrained_embeddings, embedding_dim=EMBEDDING_DIM, conv_out_ch=conv_out_ch, filter_sizes=filter_sizes, output_dim=OUTPUT_DIM, pad_idx=PAD_IDX, unk_idx=UNK_IDX) if RESUME == True: trainer.load_checkpoint(model=MODEL, optimiser=OPTIMIZER, optimiser_params=optimiser_params, loss_function=LOSS_FUNCTION, weight=weights) else: trainer.setup_model(model=MODEL, optimiser=OPTIMIZER, optimiser_params=optimiser_params, loss_function=LOSS_FUNCTION, weight=weights) # writes the params to config file params = read_config(cfg_path) params['Network']['vocab_size'] = vocab_size params['Network']['PAD_IDX'] = PAD_IDX params['Network']['UNK_IDX'] = UNK_IDX params['Network']['classes'] = classes params['Network']['SPLIT_RATIO'] = SPLIT_RATIO params['Network']['MAX_VOCAB_SIZE'] = MAX_VOCAB_SIZE params['Network']['HIDDEN_DIM'] = HIDDEN_DIM params['Network']['EMBEDDING_DIM'] = EMBEDDING_DIM params['Network']['conv_out_ch'] = conv_out_ch params['Network']['MODEL_MODE'] = MODEL_MODE params['total_train_tweets'] = total_train_tweets params['total_valid_tweets'] = total_valid_tweets write_config(params, cfg_path, sort_keys=True) trainer.execute_training(train_loader=train_iterator, valid_loader=valid_iterator, batch_size=BATCH_SIZE)
from config import * from student import * from teacher import * from models.bert import * from models.biLSTM import * if __name__ == '__main__': set_seed(1) cfg = Config() start_time = time.time() print("加载数据...") train_text, train_label = get_dataset(cfg.train_path) test_text, test_label = get_dataset(cfg.test_path) train_loader = get_loader(train_text, train_label, cfg.tokenizer) test_loader = get_loader(test_text, test_label, cfg.tokenizer) time_dif = get_time_dif(start_time) print("Time usage:", time_dif) T_model = BERT_Model(cfg).to(cfg.device) if cfg.train_teacher: teacher_train(T_model, cfg, train_loader, test_loader) if cfg.train_student: S_model = biLSTM(cfg).to(cfg.device) student_train(T_model, S_model, cfg, train_loader, test_loader)