Exemplo n.º 1
0
def test_ted():
    while True:
        cmd1 = raw_input(">cmd1: ")
        cmd2 = raw_input(">cmd2: ")
        ast1 = data_tools.bash_parser(cmd1)
        ast2 = data_tools.bash_parser(cmd2)
        dist = zss.simple_distance(
            ast1, ast2, nast.Node.get_children, nast.Node.get_label,
            ast_based.temp_local_dist
        )
        print("ted = {}".format(dist))
        print()
Exemplo n.º 2
0
def min_dist(asts, ast2, rewrite=True, ignore_arg_value=False):
    """
    Compute the minimum tree edit distance of the prediction to the set of ground truth ASTs.
    :param asts: set of gold ASTs.
    :param ast2: predicted AST.
    :param rewrite: set to true if rewrite ground truths with templates.
    :param ignore_arg_value: set to true if ignore literal values in the ASTs.
    """
    # tolerate ungrammatical predictions
    if not ast2:
        ast2 = data_tools.bash_parser("find")

    if rewrite:
        with er.DBConnection() as db:
            ast_rewrites = get_rewrites(asts, db)
    else:
        ast_rewrites = asts

    with ea.DBConnection() as db:
        min_dist = sys.maxint
        for ast1 in ast_rewrites:
            # data_tools.pretty_print(ast1)
            # data_tools.pretty_print(ast2)
            if ignore_arg_value:
                dist = temp_dist(ast1, ast2)
            else:
                dist = str_dist(ast1, ast2)
            if dist < min_dist:
                min_dist = dist

    return min_dist
Exemplo n.º 3
0
def decode(output_symbols, rev_cm_vocab, FLAGS):
    batch_outputs = []

    for i in xrange(len(output_symbols)):
        top_k_predictions = output_symbols[i]
        assert((FLAGS.decoding_algorithm == "greedy") or 
               len(top_k_predictions) == FLAGS.beam_size)
        if FLAGS.decoding_algorithm == "beam_search":
            beam_outputs = []
        for j in xrange(len(top_k_predictions)):
            prediction = top_k_predictions[j]
            outputs = [int(pred) for pred in prediction]

            # If there is an EOS symbol in outputs, cut them at that point.
            if data_utils.EOS_ID in outputs:
                outputs = outputs[:outputs.index(data_utils.EOS_ID)]

            if FLAGS.decoder_topology == "rnn":
                if FLAGS.char:
                    cmd = "".join([tf.compat.as_str(rev_cm_vocab[output])
                               for output in outputs]).replace(data_utils._UNK, ' ')
                else:
                    tokens = []
                    for output in outputs:
                        if output < len(rev_cm_vocab):
                            pred_token = rev_cm_vocab[output]
                            if "@@" in pred_token:
                                pred_token = pred_token.split("@@")[-1]
                            tokens.append(pred_token)
                        else:
                            tokens.append(data_utils._UNK)
                    cmd = " ".join(tokens)

                if FLAGS.dataset in ["bash", "bash.cl"]:
                    cmd = re.sub('( ;\s+)|( ;$)', ' \\; ', cmd)
                    cmd = re.sub('( \)\s+)|( \)$)', ' \\) ', cmd)
                    cmd = re.sub('(^\( )|( \( )', ' \\( ', cmd)
                    tree = data_tools.bash_parser(cmd)
                else:
                    tree = data_tools.paren_parser(cmd)
                search_history = outputs
            else:
                tree, cmd, search_history = to_readable(outputs, rev_cm_vocab)
            if FLAGS.decoding_algorithm == "greedy":
                batch_outputs.append((tree, cmd, search_history))
            else:
                beam_outputs.append((tree, cmd, search_history))
        if FLAGS.decoding_algorithm == "beam_search":
            batch_outputs.append(beam_outputs)

    return batch_outputs
Exemplo n.º 4
0
def translate(request, ip_address):
    template = loader.get_template('translator/translate.html')
    if request.method == 'POST':
        request_str = request.POST.get('request_str')
    else:
        request_str = request.GET.get('request_str')

    if not request_str or not request_str.strip():
        return redirect('/')
    
    while request_str.endswith('/'):
        request_str = request_str[:-1]

    trans_list = []
    html_strs = []
    if CACHE_TRANSLATIONS and NLRequest.objects.filter(
            request_str=request_str).exists():
        # if the natural language request string has been translated before,
        # directly output previously cached translations
        if Translation.objects.filter(
                request__request_str=request_str).exists():
            # model translations exist
            cached_trans = Translation.objects.filter(
                request__request_str=request_str)
            for trans in cached_trans:
                print(trans.pred_cmd)
                pred_tree = data_tools.bash_parser(trans.pred_cmd)
                if pred_tree is not None:
                    trans_list.append(trans)
                    html_str = tokens2html(pred_tree)
                    html_strs.append(html_str)

    try:
        nl_request = NLRequest.objects.get(request_str=request_str)
    except ObjectDoesNotExist:
        nl_request = NLRequest.objects.create(request_str=request_str)

    try:
        user = User.objects.get(ip_address=ip_address)
    except ObjectDoesNotExist:
        r = requests.get('http://ipinfo.io/{}/json'.format(ip_address))
        organization = r.json()['org']
        city = r.json()['city']
        region = r.json()['region']
        country = r.json()['country']
        user = User.objects.create(
            ip_address=ip_address,
            organization=organization,
            city=city,
            region=region,
            country=country
        )

    # check if the natural language request has been issued by the IP
    # address before
    # if not, save the natural language request issued by this IP Address
    if not NLRequestIPAddress.objects.filter(
            request=nl_request, user=user).exists():
        NLRequestIPAddress.objects.create(
            request=nl_request, user=user)

    if not trans_list:
        if not WEBSITE_DEVELOP:
            # call learning model and store the translations
            batch_outputs, output_logits = translate_fun(request_str)

            if batch_outputs:
                top_k_predictions = batch_outputs[0]
                top_k_scores = output_logits[0]

                for i in range(len(top_k_predictions)):
                    pred_tree, pred_cmd, outputs = top_k_predictions[i]
                    score = top_k_scores[i]

                    trans = Translation.objects.create(
                        request=nl_request, pred_cmd=pred_cmd, score=score)

                    trans_list.append(trans)
                    html_str = tokens2html(pred_tree)
                    html_strs.append(html_str)

    translation_list = []
    for trans, html_str in zip(trans_list, html_strs):
        upvoted, downvoted, starred = "", "", ""
        if Vote.objects.filter(translation=trans, ip_address=ip_address).exists():
            v = Vote.objects.get(translation=trans, ip_address=ip_address)
            upvoted = 1 if v.upvoted else ""
            downvoted = 1 if v.downvoted else ""
            starred = 1 if v.starred else ""
        translation_list.append((trans, upvoted, downvoted, starred,
                             trans.pred_cmd.replace('\\', '\\\\'), html_str))

    # sort translation_list based on voting results
    translation_list.sort(
        key=lambda x: x[0].num_votes + x[0].score, reverse=True)
    context = {
        'nl_request': nl_request,
        'trans_list': translation_list
    }
    return HttpResponse(template.render(context, request))
Exemplo n.º 5
0
def cmd2html(cmd_str):
    """ A wrapper for the function ast2html (see below) that takes in a cmd string 
  and translate into a html string with highlinghting.
  """
    return " ".join(ast2html(data_tools.bash_parser(cmd_str)))