def get_content_tokens(ast): content_tokens = collections.defaultdict(int) for token in data_tools.ast2tokens(ast, loose_constraints=True, arg_type_only=True, with_prefix=True): if not token.startswith('__ARG__'): if token.startswith('__FLAG__'): token = token[len('__FLAG__'):] content_tokens[token] += 1 return content_tokens
def get_content_tokens(ast): content_tokens = collections.defaultdict(int) for compound_token in data_tools.ast2tokens(ast, loose_constraints=True, arg_type_only=True, with_prefix=True, with_flag_argtype=True): kind_token = compound_token.split(nast.KIND_PREFIX) if len(kind_token) == 2: kind, token = kind_token else: kind = '' token = kind_token[0] if kind.lower() != 'argument': content_tokens[token] += 1 return content_tokens
def eval_slot_filling(dataset): """ Evaluate global slot filling algorithm F1 using ground truth templates. """ vocabs = data_utils.load_vocab(FLAGS) rev_tg_vocab = vocabs.rev_tg_vocab rev_tg_full_vocab = vocabs.rev_tg_full_vocab with tf.Session(config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement)) as sess: # Create model. FLAGS.beam_size = 1 FLAGS.token_decoding_algorithm = 'beam_search' FLAGS.force_reading_input = True model = graph_utils.create_model(sess, FLAGS, Seq2SeqModel, buckets=_buckets, forward_only=True) model_param_dir = os.path.join(FLAGS.model_dir, 'train.mappings.X.Y.npz') train_X, train_Y = data_utils.load_slot_filling_data(model_param_dir) slot_filling_classifier = classifiers.KNearestNeighborModel( FLAGS.num_nn_slot_filling, train_X, train_Y) print('Slot filling classifier parameters loaded.') num_correct_argument = 0.0 num_argument = 0.0 num_correct_align = 0.0 num_predict_align = 0.0 num_gt_align = 0.0 for bucket_id in xrange(len(_buckets)): for data_id in xrange(len(dataset[bucket_id])): dp = dataset[bucket_id][data_id] gt_mappings = [tuple(m) for m in dp.mappings] outputs = dp.tg_ids[1:-1] full_outputs = dp.tg_full_ids[1:-1] if gt_mappings: _, entities = tokenizer.ner_tokenizer(dp.sc_txt) nl_fillers = entities[0] encoder_inputs = [dp.sc_ids] encoder_full_inputs = [dp.sc_copy_ids] \ if FLAGS.use_copy else [dp.sc_full_ids] decoder_inputs = [dp.tg_ids] decoder_full_inputs = [dp.tg_full_ids] \ if FLAGS.use_copy else [dp.tg_copy_ids] pointer_targets = [dp.pointer_targets] \ if FLAGS.use_copy else None formatted_example = model.format_example( [encoder_inputs, encoder_full_inputs], [decoder_inputs, decoder_full_inputs], pointer_targets=pointer_targets, bucket_id=bucket_id) model_outputs = model.step(sess, formatted_example, bucket_id, forward_only=True) encoder_outputs = model_outputs.encoder_hidden_states decoder_outputs = model_outputs.decoder_hidden_states print(decoder_outputs[:, 0, :]) cm_slots = {} output_tokens = [] for ii in xrange(len(outputs)): output = outputs[ii] if output < len(rev_tg_vocab): token = rev_tg_vocab[output] if "@@" in token: token = token.split("@@")[-1] output_tokens.append(token) if token.startswith('__ARG__'): token = token[len('__ARG__'):] if nl_fillers is not None and \ token in constants._ENTITIES: if ii > 0 and slot_filling.is_min_flag( rev_tg_vocab[outputs[ii - 1]]): token_type = 'Timespan' else: token_type = token cm_slots[ii] = (token, token_type) else: output_tokens.append(data_utils._UNK) if FLAGS.use_copy: P = pointer_targets[0][0] > 0 pointers = model_outputs.pointers[0] pointers = np.multiply( np.sum(P.astype(float)[:pointers.shape[0], -pointers.shape[1]:], 1, keepdims=True), pointers) else: pointers = None tree, _, mappings = slot_filling.stable_slot_filling( output_tokens, nl_fillers, cm_slots, pointers, encoder_outputs[0], decoder_outputs[0], slot_filling_classifier, verbose=True) if mappings is not None: # print(gt_mappings) for mapping in mappings: # print(mapping) if mapping in gt_mappings: num_correct_align += 1 num_predict_align += len(mappings) num_gt_align += len(gt_mappings) tokens = data_tools.ast2tokens(tree) if not tokens: continue for ii in xrange(len(outputs)): output = outputs[ii] token = rev_tg_vocab[output] if token.startswith('__ARG__'): token = token[len('__ARG__'):] if token in constants._ENTITIES: argument = rev_tg_full_vocab[full_outputs[ii]] if argument.startswith('__ARG__'): argument = argument[len('__ARG__'):] pred = tokens[ii] if constants.remove_quotation(argument) == \ constants.remove_quotation(pred): num_correct_argument += 1 num_argument += 1 if gt_mappings: break precision = num_correct_align / num_predict_align recall = num_correct_align / num_gt_align print("Argument Alignment Precision: {}".format(precision)) print("Argument Alignment Recall: {}".format(recall)) print("Argument Alignment F1: {}".format(2 * precision * recall / (precision + recall))) print("Argument filling accuracy: {}".format(num_correct_argument / num_argument))
def print_tokens(cmd): ast = bashlint.data_tools.bash_parser(cmd) tokens = data_tools.ast2tokens(ast) print(" ".join(tokens))