def decode_set(model, dataset, rev_nl_vocab, rev_cm_vocab, verbose=True): grouped_dataset = data_utils.group_data_by_nl(dataset) with DBConnection() as db: db.remove_model(model_name) num_eval = 0 for nl_temp in grouped_dataset: batch_nl_strs, batch_cm_strs, batch_nls, batch_cmds = \ grouped_dataset[nl_temp] nl_str = batch_nl_strs[0] nl = batch_nls[0] if verbose: print("Example {}".format(num_eval+1)) print("Original English: " + nl_str.strip()) print("English: " + nl_temp) for j in xrange(len(batch_cm_strs)): print("GT Command {}: {}".format(j+1, batch_cm_strs[j].strip())) top_k_results = model.test(nl, 10) for i in xrange(len(top_k_results)): nn, cmd, score = top_k_results[i] nn_str = ' '.join([rev_nl_vocab[i] for i in nn]) tokens = [] for i in cmd: pred_token = rev_cm_vocab[i] if "@@" in pred_token: pred_token = pred_token.split("@@")[-1] tokens.append(pred_token) pred_cmd = ' '.join(tokens) tree = data_tools.bash_parser(pred_cmd) if verbose: print("NN: {}".format(nn_str)) print("Prediction {}: {} ({})".format(i, pred_cmd, score)) print("AST: ") data_tools.pretty_print(tree, 0) print db.add_prediction(model_name, nl_str, pred_cmd, float(score), update_mode=False) num_eval += 1
def decode_set(sess, model, dataset, rev_nl_vocab, rev_cm_vocab, FLAGS, verbose=True): grouped_dataset = data_utils.group_data_by_nl(dataset, use_bucket=True, use_nl_temp = FLAGS.dataset.startswith("bash")) bucketed_nl_strs, bucketed_cm_strs, bucketed_nls, bucketed_cmds = \ data_utils.bucket_grouped_data(grouped_dataset, model.buckets) with DBConnection() as db: db.remove_model(model.model_sig) for bucket_id in xrange(len(model.buckets)): bucket_nl_strs = bucketed_nl_strs[bucket_id] bucket_cm_strs = bucketed_cm_strs[bucket_id] bucket_nls = bucketed_nls[bucket_id] bucket_cmds = bucketed_cmds[bucket_id] bucket_size = len(bucket_nl_strs) num_batches = int(bucket_size / FLAGS.batch_size) if bucket_size % FLAGS.batch_size != 0: num_batches += 1 for b in xrange(num_batches): batch_nl_strs = bucket_nl_strs[b*FLAGS.batch_size:(b+1)*FLAGS.batch_size] batch_cm_strs = bucket_cm_strs[b*FLAGS.batch_size:(b+1)*FLAGS.batch_size] batch_nls = bucket_nls[b*FLAGS.batch_size:(b+1)*FLAGS.batch_size] batch_cmds = bucket_cmds[b*FLAGS.batch_size:(b+1)*FLAGS.batch_size] # make a full batch if len(batch_nl_strs) < FLAGS.batch_size: batch_size = len(batch_nl_strs) batch_nl_strs = batch_nl_strs + [batch_nl_strs[-1]] * (FLAGS.batch_size - len(batch_nl_strs)) batch_cm_strs = batch_cm_strs + [batch_cm_strs[-1]] * (FLAGS.batch_size - len(batch_cm_strs)) batch_nls = batch_nls + [batch_nls[-1]] * (FLAGS.batch_size - len(batch_nls)) batch_cmds = batch_cmds + [batch_cmds[-1]] * (FLAGS.batch_size - len(batch_cmds)) else: batch_size = FLAGS.batch_size formatted_example = model.format_example(batch_nls, batch_cmds, bucket_id=bucket_id) output_symbols, output_logits, losses, attn_masks = \ model.step(sess, formatted_example, bucket_id, forward_only=True) batch_outputs = decode(output_symbols, rev_cm_vocab, FLAGS) for batch_id in xrange(batch_size): example_id = b * FLAGS.batch_size + batch_id nl_str = batch_nl_strs[batch_id] cm_strs = batch_cm_strs[batch_id] nl = batch_nls[batch_id] nl_temp = ' '.join([rev_nl_vocab[i] for i in nl]) if verbose: print("Example {}:{}".format(bucket_id, example_id)) print("Original English: " + nl_str.strip()) print("English: " + nl_temp) for j in xrange(len(cm_strs)): print("GT Command {}: {}".format(j+1, cm_strs[j].strip())) if FLAGS.decoding_algorithm == "greedy": tree, pred_cmd, outputs = batch_outputs[batch_id] score = output_logits[batch_id] db.add_prediction(model.model_sig, nl_str, pred_cmd, float(score)) if verbose: print("Prediction: {} ({})".format(pred_cmd, score)) # print("AST: ") # data_tools.pretty_print(tree, 0) # print() elif FLAGS.decoding_algorithm == "beam_search": top_k_predictions = batch_outputs[batch_id] top_k_scores = output_logits[batch_id] assert(len(top_k_predictions) == FLAGS.beam_size) for j in xrange(min(FLAGS.beam_size, 10)): top_k_pred_tree, top_k_pred_cmd, top_k_outputs = top_k_predictions[j] if verbose: print("Prediction {}: {} ({}) ".format( j+1, top_k_pred_cmd, top_k_scores[j])) db.add_prediction(model.model_sig, nl_str, top_k_pred_cmd, float(top_k_scores[j]), update_mode=False) # print("AST: ") # data_tools.pretty_print(top_k_pred_tree, 0) if verbose: print() outputs = top_k_predictions[0][2] else: raise ValueError("Unrecognized decoding algorithm: {}." .format(FLAGS.decoding_algorithm)) if attn_masks is not None: if FLAGS.decoding_algorithm == "greedy": M = attn_masks[batch_id, :, :] elif FLAGS.decoding_algorithm == "beam_search": M = attn_masks[batch_id, 0, :, :] visualize_attn_masks(M, nl, outputs, rev_nl_vocab, rev_cm_vocab, os.path.join(FLAGS.model_dir, "{}-{}.jpg".format(bucket_id, example_id)))