예제 #1
0
def parse_all():
    global cmd
    paths = os.path.join(os.path.abspath(os.pardir), "bash", "all.cm")
    logger.info("Reading from " + paths)
    with open(paths, "r") as f:
        cmds = f.readlines()
    for cmd in cmds:
        data_tools.bash_parser(cmd, verbose=True)
예제 #2
0
def test_ted():
    while True:
        cmd1 = input(">cmd1: ")
        cmd2 = input(">cmd2: ")
        ast1 = data_tools.bash_parser(cmd1)
        ast2 = data_tools.bash_parser(cmd2)
        dist = zss.simple_distance(ast1, ast2, nast.Node.get_children,
                                   nast.Node.get_label,
                                   tree_dist.temp_local_dist)
        print("ted = {}".format(dist))
        print()
예제 #3
0
 def clean_rewrites(self):
     c = self.cursor
     non_grammatical = []
     for s1, s2 in c.execute("SELECT s1, s2 FROM Rewrites"):
         ast = data_tools.bash_parser(s1)
         if not ast:
             non_grammatical.append(s1)
         ast2 = data_tools.bash_parser(s2)
         if not ast2:
             non_grammatical.append(s2)
     for s in non_grammatical:
         print("Removing %s from Rewrites" % s)
         c.execute("DELETE FROM Rewrites WHERE s1 = ?", (s, ))
         c.execute("DELETE FROM Rewrites WHERE s2 = ?", (s, ))
예제 #4
0
def get_flag_statistics(top_utilities, input_file):
    flag_counts = {}
    for u in top_utilities:
        flag_counts[u] = set()
    with open(input_file) as f:
        for cmd in f:
            ast = data_tools.bash_parser(cmd, verbose=False)
            if ast:
                # DFS 
                stack = []
                stack.extend(ast.children)
                while stack:
                    node = stack.pop()
                    if node.is_option():
                        u = node.utility.value
                        if u in flag_counts:
                            flag_counts[u].add(node.value)
                    stack.extend(node.children)
            else:
                print(cmd)
    total_flag_count = 0
    for i in range(len(top_utilities)-1, -1, -1):
        u = top_utilities[i]
        print('{{axis:"{}",num_flags:{},num_flags_in_data:{}}},'
              .format(u, data_tools.get_utility_statistics(u), 
                      len(flag_counts[u])))
        total_flag_count += len(flag_counts[u])
    print('Total # distinct flags = {}'.format(total_flag_count))
예제 #5
0
def rewrite(ast, temp):
    """Rewrite an AST into an equivalent one using the given template."""
    arg_slots = lint.arg_slots(ast)

    def rewrite_fun(node):
        if node.kind == "argument" and not node.is_reserved():
            for i in xrange(len(arg_slots)):
                if not arg_slots[i][1] and arg_slots[i][
                        0].arg_type == node.arg_type:
                    node.value = arg_slots[i][0].value
                    arg_slots[i][1] = True
                    break
        else:
            for child in node.children:
                rewrite_fun(child)

    # TODO: improve the heurstics.
    # Step 1 constructs an AST using the given template.
    # Step 2 fills the argument slots in the newly constructed AST using the
    # argument values from the original AST.
    ast2 = data_tools.bash_parser(temp)
    if not ast2 is None:
        rewrite_fun(ast2)

    return ast2
예제 #6
0
def compute_top_utilities(path, k):
    print('computing top most frequent utilities...')
    utilities = collections.defaultdict(int)
    with open(path) as f:
        while (True):
            command = f.readline().strip()
            if not command:
                break
            ast = data_tools.bash_parser(command, verbose=False)
            for u in data_tools.get_utilities(ast):
                utilities[u] += 1
    top_utilities = []

    freq_threshold = -1
    for u, freq in sorted(utilities.items(), key=lambda x: x[1], reverse=True):
        if freq_threshold > 0 and freq < freq_threshold:
            break
        # Xingyu Wei: To udpate for word assignment and control flow,
        # we remove the Grey List
        if u in bash.BLACK_LIST:
            continue
        top_utilities.append(u)
        print('{}: {} ({})'.format(len(top_utilities), u, freq))
        if len(top_utilities) == k:
            freq_threshold = freq
    top_utilities = set(top_utilities)
    return top_utilities
예제 #7
0
def filter_by_most_frequent_utilities(data_dir, num_utilities):
    def select(ast, utility_set):
        for ut in data_tools.get_utilities(ast):
            if not ut in utility_set:
                print('Utility currently not handled: {} - {}'.format(
                    ut, data_tools.ast2command(ast, loose_constraints=True)))
                return False
        return True

    cm_path = os.path.join(data_dir, 'all.cm')
    top_utilities = compute_top_utilities(cm_path, num_utilities)
    for split in ['all']:
        nl_file_path = os.path.join(data_dir, split + '.nl')
        cm_file_path = os.path.join(data_dir, split + '.cm')
        with open(nl_file_path) as f:
            nls = [nl.strip() for nl in f.readlines()]
        with open(cm_file_path) as f:
            cms = [cm.strip() for cm in f.readlines()]
        nl_outfile_path = os.path.join(data_dir, split + '.nl.filtered')
        cm_outfile_path = os.path.join(data_dir, split + '.cm.filtered')
        with open(nl_outfile_path, 'w') as nl_outfile:
            with open(cm_outfile_path, 'w') as cm_outfile:
                for nl, cm in zip(nls, cms):
                    if len(nl.split()) > 50:
                        print('lenthy description skipped: {}'.format(nl))
                        continue
                    ast = data_tools.bash_parser(cm, verbose=True)
                    if ast and select(ast, top_utilities):
                        nl_outfile.write('{}\n'.format(nl))
                        cm_outfile.write('{}\n'.format(cm))
예제 #8
0
def min_dist(asts, ast2, rewrite=False, ignore_arg_value=False):
    """
    Compute the minimum tree edit distance of the prediction to the set of
        ground truth ASTs.
    :param asts: set of gold ASTs.
    :param ast2: predicted AST.
    :param rewrite: set to true if rewrite ground truths with templates.
    :param ignore_arg_value: set to true if ignore literal values in the ASTs.
    """
    # tolerate ungrammatical predictions
    if not ast2:
        ast2 = data_tools.bash_parser("find")

    if rewrite:
        raise NotImplementedError
    else:
        ast_rewrites = asts

    min_dist = 1e8
    for ast1 in ast_rewrites:
        if ignore_arg_value:
            dist = temp_dist(ast1, ast2)
        else:
            dist = str_dist(ast1, ast2)
        if dist < min_dist:
            min_dist = dist

    return min_dist
예제 #9
0
def combine_annotations_multi_files():
    """
    Combine multiple annotations files and discard the annotations that has a conflict.
    """

    input_dir = sys.argv[1]

    template_evals = {}
    command_evals = {}
    discarded_keys = set({})

    for in_csv in os.listdir(input_dir):
        in_csv_path = os.path.join(input_dir, in_csv)
        with open(in_csv_path) as f:
            reader = csv.DictReader(f)
            current_description = ''
            for row in reader:
                template_eval = normalize_judgement(row['correct template'])
                command_eval = normalize_judgement(row['correct command'])
                description = get_example_nl_key(row['description'])
                if description.strip():
                    current_description = description
                else:
                    description = current_description
                prediction = row['prediction']
                example_key = '{}<NL_PREDICTION>{}'.format(
                    description, prediction)
                if example_key in template_evals and template_evals[
                        example_key] != template_eval:
                    discarded_keys.add(example_key)
                    continue
                if example_key in command_evals and command_evals[
                        example_key] != command_eval:
                    discarded_keys.add(example_key)
                    continue
                template_evals[example_key] = template_eval
                command_evals[example_key] = command_eval
            print('{} read ({} manually annotated examples, {} discarded)'.
                  format(in_csv_path, len(template_evals),
                         len(discarded_keys)))

    # Write to new file
    assert (len(template_evals) == len(command_evals))
    with open('manual_annotations.additional', 'w') as o_f:
        o_f.write(
            'description,prediction,template,correct template,correct comand\n'
        )
        for key in sorted(template_evals.keys()):
            if key in discarded_keys:
                continue
            description, prediction = key.split('<NL_PREDICTION>')
            template_eval = template_evals[example_key]
            command_eval = command_evals[example_key]
            pred_tree = data_tools.bash_parser(prediction)
            pred_temp = data_tools.ast2template(pred_tree,
                                                loose_constraints=True)
            o_f.write('"{}","{}","{}",{},{}\n'.format(
                description.replace('"', '""'), prediction.replace('"', '""'),
                pred_temp.replace('"', '""'), template_eval, command_eval))
예제 #10
0
def populate_command_template():
    for cmd in Command.objects.all():
        if len(cmd.str) > 600:
            cmd.delete()
        else:
            ast = data_tools.bash_parser(cmd.str)
            template = data_tools.ast2template(ast, loose_constraints=True)
            cmd.template = template
            cmd.save()
예제 #11
0
def run():
    sqlite_filename = sys.argv[1]
    url_prefix = 'https://stackoverflow.com/questions/'

    urls = {}
    commands = {}

    with sqlite3.connect(sqlite_filename, detect_types=sqlite3.PARSE_DECLTYPES) as db:
        count = 0
        for post_id, answer_body in db.cursor().execute("""
                SELECT questions.Id, answers.Body FROM questions, answers
                WHERE questions.Id = answers.ParentId
                ORDER BY questions.Score DESC"""):
            print(post_id)
            for code_block in extract_code(answer_body):
                for cmd in extract_oneliner_from_code(code_block):
                    print('command string: {}'.format(cmd))
                    ast = data_tools.bash_parser(cmd)
                    if not ast:
                        continue
                    utilities = data_tools.get_utilities(ast)
                    for utility in utilities:
                        if utility in bash.top_100_utilities:
                            print('extracted: {}, {}'.format(utility, cmd))
                            temp = data_tools.ast2template(ast, loose_constraints=True)
                            if not utility in commands:
                                commands[utility] = {}
                                commands[utility][temp] = cmd
                                urls[utility] = {'{}{}'.format(url_prefix, post_id)}
                            else:
                                if len(commands[utility]) >= NUM_COMMAND_THRESHOLD:
                                    continue
                                if not temp in commands[utility]:
                                    commands[utility][temp] = cmd
                                    urls[utility].add('{}{}'.format(url_prefix, post_id))
            count += 1
            if count % 1000 == 0:
                completed = False
                for utility in bash.top_100_utilities:
                    if not utility in commands or len(commands[utility]) < NUM_COMMAND_THRESHOLD:
                        completed = False
                    else:
                        print('{} collection done.'.format(utility))

                if completed:
                    break

    with open('stackoverflow.urls', 'wb') as o_f:
        pickle.dump(urls, o_f)
    with open('stackoverflow.commands', 'wb') as o_f:
        pickle.dump(commands, o_f)

    for utility in commands:
        print('{} ({})'.format(utility, len(commands[utility])))
        for cmd in commands[utility]:
            print(cmd)
예제 #12
0
def get_u_hist_from_file(input_file):
    u_hist = collections.defaultdict(int)
    with open(input_file) as f:
        for cmd in f:
            ast = data_tools.bash_parser(cmd, verbose=False)
            for u in data_tools.get_utilities(ast):
                if u in bash.BLACK_LIST or u in bash.GREY_LIST:
                    continue
                u_hist[u] += 1
    return u_hist
예제 #13
0
def Cust_Cmd_Tokenizer(String, parse="Template"):
    """As per our need Custom CMD Tokenizer"""
    if parse == "Norm":
        Command = cm_to_partial_tokens(String,
                                       tokenizer=data_tools.bash_tokenizer)
    elif parse == "Template":
        AST = data_tools.bash_parser(String)
        Template = data_tools.ast2template(AST, ignore_flag_order=False)
    Template_Tokens_List = Template.split(" ")
    return Template_Tokens_List
예제 #14
0
def populate_command_tags():
    for cmd in Command.objects.all():
        if len(cmd.str) > 600:
            cmd.delete()
        else:
            cmd.tags.clear()
            print(cmd.str)
            ast = data_tools.bash_parser(cmd.str)
            for utility in data_tools.get_utilities(ast):
                print(utility)
                cmd.tags.add(get_tag(utility))
            cmd.save()
예제 #15
0
def get_command(command_str):
    command_str = command_str.strip()
    if Command.objects.filter(str=command_str).exists():
        cmd = Command.objects.get(str=command_str)
    else:
        cmd = Command.objects.create(str=command_str)
        ast = data_tools.bash_parser(command_str)
        for utility in data_tools.get_utilities(ast):
            cmd.tags.add(get_tag(utility))
        template = data_tools.ast2template(ast, loose_constraints=True)
        cmd.template = template
        cmd.save()
    return cmd
예제 #16
0
def compute_metric(predicted_cmd, predicted_confidence, ground_truth_cmd,
                   metric_params):

    if type(predicted_cmd) is not str:
        predicted_cmd = str(predicted_cmd)
    if type(ground_truth_cmd) is not str:
        ground_truth_cmd = str(ground_truth_cmd)
    if type(predicted_confidence) is not float:
        try:
            predicted_confidence = float(predicted_confidence)
        except Exception:
            predicted_confidence = 1.0

    predicted_ast = bash_parser(predicted_cmd)
    ground_truth_ast = bash_parser(ground_truth_cmd)

    predicted_utilities = get_utility_nodes(predicted_ast)
    ground_truth_utilities = get_utility_nodes(ground_truth_ast)

    ground_truth_utilities, predicted_utilities = pad_arrays(
        ground_truth_utilities, predicted_utilities)

    score = []
    u1 = metric_params['u1']
    u2 = metric_params['u2']

    for ground_truth_utility, predicted_utility in zip(ground_truth_utilities,
                                                       predicted_utilities):
        utility_score = get_utility_score(ground_truth_utility,
                                          predicted_utility)
        flag_score = get_flag_score(ground_truth_utility, predicted_utility)

        flag_score_normed = (u1 + u2 * flag_score) / (u1 + u2)
        prediction_score = predicted_confidence * (
            (utility_score * flag_score_normed) - (1 - utility_score))
        score.append(prediction_score)

    score_mean = 0.0 if len(score) == 0 else np.mean(score)
    return score_mean
def update_graph(cmd, graph):
    parsed = data_tools.bash_parser(cmd)
    child = parsed.children[0]
    if not isinstance(child, bashlint.nast.PipelineNode):
        return
    prev_name = ""
    for c in child.children:
        if c.is_utility():
            cur_name = c.value
        else:
            cur_name = ""
        if prev_name and cur_name:
            graph[prev_name].add(cur_name)
        prev_name = cur_name
def add_utilities(cmd, counter):
    def get_utilities_fun(node):
        utilities = []
        if node.is_utility():
            utilities.append(node.value)
            for child in node.children:
                utilities.extend(get_utilities_fun(child))
        elif not node.is_argument():
            for child in node.children:
                utilities.extend(get_utilities_fun(child))
        return utilities

    parsed = data_tools.bash_parser(cmd)
    utils = get_utilities_fun(parsed)
    counter.update(utils)
예제 #19
0
def compute_flag_stats():
    input_file = sys.argv[1]
    train_file = sys.argv[2]

    u_hist = collections.defaultdict(int)
    with open(input_file) as f:
        for cmd in f:
            ast = data_tools.bash_parser(cmd, verbose=False)
            for u in data_tools.get_utilities(ast):
                if u in bash.BLACK_LIST or u in bash.GREY_LIST:
                    continue
                u_hist[u] += 1

    sorted_u_by_freq = sorted(u_hist.items(), key=lambda x: x[1], reverse=True)
    most_frequent_10 = [u for u, _ in sorted_u_by_freq[:10]]
    least_frequent_10 = [u for u, _ in sorted_u_by_freq[-10:]]

    most_frequent_10_flags = collections.defaultdict(set)
    least_frequent_10_flags = collections.defaultdict(set)
    with open(train_file) as f:
        for cmd in f:
            tokens = data_tools.bash_tokenizer(cmd,
                                               loose_constraints=True,
                                               with_flag_head=True)
            for token in tokens:
                if '@@' in token:
                    u, f = token.split('@@')
                    if u in most_frequent_10:
                        most_frequent_10_flags[u].add(f)
                    if u in least_frequent_10:
                        least_frequent_10_flags[u].add(f)

    for u in most_frequent_10:
        if u in most_frequent_10_flags:
            print(u, data_tools.get_utility_statistics(u),
                  len(most_frequent_10_flags[u]))
        else:
            print(u, data_tools.get_utility_statistics(u), 0)
    print()
    for u in least_frequent_10:
        if u in least_frequent_10_flags:
            print(u, data_tools.get_utility_statistics(u),
                  len(least_frequent_10_flags[u]))
        else:
            print(u, data_tools.get_utility_statistics(u), 0)
예제 #20
0
def gen_non_specific_description_check_csv(data_dir):
    with open(os.path.join(data_dir, 'all.nl')) as f:
        nl_list = [nl.strip() for nl in f.readlines()]
    with open(os.path.join(data_dir, 'all.cm')) as f:
        cm_list = [cm.strip() for cm in f.readlines()]
    assert (len(nl_list) == len(cm_list))

    with open('annotation_check_sheet.non.specific.csv', 'w') as o_f:
        o_f.write('Utility,Command,Description\n')
        for nl, cm in zip(nl_list, cm_list):
            if ' specific ' in nl or ' a file ' in nl or ' a folder ' in nl \
                    or ' a directory ' in nl or ' some ' in nl \
                    or ' a pattern ' in nl or ' a string ' in nl:
                ast = data_tools.bash_parser(cm)
                if ast:
                    o_f.write(',"{}","{}"\n'.format(cm.replace('"', '""'),
                                                    nl.replace('"', '""')))
                    o_f.write(',,<Type a new description here>\n')
예제 #21
0
def u_hist_to_radar_chart():
    input_file = sys.argv[1]

    u_hist = collections.defaultdict(int)
    with open(input_file) as f:
        for cmd in f:
            ast = data_tools.bash_parser(cmd, verbose=False)
            for u in data_tools.get_utilities(ast):
                if u in bash.BLACK_LIST or u in bash.GREY_LIST:
                    continue
                u_hist[u] += 1

    selected_utilities = []
    for i, (u, freq) in enumerate(
            sorted(u_hist.items(), key=lambda x: x[1], reverse=True)):
        if i >= 50:
            print('{{axis:"{}",value:{:.2f}}},'.format(u, freq))
            selected_utilities.append(u)
    print()
예제 #22
0
def load_commands_in_url(stackoverflow_dump_path):
    url_prefix = 'https://stackoverflow.com/questions/'
    with sqlite3.connect(stackoverflow_dump_path,
                         detect_types=sqlite3.PARSE_DECLTYPES) as db:
        for url in URL.objects.all():
            # url = URL.objects.get(str='https://stackoverflow.com/questions/127669')
            url.commands.clear()
            print(url.str)
            for answer_body, in db.cursor().execute("""
                    SELECT answers.Body FROM answers 
                    WHERE answers.ParentId = ?""", (url.str[len(url_prefix):],)):
                url.html_content = answer_body
                for code_block in extract_code(url.html_content):
                    for cmd in extract_oneliners_from_code(code_block):
                        ast = data_tools.bash_parser(cmd)
                        if ast:
                            command = get_command(cmd)                        
                            print('extracted: {}'.format(cmd))
                            url.commands.add(command)
            url.save()
예제 #23
0
def stable_slot_filling(template_tokens,
                        sc_fillers,
                        tg_slots,
                        pointer_targets,
                        encoder_outputs,
                        decoder_outputs,
                        slot_filling_classifier,
                        verbose=False):
    """
    Fills the argument slots using learnt local alignment scores and a greedy 
    global alignment algorithm (stable marriage).

    :param template_tokens: list of tokens in the command template
    :param sc_fillers: the slot fillers extracted from the source sequence,
        indexed by token id
    :param tg_slots: the argument slots in the command template, indexed by
        token id
    :param pointer_targets: [encoder_length, decoder_length], local alignment
        scores between source and target tokens
    :param encoder_outputs: [encoder_length, dim] sequence of encoder hidden states
    :param decoder_outputs: [decoder_length, dim] sequence of decoder hidden states
    :param slot_filling_classifier: the classifier that produces the local
        alignment scores
    :param verbose: print all local alignment scores if set to true
    """

    # Step a): prepare (binary) type alignment matrix based on type info
    M = np.zeros([len(encoder_outputs), len(decoder_outputs)], dtype=np.int32)
    for f in sc_fillers:
        assert (f <= len(encoder_outputs))
        surface, filler_type = sc_fillers[f]
        matched = False
        for s in tg_slots:
            assert (s <= len(decoder_outputs))
            slot_value, slot_type = tg_slots[s]
            if slot_filler_type_match(slot_type, filler_type):
                M[f, s] = 1
                matched = True
        if not matched:
            # If no target slot can hold a source filler, skip the alignment
            # step and return None
            return None, None, None

    # Step b): compute local alignment scores if they are not provided already
    if pointer_targets is None:
        assert (encoder_outputs is not None)
        assert (decoder_outputs is not None)
        assert (slot_filling_classifier is not None)
        pointer_targets = np.zeros(
            [len(encoder_outputs), len(decoder_outputs)])
        for f in xrange(M.shape[0]):
            if np.sum(M[f]) > 1:
                X = []
                # use reversed index for the encoder embeddings matrix
                ff = len(encoder_outputs) - f - 1
                cm_slots_keys = list(tg_slots.keys())
                for s in cm_slots_keys:
                    X.append(
                        np.concatenate([
                            encoder_outputs[ff:ff + 1],
                            decoder_outputs[s:s + 1]
                        ],
                                       axis=1))
                X = np.concatenate(X, axis=0)
                X = X / norm(X, axis=1)[:, None]
                raw_scores = slot_filling_classifier.predict(X)
                for ii in xrange(len(raw_scores)):
                    s = cm_slots_keys[ii]
                    pointer_targets[f, s] = raw_scores[ii]
                    if verbose:
                        print('• alignment ({}, {}): {}\t{}\t{}'.format(
                            f, s, sc_fillers[f], tg_slots[s], raw_scores[ii]))

    M = M + M * pointer_targets
    # convert M into a dictinary representation of a sparse matrix
    M_dict = collections.defaultdict(dict)
    for i in xrange(M.shape[0]):
        if np.sum(M[i]) > 0:
            for j in xrange(M.shape[1]):
                if M[i, j] > 0:
                    M_dict[i][j] = M[i, j]

    mappings, remained_fillers = stable_marriage_alignment(M_dict)

    if not remained_fillers:
        for f, s in mappings:
            template_tokens[s] = get_fill_in_value(tg_slots[s], sc_fillers[f])
        cmd = ' '.join(template_tokens)
        tree = data_tools.bash_parser(cmd)
        if not tree is None:
            data_tools.fill_default_value(tree)
        temp = data_tools.ast2command(tree,
                                      loose_constraints=True,
                                      ignore_flag_order=False)
    else:
        tree, temp = None, None

    return tree, temp, mappings
예제 #24
0
def clean_cmd(cmd):
    cmd = _clean_cmd(bash_parser(cmd)).replace('::;', '').replace('::+', '')
    cmd = cmd.strip()
    cmd = re.sub('\s+', ' ', cmd)
    return cmd
예제 #25
0
def get_manual_evaluation_metrics(grouped_dataset,
                                  prediction_list,
                                  FLAGS,
                                  num_examples=-1,
                                  interactive=True,
                                  verbose=True):

    if len(grouped_dataset) != len(prediction_list):
        raise ValueError("ground truth and predictions length must be equal: "
                         "{} vs. {}".format(len(grouped_dataset),
                                            len(prediction_list)))

    # Get dev set samples (fixed)
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    if num_examples > 0:
        sample_ids = example_ids[:num_examples]
    else:
        sample_ids = example_ids

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'), verbose=True)

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    # Interactive manual evaluation
    num_t_top_1_correct = 0.0
    num_f_top_1_correct = 0.0
    num_t_top_3_correct = 0.0
    num_f_top_3_correct = 0.0

    for exam_id, example_id in enumerate(sample_ids):
        data_group = grouped_dataset[example_id][1]
        sc_txt = data_group[0].sc_txt.strip()
        sc_key = get_example_nl_key(sc_txt)
        command_gts = [dp.tg_txt for dp in data_group]
        command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts]
        predictions = prediction_list[example_id]
        top_3_s_correct_marked = False
        top_3_f_correct_marked = False
        for i in xrange(min(3, len(predictions))):
            pred_cmd = predictions[i]
            pred_ast = cmd_parser(pred_cmd)
            pred_temp = data_tools.ast2template(pred_ast,
                                                loose_constraints=True)
            temp_match = tree_dist.one_match(command_gt_asts,
                                             pred_ast,
                                             ignore_arg_value=True)
            str_match = tree_dist.one_match(command_gt_asts,
                                            pred_ast,
                                            ignore_arg_value=False)
            # Match ground truths & exisitng judgements
            command_example_key = '{}<NL_PREDICTION>{}'.format(
                sc_key, pred_cmd)
            structure_example_key = '{}<NL_PREDICTION>{}'.format(
                sc_key, pred_temp)
            command_eval, structure_eval = '', ''
            if str_match:
                command_eval = 'y'
                structure_eval = 'y'
            elif temp_match:
                structure_eval = 'y'
            if command_eval_cache and command_example_key in command_eval_cache:
                command_eval = command_eval_cache[command_example_key]
            if structure_eval_cache and structure_example_key in structure_eval_cache:
                structure_eval = structure_eval_cache[structure_example_key]
            # Prompt for new judgements
            if command_eval != 'y':
                if structure_eval == 'y':
                    if not command_eval and interactive:
                        print('#{}. {}'.format(exam_id, sc_txt))
                        for j, gt in enumerate(command_gts):
                            print('- GT{}: {}'.format(j, gt))
                        print('> {}'.format(pred_cmd))
                        command_eval = input('CORRECT COMMAND? [y/reason] ')
                        add_judgement(FLAGS.data_dir, sc_txt, pred_cmd,
                                      structure_eval, command_eval)
                        print()
                else:
                    if not structure_eval and interactive:
                        print('#{}. {}'.format(exam_id, sc_txt))
                        for j, gt in enumerate(command_gts):
                            print('- GT{}: {}'.format(j, gt))
                        print('> {}'.format(pred_cmd))
                        structure_eval = input(
                            'CORRECT STRUCTURE? [y/reason] ')
                        if structure_eval == 'y':
                            command_eval = input(
                                'CORRECT COMMAND? [y/reason] ')
                        add_judgement(FLAGS.data_dir, sc_txt, pred_cmd,
                                      structure_eval, command_eval)
                        print()
                structure_eval_cache[structure_example_key] = structure_eval
                command_eval_cache[command_example_key] = command_eval
            if structure_eval == 'y':
                if i == 0:
                    num_t_top_1_correct += 1
                if not top_3_s_correct_marked:
                    num_t_top_3_correct += 1
                    top_3_s_correct_marked = True
            if command_eval == 'y':
                if i == 0:
                    num_f_top_1_correct += 1
                if not top_3_f_correct_marked:
                    num_f_top_3_correct += 1
                    top_3_f_correct_marked = True

    metrics = {}
    acc_f_1 = num_f_top_1_correct / len(sample_ids)
    acc_f_3 = num_f_top_3_correct / len(sample_ids)
    acc_t_1 = num_t_top_1_correct / len(sample_ids)
    acc_t_3 = num_t_top_3_correct / len(sample_ids)
    metrics['acc_f'] = [acc_f_1, acc_f_3]
    metrics['acc_t'] = [acc_t_1, acc_t_3]

    if verbose:
        print('{} examples evaluated'.format(len(sample_ids)))
        print('Top 1 Command Acc = {:.3f}'.format(acc_f_1))
        print('Top 3 Command Acc = {:.3f}'.format(acc_f_3))
        print('Top 1 Template Acc = {:.3f}'.format(acc_t_1))
        print('Top 3 Template Acc = {:.3f}'.format(acc_t_3))
    return metrics
예제 #26
0
def translate(request, ip_address):
    template = loader.get_template('translator/translate.html')
    if request.method == 'POST':
        request_str = request.POST.get('request_str')
    else:
        request_str = request.GET.get('request_str')

    if not request_str or not request_str.strip():
        return redirect('/')

    while request_str.endswith('/'):
        request_str = request_str[:-1]

    # check if the natural language request is in the database
    nl = get_nl(request_str)

    trans_list = []
    annotated_trans_list = []

    if CACHE_TRANSLATIONS and \
            Translation.objects.filter(nl=nl).exists():
        # model translations exist
        cached_trans = Translation.objects.filter(nl=nl).order_by('score')
        count = 0
        for trans in cached_trans:
            pred_tree = data_tools.bash_parser(trans.pred_cmd.str)
            if pred_tree is not None:
                trans_list.append(trans)
                annotated_trans_list.append(tokens2html(pred_tree))
            count += 1
            if count >= NUM_TRANSLATIONS:
                break

    # check if the user is in the database
    try:
        user = User.objects.get(ip_address=ip_address)
    except ObjectDoesNotExist:
        if ip_address == '123.456.789.012':
            organization = ''
            city = '--'
            region = '--'
            country = '--'
        else:
            r = requests.get('http://ipinfo.io/{}/json'.format(ip_address))
            organization = '' if r.json()['org'] is None else r.json()['org']
            city = '--' if r.json()['city'] is None else r.json()['city']
            region = '--' if r.json()['region'] is None else r.json()['region']
            country = '--' if r.json()['country'] is None else r.json(
            )['country']
        user = User.objects.create(ip_address=ip_address,
                                   organization=organization,
                                   city=city,
                                   region=region,
                                   country=country)

    # save the natural language request issued by this IP Address
    nl_request = NLRequest.objects.create(nl=nl, user=user)

    if not trans_list:
        if not WEBSITE_DEVELOP:
            # call learning model and store the translations
            batch_outputs, output_logits = translate_fun(request_str)

            if batch_outputs:
                top_k_predictions = batch_outputs[0]
                top_k_scores = output_logits[0]

                for i in range(len(top_k_predictions)):
                    pred_tree, pred_cmd = top_k_predictions[i]
                    score = top_k_scores[i]
                    cmd = get_command(pred_cmd)
                    trans_set = Translation.objects.filter(nl=nl, pred_cmd=cmd)
                    if not trans_set.exists():
                        trans = Translation.objects.create(nl=nl,
                                                           pred_cmd=cmd,
                                                           score=score)
                    else:
                        for trans in trans_set:
                            break
                        trans.score = score
                        trans.save()
                    trans_list.append(trans)
                    start_time = time.time()
                    annotated_trans_list.append(tokens2html(pred_tree))
                    print(time.time() - start_time)
                    start_time = time.time()

    translation_list = []
    for trans, annotated_cmd in zip(trans_list, annotated_trans_list):
        upvoted, downvoted, starred = "", "", ""
        if Vote.objects.filter(translation=trans,
                               ip_address=ip_address).exists():
            v = Vote.objects.get(translation=trans, ip_address=ip_address)
            upvoted = 1 if v.upvoted else ""
            downvoted = 1 if v.downvoted else ""
            starred = 1 if v.starred else ""
        translation_list.append(
            (trans, upvoted, downvoted, starred,
             trans.pred_cmd.str.replace('\\', '\\\\'), annotated_cmd))

    # sort translation_list based on voting results
    translation_list.sort(key=lambda x: x[0].num_votes + x[0].score,
                          reverse=True)
    context = {'nl_request': nl_request, 'trans_list': translation_list}
    return HttpResponse(template.render(context, request))
예제 #27
0
def cmd2html(cmd_str):
    """ A wrapper for the function ast2html (see below) that takes in a cmd string 
  and translate into a html string with highlinghting.
  """
    return " ".join(ast2html(data_tools.bash_parser(cmd_str)))
예제 #28
0
def decode(encoder_full_inputs,
           model_outputs,
           FLAGS,
           vocabs,
           sc_fillers=None,
           slot_filling_classifier=None):
    """
    Transform the neural network output into readable strings and apply output
    filtering (if any).
    :param encoder_inputs:
    :param model_outputs:
    :param FLAGS:
    :param vocabs:
    :param sc_fillers:
    :param slot_filling_classifier:
    :return batch_outputs: nested list of (target_ast, target) tuples
        - target_ast is a python tree object for target languages that we know
          how to parse and a dummy string for those we don't
        - target is the output string
    """
    rev_sc_vocab = vocabs.rev_sc_vocab
    rev_tg_vocab = vocabs.rev_tg_vocab
    rev_sc_char_vocab = vocabs.rev_sc_char_vocab
    rev_tg_char_vocab = vocabs.rev_tg_char_vocab

    encoder_outputs = model_outputs.encoder_hidden_states
    decoder_outputs = model_outputs.decoder_hidden_states
    # print("encoder_outputs.shape = {}".format(encoder_outputs.shape))
    # print("decoder_outputs.shape = {}".format(decoder_outputs.shape))

    if FLAGS.fill_argument_slots:
        assert (sc_fillers is not None)
        assert (slot_filling_classifier is not None)
        assert (encoder_outputs is not None)
        assert (decoder_outputs is not None)

    output_symbols = model_outputs.output_symbols
    batch_size = len(output_symbols)
    batch_outputs = []
    num_output_examples = 0

    # Prepare copied indices if the model is trained with explicit copy
    # alignments.
    if FLAGS.use_copy and FLAGS.copy_fun == 'supervised':
        pointers = model_outputs.pointers
        sc_length = pointers.shape[1]
        tg_length = pointers.shape[2]
        if FLAGS.token_decoding_algorithm == 'greedy':
            batch_pointers = np.reshape(pointers,
                                        [batch_size, 1, sc_length, tg_length])
        else:
            batch_pointers = np.reshape(
                pointers, [batch_size, FLAGS.beam_size, sc_length, tg_length])

    for batch_id in xrange(batch_size):

        def as_str(output, r_sc_vocab, r_tg_vocab):
            if output < FLAGS.tg_vocab_size:
                token = r_tg_vocab[output]
            else:
                if FLAGS.use_copy and FLAGS.copy_fun == 'copynet':
                    token = r_sc_vocab[encoder_full_inputs[
                        len(encoder_full_inputs) - 1 -
                        (output - FLAGS.tg_vocab_size)][batch_id]]
                else:
                    return data_utils._UNK
            return token

        top_k_predictions = output_symbols[batch_id]
        if FLAGS.token_decoding_algorithm == 'beam_search':
            assert (len(top_k_predictions) == FLAGS.beam_size)
            beam_outputs = []
        else:
            # pack greedy decoding results into size-1 beam
            top_k_predictions = [top_k_predictions]

        for beam_id in xrange(len(top_k_predictions)):
            # Step 1: transform the neural network output into readable strings
            prediction = top_k_predictions[beam_id]
            outputs = [int(pred) for pred in prediction]

            # If there is an EOS symbol in outputs, cut them at that point.
            if data_utils.EOS_ID in outputs:
                outputs = outputs[:outputs.index(data_utils.EOS_ID)]

            if FLAGS.char:
                target = ''.join([
                    as_str(output, rev_sc_char_vocab, rev_tg_char_vocab)
                    for output in outputs
                ]).replace(constants._SPACE, ' ')
            else:
                output_tokens = []
                tg_slots = {}
                for token_id in xrange(len(outputs)):
                    output = outputs[token_id]
                    pred_token = as_str(output, rev_sc_vocab, rev_tg_vocab)
                    if pred_token.startswith('__ARG__'):
                        pred_token = pred_token[len('__ARG__'):]
                    if '@@' in pred_token:
                        pred_token = pred_token.split('@@')[-1]
                    # process argument slots
                    if pred_token in constants._ENTITIES:
                        if token_id > 0 and slot_filling.is_min_flag(
                                rev_tg_vocab[outputs[token_id - 1]]):
                            pred_token_type = 'Timespan'
                        else:
                            pred_token_type = pred_token
                        tg_slots[token_id] = (pred_token, pred_token_type)
                    output_tokens.append(pred_token)

                if FLAGS.partial_token:
                    # process partial-token outputs
                    merged_output_tokens = []
                    buffer = ''
                    load_buffer = False
                    for token in output_tokens:
                        if load_buffer:
                            if token == data_utils._ARG_END:
                                merged_output_tokens.append(buffer)
                                load_buffer = False
                                buffer = ''
                            else:
                                buffer += token
                        else:
                            if token == data_utils._ARG_START:
                                load_buffer = True
                            else:
                                merged_output_tokens.append(token)
                    output_tokens = merged_output_tokens

                target = ' '.join(output_tokens)
            # Step 2: check if the predicted command template is grammatical
            if FLAGS.grammatical_only and not FLAGS.explain:
                if FLAGS.dataset.startswith('bash'):
                    target = re.sub('( ;\s+)|( ;$)', ' \\; ', target)
                    target_ast = data_tools.bash_parser(target)
                elif FLAGS.dataset.startswith('regex'):
                    # TODO: check if a predicted regular expression is legal
                    target_ast = '__DUMMY_TREE__'
                else:
                    target_ast = data_tools.paren_parser(target)
                # filter out non-grammatical output
                if target_ast is None:
                    continue
            else:
                target_ast = '__DUMMY_TREE__'

            # Step 3: check if the predicted command templates have enough
            # slots to hold the fillers (to rule out templates that are
            # trivially unqualified)
            output_example = False
            if FLAGS.explain or not FLAGS.dataset.startswith('bash') \
                    or not FLAGS.normalized:
                output_example = True
            else:
                # Step 3: match the fillers to the argument slots
                batch_sc_fillers = sc_fillers[batch_id]
                if len(tg_slots) >= len(batch_sc_fillers):
                    if FLAGS.use_copy and FLAGS.copy_fun == 'supervised':
                        target_ast, target, _ = slot_filling.stable_slot_filling(
                            output_tokens,
                            batch_sc_fillers,
                            tg_slots,
                            batch_pointers[batch_id, beam_id, :, :],
                            None,
                            None,
                            None,
                            verbose=False)
                    elif FLAGS.fill_argument_slots:
                        target_ast, target, _ = slot_filling.stable_slot_filling(
                            output_tokens,
                            batch_sc_fillers,
                            tg_slots,
                            None,
                            encoder_outputs[batch_id],
                            decoder_outputs[batch_id * FLAGS.beam_size +
                                            beam_id],
                            slot_filling_classifier,
                            verbose=False)
                    else:
                        output_example = True
                    if not output_example and (target_ast is not None):
                        output_example = True

            if output_example:
                if FLAGS.token_decoding_algorithm == 'greedy':
                    batch_outputs.append((target_ast, target))
                else:
                    beam_outputs.append((target_ast, target))
                num_output_examples += 1

            # The threshold is used to increase decoding speed
            if num_output_examples == 20:
                break

        if FLAGS.token_decoding_algorithm == 'beam_search':
            if beam_outputs:
                batch_outputs.append(beam_outputs)

    # Step 4: apply character decoding
    if FLAGS.tg_char:
        char_output_symbols = model_outputs.char_output_symbols
        sentence_length = char_output_symbols.shape[0]
        batch_char_outputs = []
        batch_char_predictions = \
            [np.transpose(np.reshape(x, [sentence_length, FLAGS.beam_size,
                                         FLAGS.max_tg_token_size + 1]),x
                          (1, 0, 2))
             for x in np.split(char_output_symbols, batch_size, 1)]
        for batch_id in xrange(len(batch_char_predictions)):
            beam_char_outputs = []
            top_k_char_predictions = batch_char_predictions[batch_id]
            for k in xrange(len(top_k_char_predictions)):
                top_k_char_prediction = top_k_char_predictions[k]
                sent = []
                for i in xrange(sentence_length):
                    word = ''
                    for j in xrange(FLAGS.max_tg_token_size):
                        char_prediction = top_k_char_prediction[i, j]
                        if char_prediction == data_utils.CEOS_ID or \
                            char_prediction == data_utils.CPAD_ID:
                            break
                        elif char_prediction in rev_tg_char_vocab:
                            word += rev_tg_char_vocab[char_prediction]
                        else:
                            word += data_utils._CUNK
                    sent.append(word)
                if data_utils._CATOM in sent:
                    sent = sent[:sent[:].index(data_utils._CATOM)]
                beam_char_outputs.append(' '.join(sent))
            batch_char_outputs.append(beam_char_outputs)
        return batch_outputs, batch_char_outputs
    else:
        return batch_outputs
예제 #29
0
def decode(model_outputs, FLAGS, vocabs, sc_fillers=None,
           slot_filling_classifier=None, copy_tokens=None):
    """
    Transform the neural network output into readable strings and apply output
    filtering (if any).
    :param encoder_inputs:
    :param model_outputs:
    :param FLAGS:
    :param vocabs:
    :param sc_fillers:
    :param slot_filling_classifier:
    :return batch_outputs: nested list of (target_ast, target) tuples
        - target_ast is a python tree object for target languages that we know
          how to parse and a dummy string for those we don't
        - target is the output string
    """
    rev_tg_vocab = vocabs.rev_tg_vocab

    encoder_outputs = model_outputs.encoder_hidden_states
    decoder_outputs = model_outputs.decoder_hidden_states

    if FLAGS.fill_argument_slots:
        assert(sc_fillers is not None)
        assert(slot_filling_classifier is not None)
        assert(encoder_outputs is not None)
        assert(decoder_outputs is not None)

    output_symbols = model_outputs.output_symbols
    batch_size = len(output_symbols)
    batch_outputs = []
    num_output_examples = 0

    for batch_id in xrange(batch_size):
        def as_str(output, r_tg_vocab):
            if output < FLAGS.tg_vocab_size:
                token = r_tg_vocab[output]
            else:
                if FLAGS.use_copy and FLAGS.copy_fun == 'copynet':
                    source_id = output - FLAGS.tg_vocab_size
                    if source_id >= 0 and source_id < len(copy_tokens[batch_id]):
                        token = copy_tokens[batch_id][source_id]
                    else:
                        return data_utils._UNK
                else:
                    return data_utils._UNK
            return token

        top_k_predictions = output_symbols[batch_id]
        if FLAGS.token_decoding_algorithm == 'beam_search':
            assert(len(top_k_predictions) == FLAGS.beam_size)
            beam_outputs = []
        else:
            # pack greedy decoding results into size-1 beam
            top_k_predictions = [top_k_predictions]

        for beam_id in xrange(len(top_k_predictions)):
            # Step 1: transform the neural network output into readable strings
            prediction = top_k_predictions[beam_id]
            outputs = [int(pred) for pred in prediction]
            
            # If there is an EOS symbol in outputs, cut them at that point.
            if data_utils.EOS_ID in outputs:
                outputs = outputs[:outputs.index(data_utils.EOS_ID)]
            if data_utils.PAD_ID in outputs:
                outputs = outputs[:outputs.index(data_utils.PAD_ID)]
            output_tokens = []
            tg_slots = {}
            for token_id in xrange(len(outputs)):
                output = outputs[token_id]
                pred_token = as_str(output, rev_tg_vocab)
                if data_tools.flag_suffix in pred_token:
                    pred_token = pred_token.split(data_tools.flag_suffix)[0]
                # process argument slots
                if pred_token in bash.argument_types:
                    if token_id > 0 and format_args.is_min_flag(
                        rev_tg_vocab[outputs[token_id-1]]):
                        pred_token_type = 'Timespan'
                    else:
                        pred_token_type = pred_token
                    tg_slots[token_id] = (pred_token, pred_token_type)
                output_tokens.append(pred_token)

            if FLAGS.channel == 'partial.token':
                # process partial-token outputs
                merged_output_tokens = []
                buffer = ''
                load_buffer = False
                for token in output_tokens:
                    if load_buffer:
                        if token == data_utils._ARG_END:
                            merged_output_tokens.append(buffer)
                            load_buffer = False
                            buffer = ''
                        else:
                            buffer += token
                    else:
                        if token == data_utils._ARG_START:
                            load_buffer = True
                        else:
                            merged_output_tokens.append(token)
                if buffer:
                    merged_output_tokens.append(buffer)
                output_tokens = merged_output_tokens
    
            if FLAGS.channel == 'char':
                target = ''
                for char in output_tokens:
                    if char == data_utils.constants._SPACE:
                        target += ' '
                    else:
                        target += char
            else:
                target = ' '.join(output_tokens)
            
            # Step 2: checvik if the predicted command template is grammatical
            if FLAGS.grammatical_only and not FLAGS.explain:
                if FLAGS.dataset.startswith('bash'):
                    target = re.sub('( ;\s+)|( ;$)', ' \\; ', target)
                    target_ast = data_tools.bash_parser(target, verbose=False)
                elif FLAGS.dataset.startswith('regex'):
                    # TODO: check if a predicted regular expression is legal
                    target_ast = '__DUMMY_TREE__'
                else:
                    target_ast = data_tools.paren_parser(target)
                # filter out non-grammatical output
                if target_ast is None:
                    continue
            else:
                target_ast = '__DUMMY_TREE__'

            # Step 3: check if the predicted command templates have enough
            # slots to hold the fillers (to rule out templates that are
            # trivially unqualified)
            output_example = False
            if FLAGS.explain or not FLAGS.dataset.startswith('bash') \
                    or not FLAGS.normalized:
                output_example = True
            else:
                # Step 3: match the fillers to the argument slots
                batch_sc_fillers = sc_fillers[batch_id]
                if len(tg_slots) >= len(batch_sc_fillers):
                    if FLAGS.fill_argument_slots:
                        target_ast, target, _ = slot_filling.stable_slot_filling(
                            output_tokens, batch_sc_fillers, tg_slots, None,
                            encoder_outputs[batch_id],
                            decoder_outputs[batch_id*FLAGS.beam_size+beam_id],
                            slot_filling_classifier, verbose=False)
                    else:
                        output_example = True
                    if not output_example and (target_ast is not None):
                        output_example = True

            if output_example:
                if FLAGS.token_decoding_algorithm == 'greedy':
                    batch_outputs.append((target_ast, target))
                else:
                    beam_outputs.append((target_ast, target))
                num_output_examples += 1

            # The threshold is used to increase decoding speed
            if num_output_examples == 20:
                break

        if FLAGS.token_decoding_algorithm == 'beam_search':
            if beam_outputs:
                batch_outputs.append(beam_outputs)

    return batch_outputs
예제 #30
0
def decode_set(sess, model, dataset, top_k, FLAGS, verbose=False):
    """
    Compute top-k predictions on the dev/test dataset and write the predictions
    to disk.

    :param sess: A TensorFlow session.
    :param model: Prediction model object.
    :param top_k: Number of top predictions to compute.
    :param FLAGS: Training/testing hyperparameter settings.
    :param verbose: If set, also print decoding results to screen.
    """
    nl2bash = FLAGS.dataset.startswith('bash') and not FLAGS.explain

    tokenizer_selector = 'cm' if FLAGS.explain else 'nl'
    grouped_dataset = data_utils.group_parallel_data(
        dataset, okenizer_selector=tokenizer_selector)
    vocabs = data_utils.load_vocabulary(FLAGS)
    rev_sc_vocab = vocabs.rev_sc_vocab

    ts = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d-%H%M%S')
    pred_file_path = os.path.join(model.model_dir, 'predictions.{}.{}'.format(
        model.decode_sig, ts))
    pred_file = open(pred_file_path, 'w')
    eval_file_path = os.path.join(model.model_dir, 'predictions.{}.{}.csv'.format(
        model.decode_sig, ts))
    eval_file = open(eval_file_path, 'w')
    eval_file.write('example_id, description, ground_truth, prediction, ' +
                    'correct template, correct command\n')
    for example_id in xrange(len(grouped_dataset)):
        key, data_group = grouped_dataset[example_id]

        sc_txt = data_group[0].sc_txt.strip()
        sc_tokens = [rev_sc_vocab[i] for i in data_group[0].sc_ids]
        if FLAGS.channel == 'char':
            sc_temp = ''.join(sc_tokens)
            sc_temp = sc_temp.replace(constants._SPACE, ' ')
        else:
            sc_temp = ' '.join(sc_tokens)
        tg_txts = [dp.tg_txt for dp in data_group]
        tg_asts = [data_tools.bash_parser(tg_txt) for tg_txt in tg_txts]
        if verbose:
            print('\nExample {}:'.format(example_id))
            print('Original Source: {}'.format(sc_txt.encode('utf-8')))
            print('Source: {}'.format(sc_temp.encode('utf-8')))
            for j in xrange(len(data_group)):
                print('GT Target {}: {}'.format(j+1, data_group[j].tg_txt.encode('utf-8')))

        if FLAGS.fill_argument_slots:
            slot_filling_classifier = get_slot_filling_classifer(FLAGS)
            batch_outputs, sequence_logits = translate_fun(data_group, sess, model,
                vocabs, FLAGS, slot_filling_classifier=slot_filling_classifier)
        else:
            batch_outputs, sequence_logits = translate_fun(data_group, sess, model,
                vocabs, FLAGS)
        if FLAGS.tg_char:
            batch_outputs, batch_char_outputs = batch_outputs

        eval_row = '{},"{}",'.format(example_id, sc_txt.replace('"', '""'))
        if batch_outputs:
            if FLAGS.token_decoding_algorithm == 'greedy':
                tree, pred_cmd = batch_outputs[0]
                if nl2bash:
                    pred_cmd = data_tools.ast2command(
                        tree, loose_constraints=True)
                score = sequence_logits[0]
                if verbose:
                    print('Prediction: {} ({})'.format(pred_cmd, score))
                pred_file.write('{}\n'.format(pred_cmd))
            elif FLAGS.token_decoding_algorithm == 'beam_search':
                top_k_predictions = batch_outputs[0]
                if FLAGS.tg_char:
                    top_k_char_predictions = batch_char_outputs[0]
                top_k_scores = sequence_logits[0]
                num_preds = min(FLAGS.beam_size, top_k, len(top_k_predictions))
                for j in xrange(num_preds):
                    if j > 0:
                        eval_row = ',,'
                    if j < len(tg_txts):
                        eval_row += '"{}",'.format(tg_txts[j].strip().replace('"', '""'))
                    else:
                        eval_row += ','
                    top_k_pred_tree, top_k_pred_cmd = top_k_predictions[j]
                    if nl2bash:
                        pred_cmd = data_tools.ast2command(
                            top_k_pred_tree, loose_constraints=True)
                    else:
                        pred_cmd = top_k_pred_cmd
                    pred_file.write('{}|||'.format(pred_cmd.encode('utf-8')))
                    eval_row += '"{}",'.format(pred_cmd.replace('"', '""'))
                    temp_match = tree_dist.one_match(
                        tg_asts, top_k_pred_tree, ignore_arg_value=True)
                    str_match = tree_dist.one_match(
                        tg_asts, top_k_pred_tree, ignore_arg_value=False)
                    if temp_match:
                        eval_row += 'y,'
                    if str_match:
                        eval_row += 'y'
                    eval_file.write('{}\n'.format(eval_row.encode('utf-8')))
                    if verbose:
                        print('Prediction {}: {} ({})'.format(
                            j+1, pred_cmd.encode('utf-8'), top_k_scores[j]))
                        if FLAGS.tg_char:
                            print('Character-based prediction {}: {}'.format(
                                j+1, top_k_char_predictions[j].encode('utf-8')))
                pred_file.write('\n')
        else:
            print(APOLOGY_MSG)
            pred_file.write('\n')
            eval_file.write('{}\n'.format(eval_row))
            eval_file.write('\n')
            eval_file.write('\n')
    pred_file.close()
    eval_file.close()
    shutil.copyfile(pred_file_path, os.path.join(FLAGS.model_dir,
        'predictions.{}.latest'.format(model.decode_sig)))
    shutil.copyfile(eval_file_path, os.path.join(FLAGS.model_dir,
        'predictions.{}.latest.csv'.format(model.decode_sig)))