コード例 #1
0
def compute_util_acc(pred_cmd, gt_cmd):
    pred_ast = bash_parser(pred_cmd)
    gt_ast = bash_parser(gt_cmd)

    pred_utils = get_utility_nodes(pred_ast)
    gt_utils = get_utility_nodes(gt_ast)

    gt_utils, pred_utils = pad_arrays(gt_utils, pred_utils)

    corr = sum(get_utility_score(gu, pu) for gu, pu in zip(gt_utils, pred_utils))
    total = len(gt_utils)
    return corr/total if total != 0 else -1
コード例 #2
0
def compute_metric_loss(predicted_cmd, predicted_confidence, ground_truth_cmd,
                        metric_params):

    if type(predicted_cmd) is not str:
        predicted_cmd = str(predicted_cmd)
    if type(ground_truth_cmd) is not str:
        ground_truth_cmd = str(ground_truth_cmd)
    if type(predicted_confidence) is not float:
        try:
            predicted_confidence = float(predicted_confidence)
        except Exception:
            predicted_confidence = 1.0

    predicted_ast = bash_parser(predicted_cmd)
    ground_truth_ast = bash_parser(ground_truth_cmd)

    predicted_utilities = get_utility_nodes(predicted_ast)
    ground_truth_utilities = get_utility_nodes(ground_truth_ast)

    ground_truth_utilities, predicted_utilities = pad_arrays(
        ground_truth_utilities, predicted_utilities)

    util_losses = []
    flag_losses = []
    u1 = metric_params['u1']
    u2 = metric_params['u2']
    util_weight = u1 / (u1 + u2)
    flag_weight = u2 / (u1 + u2)

    for ground_truth_utility, predicted_utility in zip(ground_truth_utilities,
                                                       predicted_utilities):
        util_correct = get_utility_score(ground_truth_utility,
                                         predicted_utility)
        util_loss = 2 * (1 - util_correct
                         )  # loss is 2 when wrong, 0 when correct
        flag_loss = 1 - get_flag_score(ground_truth_utility, predicted_utility)
        flag_loss *= util_correct  # flag loss only applicable when util is correct

        util_losses.append(util_loss)
        flag_losses.append(flag_loss * flag_weight)

    util_loss_mean = 0.0 if len(util_losses) == 0 else np.mean(util_losses)
    util_loss_mean *= predicted_confidence
    flag_loss_mean = 0.0 if len(flag_losses) == 0 else np.mean(flag_losses)
    flag_loss_mean *= predicted_confidence
    return util_loss_mean, flag_loss_mean
コード例 #3
0
def compute_metric(predicted_cmd, predicted_confidence, ground_truth_cmd,
                   metric_params):

    if type(predicted_cmd) is not str:
        predicted_cmd = str(predicted_cmd)
    if type(ground_truth_cmd) is not str:
        ground_truth_cmd = str(ground_truth_cmd)
    if type(predicted_confidence) is not float:
        try:
            predicted_confidence = float(predicted_confidence)
        except Exception:
            predicted_confidence = 1.0

    predicted_ast = bash_parser(predicted_cmd)
    ground_truth_ast = bash_parser(ground_truth_cmd)

    predicted_utilities = get_utility_nodes(predicted_ast)
    ground_truth_utilities = get_utility_nodes(ground_truth_ast)

    ground_truth_utilities, predicted_utilities = pad_arrays(
        ground_truth_utilities, predicted_utilities)

    score = []
    u1 = metric_params['u1']
    u2 = metric_params['u2']

    for ground_truth_utility, predicted_utility in zip(ground_truth_utilities,
                                                       predicted_utilities):
        utility_score = get_utility_score(ground_truth_utility,
                                          predicted_utility)
        flag_score = get_flag_score(ground_truth_utility, predicted_utility)

        flag_score_normed = (u1 + u2 * flag_score) / (u1 + u2)
        prediction_score = predicted_confidence * (
            (utility_score * flag_score_normed) - (1 - utility_score))
        score.append(prediction_score)

    score_mean = 0.0 if len(score) == 0 else np.mean(score)
    return score_mean