Exemplo n.º 1
0
def model_predictions(model, data, vocab, DEVICE, BATCH_SIZE=16):
    """
    model: an instance of BertSCLSTM
    data: list of tuples, with each tuple consisting of correct and incorrect 
            sentence string (would be split at whitespaces)
    """

    topk = 1
    # print("###############################################")
    inference_st_time = time.time()
    final_sentences = []
    VALID_BATCH_SIZE = BATCH_SIZE
    # print("data size: {}".format(len(data)))
    data_iter = batch_iter(data, batch_size=VALID_BATCH_SIZE, shuffle=False)
    model.eval()
    model.to(DEVICE)
    for batch_id, (batch_labels, batch_sentences) in enumerate(data_iter):
        # set batch data for bert
        batch_labels_, batch_sentences_, batch_bert_inp, batch_bert_splits = bert_tokenize_for_valid_examples(
            batch_labels, batch_sentences)
        if len(batch_labels_) == 0:
            print("################")
            print(
                "Not predicting the following lines due to pre-processing mismatch: \n"
            )
            print([(a, b) for a, b in zip(batch_labels, batch_sentences)])
            print("################")
            continue
        else:
            batch_labels, batch_sentences = batch_labels_, batch_sentences_
        batch_bert_inp = {k: v.to(DEVICE) for k, v in batch_bert_inp.items()}
        # set batch data for others
        batch_labels_ids, batch_lengths = labelize(batch_labels, vocab)
        batch_idxs, batch_lengths_ = sclstm_tokenize(batch_sentences, vocab)
        assert (batch_lengths_ == batch_lengths).all() == True
        assert len(batch_bert_splits) == len(batch_idxs)
        batch_idxs = [batch_idxs_.to(DEVICE) for batch_idxs_ in batch_idxs]
        batch_lengths = batch_lengths.to(DEVICE)
        batch_labels_ids = batch_labels_ids.to(DEVICE)
        # forward
        with torch.no_grad():
            """
            NEW: batch_predictions can now be of shape (batch_size,batch_max_seq_len,topk) if topk>1, else (batch_size,batch_max_seq_len)
            """
            _, batch_predictions = model(batch_idxs,
                                         batch_lengths,
                                         batch_bert_inp,
                                         batch_bert_splits,
                                         targets=batch_labels_ids,
                                         topk=topk)
        batch_predictions = untokenize_without_unks(batch_predictions,
                                                    batch_lengths, vocab,
                                                    batch_labels)
        final_sentences.extend(batch_predictions)
    # print("total inference time for this data is: {:4f} secs".format(time.time()-inference_st_time))
    return final_sentences
Exemplo n.º 2
0
def model_predictions(model, data, vocab, DEVICE, BATCH_SIZE=16):
    """
    model: an instance of ElmoSCLSTM
    data: list of tuples, with each tuple consisting of correct and incorrect 
            sentence string (would be split at whitespaces)
    """

    topk = 1
    # print("###############################################")
    inference_st_time = time.time()
    final_sentences = []
    VALID_BATCH_SIZE = BATCH_SIZE
    # print("data size: {}".format(len(data)))
    data_iter = batch_iter(data, batch_size=VALID_BATCH_SIZE, shuffle=False)
    model.eval()
    model.to(DEVICE)
    for batch_id, (batch_clean_sentences,
                   batch_corrupt_sentences) in enumerate(data_iter):
        # set batch data
        batch_labels, batch_lengths = labelize(batch_clean_sentences, vocab)
        batch_idxs, batch_lengths_ = sclstm_tokenize(batch_corrupt_sentences,
                                                     vocab)
        assert (batch_lengths_ == batch_lengths).all() == True
        batch_idxs = [batch_idxs_.to(DEVICE) for batch_idxs_ in batch_idxs]
        batch_lengths = batch_lengths.to(DEVICE)
        batch_labels = batch_labels.to(DEVICE)
        batch_elmo_inp = elmo_batch_to_ids(
            [line.split() for line in batch_corrupt_sentences]).to(DEVICE)
        # forward
        with torch.no_grad():
            """
            NEW: batch_predictions can now be of shape (batch_size,batch_max_seq_len,topk) if topk>1, else (batch_size,batch_max_seq_len)
            """
            _, batch_predictions = model(batch_idxs,
                                         batch_lengths,
                                         batch_elmo_inp,
                                         targets=batch_labels,
                                         topk=topk)
        batch_predictions = untokenize_without_unks(batch_predictions,
                                                    batch_lengths, vocab,
                                                    batch_clean_sentences)
        final_sentences.extend(batch_predictions)
    # print("total inference time for this data is: {:4f} secs".format(time.time()-inference_st_time))
    return final_sentences
def model_inference(model, data, topk, DEVICE, BATCH_SIZE=16, vocab_=None):
    """
    model: an instance of BertSCLSTM
    data: list of tuples, with each tuple consisting of correct and incorrect 
            sentence string (would be split at whitespaces)
    topk: how many of the topk softmax predictions are considered for metrics calculations
    """
    if vocab_ is not None:
        vocab = vocab_
    print("###############################################")
    inference_st_time = time.time()
    _corr2corr, _corr2incorr, _incorr2corr, _incorr2incorr = 0, 0, 0, 0
    _mistakes = []
    VALID_BATCH_SIZE = BATCH_SIZE
    valid_loss = 0.
    valid_acc = 0.
    print("data size: {}".format(len(data)))
    data_iter = batch_iter(data, batch_size=VALID_BATCH_SIZE, shuffle=False)
    model.eval()
    model.to(DEVICE)
    for batch_id, (batch_labels,batch_sentences) in tqdm(enumerate(data_iter)):
        torch.cuda.empty_cache()
        st_time = time.time()
        # set batch data for bert
        batch_labels_, batch_sentences_, batch_bert_inp, batch_bert_splits = bert_tokenize_for_valid_examples(batch_labels,batch_sentences)
        if len(batch_labels_)==0:
            print("################")
            print("Not predicting the following lines due to pre-processing mismatch: \n")
            print([(a,b) for a,b in zip(batch_labels,batch_sentences)])
            print("################")
            continue
        else:
            batch_labels, batch_sentences = batch_labels_, batch_sentences_
        batch_bert_inp = {k:v.to(DEVICE) for k,v in batch_bert_inp.items()}
        # set batch data for others
        batch_labels_ids, batch_lengths = labelize(batch_labels, vocab)
        batch_idxs, batch_lengths_ = sclstm_tokenize(batch_sentences, vocab)
        assert (batch_lengths_==batch_lengths).all()==True
        assert len(batch_bert_splits)==len(batch_idxs)
        batch_idxs = [batch_idxs_.to(DEVICE) for batch_idxs_ in batch_idxs]
        batch_lengths = batch_lengths.to(DEVICE)
        batch_labels_ids = batch_labels_ids.to(DEVICE) 
        # forward
        try:
            with torch.no_grad():
                """
                NEW: batch_predictions can now be of shape (batch_size,batch_max_seq_len,topk) if topk>1, else (batch_size,batch_max_seq_len)
                """
                batch_loss, batch_predictions = model(batch_idxs, batch_lengths, batch_bert_inp, batch_bert_splits, targets=batch_labels_ids, topk=topk)
        except RuntimeError:
            print(f"batch_idxs:{len(batch_idxs)},batch_lengths:{batch_lengths.shape},batch_bert_inp:{len(batch_bert_inp.keys())},batch_labels_ids:{batch_labels_ids.shape}")
            raise Exception("")
        valid_loss += batch_loss
        # compute accuracy in numpy
        batch_labels_ids = batch_labels_ids.cpu().detach().numpy()
        batch_lengths = batch_lengths.cpu().detach().numpy()
        # based on topk, obtain either strings of batch_predictions or list of tokens
        if topk==1:
            batch_predictions = untokenize_without_unks(batch_predictions, batch_lengths, vocab, batch_sentences)    
        else:
            batch_predictions = untokenize_without_unks2(batch_predictions, batch_lengths, vocab, batch_sentences, topk=None)
        #corr2corr, corr2incorr, incorr2corr, incorr2incorr, mistakes = \
        #    get_metrics(batch_labels,batch_sentences,batch_predictions,check_until_topk=topk,return_mistakes=True)
        #_mistakes.extend(mistakes)
        # batch_labels = [line.lower() for line in batch_labels]
        # batch_sentences = [line.lower() for line in batch_sentences]
        # batch_predictions = [line.lower() for line in batch_predictions]
        corr2corr, corr2incorr, incorr2corr, incorr2incorr = \
            get_metrics(batch_labels,batch_sentences,batch_predictions,check_until_topk=topk,return_mistakes=False)
        _corr2corr+=corr2corr
        _corr2incorr+=corr2incorr
        _incorr2corr+=incorr2corr
        _incorr2incorr+=incorr2incorr
        
        # delete
        del batch_loss
        del batch_predictions
        del batch_labels, batch_lengths, batch_idxs, batch_lengths_, batch_bert_inp
        torch.cuda.empty_cache()

        '''
        # update progress
        progressBar(batch_id+1,
                    int(np.ceil(len(data) / VALID_BATCH_SIZE)), 
                    ["batch_time","batch_loss","avg_batch_loss","batch_acc","avg_batch_acc"], 
                    [time.time()-st_time,batch_loss,valid_loss/(batch_id+1),None,None])
        '''
    print(f"\nEpoch {None} valid_loss: {valid_loss/(batch_id+1)}")
    print("total inference time for this data is: {:4f} secs".format(time.time()-inference_st_time))
    print("###############################################")
    print("")
    #for mistake in _mistakes:
    #    print(mistake)
    print("")
    print("total token count: {}".format(_corr2corr+_corr2incorr+_incorr2corr+_incorr2incorr))
    print(f"_corr2corr:{_corr2corr}, _corr2incorr:{_corr2incorr}, _incorr2corr:{_incorr2corr}, _incorr2incorr:{_incorr2incorr}")
    print(f"accuracy is {(_corr2corr+_incorr2corr)/(_corr2corr+_corr2incorr+_incorr2corr+_incorr2incorr)}")
    print(f"word correction rate is {(_incorr2corr)/(_incorr2corr+_incorr2incorr)}")
    print("###############################################")
    return
 #     print("<------------------>")
 #     print("<------------------>")
 #     for x in batch_sentences:
 #         print(x)
 if len(batch_labels_)==0:
     print("################")
     print("Not training the following lines due to pre-processing mismatch: \n")
     print([(a,b) for a,b in zip(batch_labels,batch_sentences)])
     print("################")
     continue
 else:
     batch_labels, batch_sentences = batch_labels_, batch_sentences_
 batch_bert_inp = {k:v.to(DEVICE) for k,v in batch_bert_inp.items()}
 # set batch data for others
 batch_labels, batch_lengths = labelize(batch_labels, vocab)
 batch_idxs, batch_lengths_ = sclstm_tokenize(batch_sentences, vocab)
 assert (batch_lengths_==batch_lengths).all()==True
 assert len(batch_bert_splits)==len(batch_idxs)
 batch_idxs = [batch_idxs_.to(DEVICE) for batch_idxs_ in batch_idxs]
 batch_lengths = batch_lengths.to(DEVICE)
 batch_labels = batch_labels.to(DEVICE)                
 # forward
 model.train()
 loss = model(batch_idxs, batch_lengths, batch_bert_inp, batch_bert_splits, targets=batch_labels)
 batch_loss = loss.cpu().detach().numpy()
 train_loss += batch_loss
 # backward
 if GRADIENT_ACC > 1:
     loss = loss / GRADIENT_ACC
 loss.backward()
 # step
Exemplo n.º 5
0
def model_inference(model, data, topk, DEVICE, BATCH_SIZE=16, vocab_=None):
    """
    model: an instance of SCLSTM
    data: list of tuples, with each tuple consisting of correct and incorrect 
            sentence string (would be split at whitespaces)
    topk: how many of the topk softmax predictions are considered for metrics calculations
    """
    if vocab_ is not None:
        vocab = vocab_
    # list of dicts with keys {"id":, "original":, "noised":, "predicted":, "topk":, "topk_prediction_probs":, "topk_reranker_losses":,}
    results = []
    line_index = 0

    print("###############################################")
    inference_st_time = time.time()
    _corr2corr, _corr2incorr, _incorr2corr, _incorr2incorr = 0, 0, 0, 0
    _mistakes = []
    VALID_BATCH_SIZE = BATCH_SIZE
    valid_loss = 0.
    valid_acc = 0.
    print("data size: {}".format(len(data)))
    data_iter = batch_iter(data, batch_size=VALID_BATCH_SIZE, shuffle=False)
    model.eval()
    model.to(DEVICE)
    for batch_id, (batch_clean_sentences,
                   batch_corrupt_sentences) in tqdm(enumerate(data_iter)):
        st_time = time.time()
        # set batch data
        batch_labels, batch_lengths = labelize(batch_clean_sentences, vocab)
        batch_idxs, batch_lengths_ = sclstm_tokenize(batch_corrupt_sentences,
                                                     vocab)
        assert (batch_lengths_ == batch_lengths).all() == True
        batch_idxs = [batch_idxs_.to(DEVICE) for batch_idxs_ in batch_idxs]
        batch_lengths = batch_lengths.to(DEVICE)
        batch_labels = batch_labels.to(DEVICE)
        # forward
        with torch.no_grad():
            """
            NEW: batch_predictions can now be of shape (batch_size,batch_max_seq_len,topk) if topk>1, else (batch_size,batch_max_seq_len)
            """
            batch_loss, batch_predictions = model(batch_idxs,
                                                  batch_lengths,
                                                  targets=batch_labels,
                                                  topk=topk)
        valid_loss += batch_loss
        # compute accuracy in numpy
        batch_labels = batch_labels.cpu().detach().numpy()
        batch_lengths = batch_lengths.cpu().detach().numpy()
        # based on topk, obtain either strings of batch_predictions or list of tokens
        if topk == 1:
            batch_predictions = untokenize_without_unks(
                batch_predictions, batch_lengths, vocab,
                batch_corrupt_sentences)
        else:
            batch_predictions = untokenize_without_unks2(
                batch_predictions,
                batch_lengths,
                vocab,
                batch_corrupt_sentences,
                topk=None)

        batch_clean_sentences = [
            line.lower() for line in batch_clean_sentences
        ]
        batch_corrupt_sentences = [
            line.lower() for line in batch_corrupt_sentences
        ]
        batch_predictions = [line.lower() for line in batch_predictions]
        corr2corr, corr2incorr, incorr2corr, incorr2incorr, mistakes = \
            get_metrics(batch_clean_sentences,batch_corrupt_sentences,batch_predictions,check_until_topk=topk,return_mistakes=True)
        _corr2corr += corr2corr
        _corr2incorr += corr2incorr
        _incorr2corr += incorr2corr
        _incorr2incorr += incorr2incorr
        _mistakes.extend(mistakes)

        for i, (a, b, c) in enumerate(
                zip(batch_clean_sentences, batch_corrupt_sentences,
                    batch_predictions)):
            results.append({
                "id": line_index + i,
                "original": a,
                "noised": b,
                "predicted": c,
                "topk": [],
                "topk_prediction_probs": [],
                "topk_reranker_losses": []
            })
        line_index += len(batch_clean_sentences)
        '''
        # update progress
        progressBar(batch_id+1,
                    int(np.ceil(len(data) / VALID_BATCH_SIZE)), 
                    ["batch_time","batch_loss","avg_batch_loss","batch_acc","avg_batch_acc"], 
                    [time.time()-st_time,batch_loss,valid_loss/(batch_id+1),None,None])
        '''
    print(f"\nEpoch {None} valid_loss: {valid_loss/(batch_id+1)}")
    print("total inference time for this data is: {:4f} secs".format(
        time.time() - inference_st_time))
    print("###############################################")
    # print("")
    # for mistake in _mistakes:
    #     print(mistake)
    # print("")
    print("total token count: {}".format(_corr2corr + _corr2incorr +
                                         _incorr2corr + _incorr2incorr))
    print(
        f"_corr2corr:{_corr2corr}, _corr2incorr:{_corr2incorr}, _incorr2corr:{_incorr2corr}, _incorr2incorr:{_incorr2incorr}"
    )
    print(
        f"accuracy is {(_corr2corr+_incorr2corr)/(_corr2corr+_corr2incorr+_incorr2corr+_incorr2incorr)}"
    )
    print(
        f"word correction rate is {(_incorr2corr)/(_incorr2corr+_incorr2incorr)}"
    )
    print("###############################################")
    return results