Пример #1
0
def model_predictions(model, data, vocab, DEVICE, BATCH_SIZE=16):
    """
    model: an instance of BertSCLSTM
    data: list of tuples, with each tuple consisting of correct and incorrect 
            sentence string (would be split at whitespaces)
    """

    topk = 1
    # print("###############################################")
    inference_st_time = time.time()
    final_sentences = []
    VALID_BATCH_SIZE = BATCH_SIZE
    # print("data size: {}".format(len(data)))
    data_iter = batch_iter(data, batch_size=VALID_BATCH_SIZE, shuffle=False)
    model.eval()
    model.to(DEVICE)
    for batch_id, (batch_labels, batch_sentences) in enumerate(data_iter):
        # set batch data for bert
        batch_labels_, batch_sentences_, batch_bert_inp, batch_bert_splits = bert_tokenize_for_valid_examples(
            batch_labels, batch_sentences)
        if len(batch_labels_) == 0:
            print("################")
            print(
                "Not predicting the following lines due to pre-processing mismatch: \n"
            )
            print([(a, b) for a, b in zip(batch_labels, batch_sentences)])
            print("################")
            continue
        else:
            batch_labels, batch_sentences = batch_labels_, batch_sentences_
        batch_bert_inp = {k: v.to(DEVICE) for k, v in batch_bert_inp.items()}
        # set batch data for others
        batch_labels_ids, batch_lengths = labelize(batch_labels, vocab)
        batch_idxs, batch_lengths_ = sclstm_tokenize(batch_sentences, vocab)
        assert (batch_lengths_ == batch_lengths).all() == True
        assert len(batch_bert_splits) == len(batch_idxs)
        batch_idxs = [batch_idxs_.to(DEVICE) for batch_idxs_ in batch_idxs]
        batch_lengths = batch_lengths.to(DEVICE)
        batch_labels_ids = batch_labels_ids.to(DEVICE)
        # forward
        with torch.no_grad():
            """
            NEW: batch_predictions can now be of shape (batch_size,batch_max_seq_len,topk) if topk>1, else (batch_size,batch_max_seq_len)
            """
            _, batch_predictions = model(batch_idxs,
                                         batch_lengths,
                                         batch_bert_inp,
                                         batch_bert_splits,
                                         targets=batch_labels_ids,
                                         topk=topk)
        batch_predictions = untokenize_without_unks(batch_predictions,
                                                    batch_lengths, vocab,
                                                    batch_labels)
        final_sentences.extend(batch_predictions)
    # print("total inference time for this data is: {:4f} secs".format(time.time()-inference_st_time))
    return final_sentences
def model_inference(model, data, topk, DEVICE, BATCH_SIZE=16, vocab_=None):
    """
    model: an instance of BertSCLSTM
    data: list of tuples, with each tuple consisting of correct and incorrect 
            sentence string (would be split at whitespaces)
    topk: how many of the topk softmax predictions are considered for metrics calculations
    """
    if vocab_ is not None:
        vocab = vocab_
    print("###############################################")
    inference_st_time = time.time()
    _corr2corr, _corr2incorr, _incorr2corr, _incorr2incorr = 0, 0, 0, 0
    _mistakes = []
    VALID_BATCH_SIZE = BATCH_SIZE
    valid_loss = 0.
    valid_acc = 0.
    print("data size: {}".format(len(data)))
    data_iter = batch_iter(data, batch_size=VALID_BATCH_SIZE, shuffle=False)
    model.eval()
    model.to(DEVICE)
    for batch_id, (batch_labels,batch_sentences) in tqdm(enumerate(data_iter)):
        torch.cuda.empty_cache()
        st_time = time.time()
        # set batch data for bert
        batch_labels_, batch_sentences_, batch_bert_inp, batch_bert_splits = bert_tokenize_for_valid_examples(batch_labels,batch_sentences)
        if len(batch_labels_)==0:
            print("################")
            print("Not predicting the following lines due to pre-processing mismatch: \n")
            print([(a,b) for a,b in zip(batch_labels,batch_sentences)])
            print("################")
            continue
        else:
            batch_labels, batch_sentences = batch_labels_, batch_sentences_
        batch_bert_inp = {k:v.to(DEVICE) for k,v in batch_bert_inp.items()}
        # set batch data for others
        batch_labels_ids, batch_lengths = labelize(batch_labels, vocab)
        batch_idxs, batch_lengths_ = sclstm_tokenize(batch_sentences, vocab)
        assert (batch_lengths_==batch_lengths).all()==True
        assert len(batch_bert_splits)==len(batch_idxs)
        batch_idxs = [batch_idxs_.to(DEVICE) for batch_idxs_ in batch_idxs]
        batch_lengths = batch_lengths.to(DEVICE)
        batch_labels_ids = batch_labels_ids.to(DEVICE) 
        # forward
        try:
            with torch.no_grad():
                """
                NEW: batch_predictions can now be of shape (batch_size,batch_max_seq_len,topk) if topk>1, else (batch_size,batch_max_seq_len)
                """
                batch_loss, batch_predictions = model(batch_idxs, batch_lengths, batch_bert_inp, batch_bert_splits, targets=batch_labels_ids, topk=topk)
        except RuntimeError:
            print(f"batch_idxs:{len(batch_idxs)},batch_lengths:{batch_lengths.shape},batch_bert_inp:{len(batch_bert_inp.keys())},batch_labels_ids:{batch_labels_ids.shape}")
            raise Exception("")
        valid_loss += batch_loss
        # compute accuracy in numpy
        batch_labels_ids = batch_labels_ids.cpu().detach().numpy()
        batch_lengths = batch_lengths.cpu().detach().numpy()
        # based on topk, obtain either strings of batch_predictions or list of tokens
        if topk==1:
            batch_predictions = untokenize_without_unks(batch_predictions, batch_lengths, vocab, batch_sentences)    
        else:
            batch_predictions = untokenize_without_unks2(batch_predictions, batch_lengths, vocab, batch_sentences, topk=None)
        #corr2corr, corr2incorr, incorr2corr, incorr2incorr, mistakes = \
        #    get_metrics(batch_labels,batch_sentences,batch_predictions,check_until_topk=topk,return_mistakes=True)
        #_mistakes.extend(mistakes)
        # batch_labels = [line.lower() for line in batch_labels]
        # batch_sentences = [line.lower() for line in batch_sentences]
        # batch_predictions = [line.lower() for line in batch_predictions]
        corr2corr, corr2incorr, incorr2corr, incorr2incorr = \
            get_metrics(batch_labels,batch_sentences,batch_predictions,check_until_topk=topk,return_mistakes=False)
        _corr2corr+=corr2corr
        _corr2incorr+=corr2incorr
        _incorr2corr+=incorr2corr
        _incorr2incorr+=incorr2incorr
        
        # delete
        del batch_loss
        del batch_predictions
        del batch_labels, batch_lengths, batch_idxs, batch_lengths_, batch_bert_inp
        torch.cuda.empty_cache()

        '''
        # update progress
        progressBar(batch_id+1,
                    int(np.ceil(len(data) / VALID_BATCH_SIZE)), 
                    ["batch_time","batch_loss","avg_batch_loss","batch_acc","avg_batch_acc"], 
                    [time.time()-st_time,batch_loss,valid_loss/(batch_id+1),None,None])
        '''
    print(f"\nEpoch {None} valid_loss: {valid_loss/(batch_id+1)}")
    print("total inference time for this data is: {:4f} secs".format(time.time()-inference_st_time))
    print("###############################################")
    print("")
    #for mistake in _mistakes:
    #    print(mistake)
    print("")
    print("total token count: {}".format(_corr2corr+_corr2incorr+_incorr2corr+_incorr2incorr))
    print(f"_corr2corr:{_corr2corr}, _corr2incorr:{_corr2incorr}, _incorr2corr:{_incorr2corr}, _incorr2incorr:{_incorr2incorr}")
    print(f"accuracy is {(_corr2corr+_incorr2corr)/(_corr2corr+_corr2incorr+_incorr2corr+_incorr2incorr)}")
    print(f"word correction rate is {(_incorr2corr)/(_incorr2corr+_incorr2incorr)}")
    print("###############################################")
    return
 print(f"In epoch: {epoch_id}")
 progress_write_file.write(f"In epoch: {epoch_id}\n")
 progress_write_file.flush()
 # train loss and backprop
 train_loss = 0.
 train_acc = 0.
 print("train_data size: {}".format(len(train_data)))
 progress_write_file.write("train_data size: {}\n".format(len(train_data)))
 progress_write_file.flush()
 train_data_iter = batch_iter(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
 nbatches = int(np.ceil(len(train_data)/TRAIN_BATCH_SIZE))
 optimizer.zero_grad()
 for batch_id, (batch_labels,batch_sentences) in enumerate(train_data_iter):
     st_time = time.time()
     # set batch data for bert
     batch_labels_, batch_sentences_, batch_bert_inp, batch_bert_splits = bert_tokenize_for_valid_examples(batch_labels,batch_sentences)                
     # print("$$$$$$ NEW BATCH $$$$$$$$$$$$$")
     # print("Before....")
     # print([len(x.split()) for x in batch_labels])
     # print([len(x.split()) for x in batch_sentences])
     # print("After...")
     # print([len(x.split()) for x in batch_labels_])
     # print([len(x.split()) for x in batch_sentences_])
     # print("At least a mismatch...")
     # if len(batch_labels)!=len(batch_labels_):
     #     for x in batch_labels:
     #         print(x)
     #     print("<------------------>")
     #     print("<------------------>")
     #     for x in batch_sentences:
     #         print(x)