Esempio n. 1
0
def get_predictions(data_fp: Path, predictor: Predictor, incl_labels: bool,
                    vocab: Vocabulary) -> List[Dict[str, Any]]:
    '''
    :param data_fp: File Path to the dataset file that you wish to predict on
    :param predictor: A predictor that can be used to generate predictions
    :param incl_labels: Wether or not to include the original gold labels in 
                        the return
    :param vocab: Required to get the original gold labels
    :returns: A List of dictionaries that store the data read from the dataset 
              file and with included predictions and optional gold labels
    '''
    data_samples = read_dataset(data_fp, incl_labels=incl_labels, vocab=vocab)
    data_samples = iter(data_samples)
    batch_size = 64
    data_exists = True
    new_data_samples = []
    while data_exists:
        data_batch = []
        for _ in range(batch_size):
            try:
                data_batch.append(next(data_samples))
            except StopIteration:
                data_exists = False
        if data_batch:
            predictions = predictor.predict_batch_json(data_batch)
            for prediction, data_sample in zip(predictions, data_batch):
                data_sample['prediction'] = prediction['labels']
                new_data_samples.append(data_sample)
    return new_data_samples
Esempio n. 2
0
def batched_predict_json(predictor: Predictor,
                         examples: List[Dict[str, Any]],
                         batch_size: int = 16) -> List[Dict[str, Any]]:
    results = []  # type: List[Dict[str, Any]]
    for i in range(0, len(examples), batch_size):
        batch_examples = examples[i:i + batch_size]
        batch_results = predictor.predict_batch_json(batch_examples)
        results.extend(batch_results)
    return results