def clean_rewrites(self):
     c = self.cursor
     non_grammatical = []
     for s1, s2 in c.execute("SELECT s1, s2 FROM Rewrites"):
         ast = data_tools.bash_parser(s1)
         if not ast:
             non_grammatical.append(s1)
         ast2 = data_tools.bash_parser(s2)
         if not ast2:
             non_grammatical.append(s2)
     for s in non_grammatical:
         print("Removing %s from Rewrites" % s)
         c.execute("DELETE FROM Rewrites WHERE s1 = ?", (s,))
         c.execute("DELETE FROM Rewrites WHERE s2 = ?", (s,))
 def add_to_set(nl_data, cm_data, split):
     with_parent = True
     for nl, cm in zip(getattr(nl_data, split), getattr(cm_data, split)):
         ast = data_tools.bash_parser(cm)
         if ast:
             if is_simple(ast):
                 nl_chars = data_tools.char_tokenizer(nl, data_tools.basic_tokenizer,
                                                      normalize_digits=False,
                                                      normalize_long_pattern=False)
                 cm_chars = data_tools.char_tokenizer(cm, data_tools.bash_tokenizer,
                                                      normalize_digits=False,
                                                      normalize_long_pattern=False)
                 nl_tokens = data_tools.basic_tokenizer(nl)
                 cm_tokens = data_tools.ast2tokens(ast, with_parent=with_parent)
                 cm_seq = data_tools.ast2list(ast, list=[], with_parent=with_parent)
                 pruned_ast = normalizer.prune_ast(ast)
                 cm_pruned_tokens = data_tools.ast2tokens(
                     pruned_ast, loose_constraints=True, with_parent=with_parent)
                 cm_pruned_seq = data_tools.ast2list(
                     pruned_ast, list=[], with_parent=with_parent)
                 cm_normalized_tokens = data_tools.ast2tokens(
                     ast, loose_constraints=True, arg_type_only=True, with_parent=with_parent)
                 cm_normalized_seq = data_tools.ast2list(
                     ast, arg_type_only=True, list=[], with_parent=with_parent)
                 cm_canonical_tokens = data_tools.ast2tokens(
                     ast, loose_constraints=True, arg_type_only=True, ignore_flag_order=True,
                     with_parent=with_parent)
                 cm_canonical_seq = data_tools.ast2list(
                     ast, arg_type_only=True, ignore_flag_order=True, list=[],
                     with_parent=with_parent)
                 getattr(nl_list, split).append(nl)
                 getattr(cm_list, split).append(cm)
                 getattr(nl_char_list, split).append(nl_chars)
                 getattr(nl_token_list, split).append(nl_tokens)
                 getattr(cm_char_list, split).append(cm_chars)
                 getattr(cm_token_list, split).append(cm_tokens)
                 getattr(cm_seq_list, split).append(cm_seq)
                 getattr(cm_pruned_token_list, split).append(cm_pruned_tokens)
                 getattr(cm_pruned_seq_list, split).append(cm_pruned_seq)
                 getattr(cm_normalized_token_list, split).append(cm_normalized_tokens)
                 getattr(cm_normalized_seq_list, split).append(cm_normalized_seq)
                 getattr(cm_canonical_token_list, split).append(cm_canonical_tokens)
                 getattr(cm_canonical_seq_list, split).append(cm_canonical_seq)
             else:
                 print("Rare command: " + cm)
Example #3
0
def decode_set(model, dataset, rev_nl_vocab, rev_cm_vocab, verbose=True):
    grouped_dataset = data_utils.group_data_by_nl(dataset)

    with DBConnection() as db:
        db.remove_model(model_name)
        num_eval = 0
        for nl_temp in grouped_dataset:
            batch_nl_strs, batch_cm_strs, batch_nls, batch_cmds = \
                grouped_dataset[nl_temp]

            nl_str = batch_nl_strs[0]
            nl = batch_nls[0]
            if verbose:
                print("Example {}".format(num_eval+1))
                print("Original English: " + nl_str.strip())
                print("English: " + nl_temp)
                for j in xrange(len(batch_cm_strs)):
                    print("GT Command {}: {}".format(j+1, batch_cm_strs[j].strip()))
            top_k_results = model.test(nl, 10)
            for i in xrange(len(top_k_results)):
                nn, cmd, score = top_k_results[i]
                nn_str = ' '.join([rev_nl_vocab[i] for i in nn])
                tokens = []
                for i in cmd:
                    pred_token = rev_cm_vocab[i]
                    if "@@" in pred_token:
                        pred_token = pred_token.split("@@")[-1]
                    tokens.append(pred_token)
                pred_cmd = ' '.join(tokens)
                tree = data_tools.bash_parser(pred_cmd)
                if verbose:
                    print("NN: {}".format(nn_str))
                    print("Prediction {}: {} ({})".format(i, pred_cmd, score))
                    print("AST: ")
                    data_tools.pretty_print(tree, 0)
                    print
                db.add_prediction(model_name, nl_str, pred_cmd, float(score),
                                  update_mode=False)
            
            num_eval += 1
def test_rewrite(cmd):
    with DBConnection() as db:
        ast = data_tools.bash_parser(cmd)
        rewrites = list(db.get_rewrites(ast))
        for i in xrange(len(rewrites)):
            print("rewrite %d: %s" % (i, data_tools.ast2command(rewrites[i])))