def clean_rewrites(self): c = self.cursor non_grammatical = [] for s1, s2 in c.execute("SELECT s1, s2 FROM Rewrites"): ast = data_tools.bash_parser(s1) if not ast: non_grammatical.append(s1) ast2 = data_tools.bash_parser(s2) if not ast2: non_grammatical.append(s2) for s in non_grammatical: print("Removing %s from Rewrites" % s) c.execute("DELETE FROM Rewrites WHERE s1 = ?", (s,)) c.execute("DELETE FROM Rewrites WHERE s2 = ?", (s,))
def add_to_set(nl_data, cm_data, split): with_parent = True for nl, cm in zip(getattr(nl_data, split), getattr(cm_data, split)): ast = data_tools.bash_parser(cm) if ast: if is_simple(ast): nl_chars = data_tools.char_tokenizer(nl, data_tools.basic_tokenizer, normalize_digits=False, normalize_long_pattern=False) cm_chars = data_tools.char_tokenizer(cm, data_tools.bash_tokenizer, normalize_digits=False, normalize_long_pattern=False) nl_tokens = data_tools.basic_tokenizer(nl) cm_tokens = data_tools.ast2tokens(ast, with_parent=with_parent) cm_seq = data_tools.ast2list(ast, list=[], with_parent=with_parent) pruned_ast = normalizer.prune_ast(ast) cm_pruned_tokens = data_tools.ast2tokens( pruned_ast, loose_constraints=True, with_parent=with_parent) cm_pruned_seq = data_tools.ast2list( pruned_ast, list=[], with_parent=with_parent) cm_normalized_tokens = data_tools.ast2tokens( ast, loose_constraints=True, arg_type_only=True, with_parent=with_parent) cm_normalized_seq = data_tools.ast2list( ast, arg_type_only=True, list=[], with_parent=with_parent) cm_canonical_tokens = data_tools.ast2tokens( ast, loose_constraints=True, arg_type_only=True, ignore_flag_order=True, with_parent=with_parent) cm_canonical_seq = data_tools.ast2list( ast, arg_type_only=True, ignore_flag_order=True, list=[], with_parent=with_parent) getattr(nl_list, split).append(nl) getattr(cm_list, split).append(cm) getattr(nl_char_list, split).append(nl_chars) getattr(nl_token_list, split).append(nl_tokens) getattr(cm_char_list, split).append(cm_chars) getattr(cm_token_list, split).append(cm_tokens) getattr(cm_seq_list, split).append(cm_seq) getattr(cm_pruned_token_list, split).append(cm_pruned_tokens) getattr(cm_pruned_seq_list, split).append(cm_pruned_seq) getattr(cm_normalized_token_list, split).append(cm_normalized_tokens) getattr(cm_normalized_seq_list, split).append(cm_normalized_seq) getattr(cm_canonical_token_list, split).append(cm_canonical_tokens) getattr(cm_canonical_seq_list, split).append(cm_canonical_seq) else: print("Rare command: " + cm)
def decode_set(model, dataset, rev_nl_vocab, rev_cm_vocab, verbose=True): grouped_dataset = data_utils.group_data_by_nl(dataset) with DBConnection() as db: db.remove_model(model_name) num_eval = 0 for nl_temp in grouped_dataset: batch_nl_strs, batch_cm_strs, batch_nls, batch_cmds = \ grouped_dataset[nl_temp] nl_str = batch_nl_strs[0] nl = batch_nls[0] if verbose: print("Example {}".format(num_eval+1)) print("Original English: " + nl_str.strip()) print("English: " + nl_temp) for j in xrange(len(batch_cm_strs)): print("GT Command {}: {}".format(j+1, batch_cm_strs[j].strip())) top_k_results = model.test(nl, 10) for i in xrange(len(top_k_results)): nn, cmd, score = top_k_results[i] nn_str = ' '.join([rev_nl_vocab[i] for i in nn]) tokens = [] for i in cmd: pred_token = rev_cm_vocab[i] if "@@" in pred_token: pred_token = pred_token.split("@@")[-1] tokens.append(pred_token) pred_cmd = ' '.join(tokens) tree = data_tools.bash_parser(pred_cmd) if verbose: print("NN: {}".format(nn_str)) print("Prediction {}: {} ({})".format(i, pred_cmd, score)) print("AST: ") data_tools.pretty_print(tree, 0) print db.add_prediction(model_name, nl_str, pred_cmd, float(score), update_mode=False) num_eval += 1
def test_rewrite(cmd): with DBConnection() as db: ast = data_tools.bash_parser(cmd) rewrites = list(db.get_rewrites(ast)) for i in xrange(len(rewrites)): print("rewrite %d: %s" % (i, data_tools.ast2command(rewrites[i])))