def model_predictions(model, data, vocab, DEVICE, BATCH_SIZE=16): """ model: an instance of ElmoSCTransformer data: list of tuples, with each tuple consisting of correct and incorrect sentence string (would be split at whitespaces) """ topk = 1 print("###############################################") inference_st_time = time.time() final_sentences = [] VALID_BATCH_SIZE = BATCH_SIZE print("data size: {}".format(len(data))) data_iter = batch_iter(data, batch_size=VALID_BATCH_SIZE, shuffle=False) model.eval() model.to(DEVICE) for batch_id, (batch_clean_sentences, batch_corrupt_sentences) in enumerate(data_iter): # set batch data batch_labels, batch_lengths = labelize(batch_clean_sentences, vocab) batch_idxs, batch_lengths_, inverted_mask = sctrans_tokenize( batch_corrupt_sentences, vocab) assert (batch_lengths_ == batch_lengths).all() == True batch_idxs = [batch_idxs_.to(DEVICE) for batch_idxs_ in batch_idxs] batch_lengths = batch_lengths.to(DEVICE) batch_labels = batch_labels.to(DEVICE) inverted_mask = inverted_mask.to(DEVICE) batch_elmo_inp = elmo_batch_to_ids( [line.split() for line in batch_corrupt_sentences]).to(DEVICE) # forward with torch.no_grad(): """ NEW: batch_predictions can now be of shape (batch_size,batch_max_seq_len,topk) if topk>1, else (batch_size,batch_max_seq_len) """ _, batch_predictions = model(batch_idxs, inverted_mask, batch_lengths, batch_elmo_inp, targets=batch_labels, topk=topk) batch_predictions = untokenize_without_unks(batch_predictions, batch_lengths, vocab, batch_clean_sentences) final_sentences.extend(batch_predictions) print("total inference time for this data is: {:4f} secs".format( time.time() - inference_st_time)) return final_sentences
def model_inference(model, data, topk, DEVICE, BATCH_SIZE=16): """ model: an instance of ElmoSCTransformer data: list of tuples, with each tuple consisting of correct and incorrect sentence string (would be split at whitespaces) topk: how many of the topk softmax predictions are considered for metrics calculations """ print("###############################################") inference_st_time = time.time() _corr2corr, _corr2incorr, _incorr2corr, _incorr2incorr = 0, 0, 0, 0 _mistakes = [] VALID_BATCH_SIZE = BATCH_SIZE valid_loss = 0. valid_acc = 0. print("data size: {}".format(len(data))) data_iter = batch_iter(data, batch_size=VALID_BATCH_SIZE, shuffle=False) model.eval() model.to(DEVICE) for batch_id, (batch_clean_sentences, batch_corrupt_sentences) in tqdm(enumerate(data_iter)): torch.cuda.empty_cache() st_time = time.time() # set batch data batch_labels, batch_lengths = labelize(batch_clean_sentences, vocab) batch_idxs, batch_lengths_, inverted_mask = sctrans_tokenize( batch_corrupt_sentences, vocab) assert (batch_lengths_ == batch_lengths).all() == True batch_idxs = [batch_idxs_.to(DEVICE) for batch_idxs_ in batch_idxs] batch_lengths = batch_lengths.to(DEVICE) batch_labels = batch_labels.to(DEVICE) inverted_mask = inverted_mask.to(DEVICE) batch_elmo_inp = elmo_batch_to_ids( [line.split() for line in batch_corrupt_sentences]).to(DEVICE) # forward try: with torch.no_grad(): """ NEW: batch_predictions can now be of shape (batch_size,batch_max_seq_len,topk) if topk>1, else (batch_size,batch_max_seq_len) """ batch_loss, batch_predictions = model(batch_idxs, inverted_mask, batch_lengths, batch_elmo_inp, targets=batch_labels, topk=topk) except RuntimeError: print( f"batch_idxs:{len(batch_idxs)},batch_lengths:{batch_lengths.shape},batch_elmo_inp:{batch_elmo_inp.shape},batch_labels:{batch_labels.shape}" ) raise Exception("") valid_loss += batch_loss # compute accuracy in numpy batch_labels = batch_labels.cpu().detach().numpy() batch_lengths = batch_lengths.cpu().detach().numpy() # based on topk, obtain either strings of batch_predictions or list of tokens if topk == 1: batch_predictions = untokenize_without_unks( batch_predictions, batch_lengths, vocab, batch_corrupt_sentences) else: batch_predictions = untokenize_without_unks2( batch_predictions, batch_lengths, vocab, batch_corrupt_sentences, topk=None) #corr2corr, corr2incorr, incorr2corr, incorr2incorr, mistakes = \ # get_metrics(batch_clean_sentences,batch_corrupt_sentences,batch_predictions,check_until_topk=topk,return_mistakes=True) #_mistakes.extend(mistakes) batch_clean_sentences = [ line.lower() for line in batch_clean_sentences ] batch_corrupt_sentences = [ line.lower() for line in batch_corrupt_sentences ] batch_predictions = [line.lower() for line in batch_predictions] corr2corr, corr2incorr, incorr2corr, incorr2incorr = \ get_metrics(batch_clean_sentences,batch_corrupt_sentences,batch_predictions,check_until_topk=topk,return_mistakes=False) _corr2corr += corr2corr _corr2incorr += corr2incorr _incorr2corr += incorr2corr _incorr2incorr += incorr2incorr # delete del batch_loss del batch_predictions del batch_labels, batch_lengths, batch_idxs, batch_lengths_, batch_elmo_inp torch.cuda.empty_cache() ''' # update progress progressBar(batch_id+1, int(np.ceil(len(data) / VALID_BATCH_SIZE)), ["batch_time","batch_loss","avg_batch_loss","batch_acc","avg_batch_acc"], [time.time()-st_time,batch_loss,valid_loss/(batch_id+1),None,None]) ''' print(f"\nEpoch {None} valid_loss: {valid_loss/(batch_id+1)}") print("total inference time for this data is: {:4f} secs".format( time.time() - inference_st_time)) print("###############################################") print("") #for mistake in _mistakes: # print(mistake) print("") print("total token count: {}".format(_corr2corr + _corr2incorr + _incorr2corr + _incorr2incorr)) print( f"_corr2corr:{_corr2corr}, _corr2incorr:{_corr2incorr}, _incorr2corr:{_incorr2corr}, _incorr2incorr:{_incorr2incorr}" ) print( f"accuracy is {(_corr2corr+_incorr2corr)/(_corr2corr+_corr2incorr+_incorr2corr+_incorr2incorr)}" ) print( f"word correction rate is {(_incorr2corr)/(_incorr2corr+_incorr2incorr)}" ) print("###############################################") return
for batch_id, (batch_labels, batch_sentences) in enumerate(train_data_iter): optimizer.zero_grad() st_time = time.time() # set batch data batch_labels, batch_lengths = labelize(batch_labels, vocab) batch_idxs, batch_lengths_, inverted_mask = sctrans_tokenize( batch_sentences, vocab) assert (batch_lengths_ == batch_lengths).all() == True batch_idxs = [ batch_idxs_.to(DEVICE) for batch_idxs_ in batch_idxs ] batch_lengths = batch_lengths.to(DEVICE) batch_labels = batch_labels.to(DEVICE) inverted_mask = inverted_mask.to(DEVICE) batch_elmo_inp = elmo_batch_to_ids( [line.split() for line in batch_sentences]).to(DEVICE) # forward model.train() loss = model(batch_idxs, inverted_mask, batch_lengths, batch_elmo_inp, targets=batch_labels) batch_loss = loss.cpu().detach().numpy() train_loss += batch_loss # backward loss.backward() optimizer.step() # compute accuracy in numpy model.eval() with torch.no_grad():