Example #1
0
def print_dialogue_batch(learner: Learner,
                         modeldata: ModelData,
                         input_field,
                         output_field,
                         num_batches=1,
                         num_sentences=-1,
                         is_test=False,
                         num_beams=1,
                         smoothing_function=None,
                         weights=None):
    weights = (1 / 3., 1 / 3., 1 / 3.) if weights is None else weights
    smoothing_function = SmoothingFunction(
    ).method1 if smoothing_function is None else smoothing_function
    predictions, targets, inputs = learner.predict_with_targs_and_inputs(
        is_test=is_test, num_beams=num_beams)
    blue_scores = []
    for batch_num, (input, target,
                    prediction) in enumerate(zip(inputs, targets,
                                                 predictions)):
        input = np.transpose(
            input,
            [1, 2, 0])  # transpose number of utterances to beams [sl, bs, nb]
        inputs_str: BatchBeamTokens = modeldata.itos(input, input_field)
        inputs_str: List[str] = ["\n".join(conv) for conv in inputs_str]
        predictions_str: BatchBeamTokens = modeldata.itos(
            prediction, output_field)
        targets_str: BatchBeamTokens = modeldata.itos(target, output_field)
        for index, (inp, targ, pred) in enumerate(
                zip(inputs_str, targets_str, predictions_str)):
            if targ[0].split() == pred[0].split()[1:]:
                blue_score = 1
            else:
                blue_score = sentence_bleu(
                    [targ[0].split()],
                    pred[0].split()[1:],
                    smoothing_function=smoothing_function,
                    weights=weights)
            print(
                f'BATCH: {batch_num} SAMPLE : {index}\nINPUT:\n{"".join(inp)}\nTARGET:\n{ "".join(targ)}\nPREDICTON:\n{"".join(pred)}\nblue: {blue_score}\n\n'
            )
            blue_scores.append(blue_score)
            if 0 < num_sentences <= index - 1:
                break
        if 0 < num_batches <= batch_num - 1:
            break
    print(
        f'bleu score: mean: {np.mean(blue_scores)}, std: {np.std(blue_scores)}'
    )
Example #2
0
def print_batch(learner: Learner,
                modeldata: ModelData,
                input_field,
                output_field,
                num_batches=1,
                num_sentences=-1,
                is_test=False,
                num_beams=1,
                weights=None,
                smoothing_function=None):
    predictions, targets, inputs = learner.predict_with_targs_and_inputs(
        is_test=is_test, num_beams=num_beams)
    weights = (1 / 3., 1 / 3., 1 / 3.) if weights is None else weights
    smoothing_function = SmoothingFunction(
    ).method1 if smoothing_function is None else smoothing_function
    blue_scores = []
    for batch_num, (input, target,
                    prediction) in enumerate(zip(inputs, targets,
                                                 predictions)):
        inputs_str: BatchBeamTokens = modeldata.itos(input, input_field)
        predictions_str: BatchBeamTokens = modeldata.itos(
            prediction, output_field)
        targets_str: BatchBeamTokens = modeldata.itos(target, output_field)
        for index, (inp, targ, pred) in enumerate(
                zip(inputs_str, targets_str, predictions_str)):
            blue_score = sentence_bleu([targ],
                                       pred,
                                       smoothing_function=smoothing_function,
                                       weights=weights)
            print(
                f'batch: {batch_num} sample : {index}\ninput: {" ".join(inp)}\ntarget: { " ".join(targ)}\nprediction: {" ".join(pred)}\nbleu: {blue_score}\n\n'
            )
            blue_scores.append(blue_score)
            if 0 < num_sentences <= index - 1:
                break
        if 0 < num_batches <= batch_num - 1:
            break
    print(f'mean bleu score: {np.mean(blue_scores)}')