def predict_and_save(dataset: GroundedScanDataset,
                     model: nn.Module,
                     output_file_path: str,
                     max_decoding_steps: int,
                     max_testing_examples=None,
                     **kwargs):
    """
    Predict all data in dataset with a model and write the predictions to output_file_path.
    :param dataset: a dataset with test examples
    :param model: a trained model from model.py
    :param output_file_path: a path where a .json file with predictions will be saved.
    :param max_decoding_steps: after how many steps to force quit decoding
    :param max_testing_examples: after how many examples to stop predicting, if None all examples will be evaluated
    """
    cfg = locals().copy()

    with open(output_file_path, mode='w') as outfile:
        output = []
        with torch.no_grad():
            i = 0
            for (input_sequence, derivation_spec, situation_spec,
                 output_sequence, target_sequence, attention_weights_commands,
                 attention_weights_situations, position_accuracy) in predict(
                     dataset.get_data_iterator(batch_size=1),
                     model=model,
                     max_decoding_steps=max_decoding_steps,
                     pad_idx=dataset.target_vocabulary.pad_idx,
                     sos_idx=dataset.target_vocabulary.sos_idx,
                     eos_idx=dataset.target_vocabulary.eos_idx):
                i += 1
                accuracy = sequence_accuracy(output_sequence,
                                             target_sequence[0].tolist()[1:-1])
                input_str_sequence = dataset.array_to_sentence(
                    input_sequence[0].tolist(), vocabulary="input")
                input_str_sequence = input_str_sequence[
                    1:-1]  # Get rid of <SOS> and <EOS>
                target_str_sequence = dataset.array_to_sentence(
                    target_sequence[0].tolist(), vocabulary="target")
                target_str_sequence = target_str_sequence[
                    1:-1]  # Get rid of <SOS> and <EOS>
                output_str_sequence = dataset.array_to_sentence(
                    output_sequence, vocabulary="target")
                output.append({
                    "input": input_str_sequence,
                    "prediction": output_str_sequence,
                    "derivation": derivation_spec,
                    "target": target_str_sequence,
                    "situation": situation_spec,
                    "attention_weights_input": attention_weights_commands,
                    "attention_weights_situation":
                    attention_weights_situations,
                    "accuracy": accuracy,
                    "exact_match": True if accuracy == 100 else False,
                    "position_accuracy": position_accuracy
                })
        logger.info("Wrote predictions for {} examples.".format(i))
        json.dump(output, outfile, indent=4)
    return output_file_path
Exemple #2
0
def predict_and_save(dataset: GroundedScanDataset, model: nn.Module, output_file_path: str, max_decoding_steps: int,
                     max_testing_examples=None, **kwargs):
    """

    :param dataset:
    :param model:
    :param output_file_path:
    :param max_decoding_steps:
    :param max_testing_examples:
    :param kwargs:
    :return:
    """
    cfg = locals().copy()

    with open(output_file_path, mode='w') as outfile:
        output = []
        with torch.no_grad():
            i = 0
            for (input_sequence, derivation_spec, situation_spec, output_sequence, target_sequence,
                 attention_weights_commands, attention_weights_situations, _) in predict(
                    dataset.get_data_iterator(batch_size=1), model=model, max_decoding_steps=max_decoding_steps,
                    pad_idx=dataset.target_vocabulary.pad_idx, sos_idx=dataset.target_vocabulary.sos_idx,
                    eos_idx=dataset.target_vocabulary.eos_idx):
                i += 1
                accuracy = sequence_accuracy(output_sequence, target_sequence[0].tolist()[1:-1])
                input_str_sequence = dataset.array_to_sentence(input_sequence[0].tolist(), vocabulary="input")
                input_str_sequence = input_str_sequence[1:-1]  # Get rid of <SOS> and <EOS>
                target_str_sequence = dataset.array_to_sentence(target_sequence[0].tolist(), vocabulary="target")
                target_str_sequence = target_str_sequence[1:-1]  # Get rid of <SOS> and <EOS>
                output_str_sequence = dataset.array_to_sentence(output_sequence, vocabulary="target")
                output.append({"input": input_str_sequence, "prediction": output_str_sequence,
                               "derivation": derivation_spec,
                               "target": target_str_sequence, "situation": situation_spec,
                               "attention_weights_input": attention_weights_commands,
                               "attention_weights_situation": attention_weights_situations,
                               "accuracy": accuracy,
                               "exact_match": True if accuracy == 100 else False})
        logger.info("Wrote predictions for {} examples.".format(i))
        json.dump(output, outfile, indent=4)
    return output_file_path