예제 #1
0
def get_content_tokens(ast):
    content_tokens = collections.defaultdict(int)
    for token in data_tools.ast2tokens(ast,
                                       loose_constraints=True,
                                       arg_type_only=True,
                                       with_prefix=True):
        if not token.startswith('__ARG__'):
            if token.startswith('__FLAG__'):
                token = token[len('__FLAG__'):]
            content_tokens[token] += 1
    return content_tokens
예제 #2
0
def get_content_tokens(ast):
    content_tokens = collections.defaultdict(int)
    for compound_token in data_tools.ast2tokens(ast, loose_constraints=True,
            arg_type_only=True, with_prefix=True, with_flag_argtype=True):
        kind_token = compound_token.split(nast.KIND_PREFIX)
        if len(kind_token) == 2:
            kind, token = kind_token
        else:
            kind = ''
            token = kind_token[0]
        if kind.lower() != 'argument':
            content_tokens[token] += 1
    return content_tokens
예제 #3
0
def eval_slot_filling(dataset):
    """
    Evaluate global slot filling algorithm F1 using ground truth templates.
    """
    vocabs = data_utils.load_vocab(FLAGS)
    rev_tg_vocab = vocabs.rev_tg_vocab
    rev_tg_full_vocab = vocabs.rev_tg_full_vocab

    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement)) as sess:
        # Create model.
        FLAGS.beam_size = 1
        FLAGS.token_decoding_algorithm = 'beam_search'
        FLAGS.force_reading_input = True
        model = graph_utils.create_model(sess,
                                         FLAGS,
                                         Seq2SeqModel,
                                         buckets=_buckets,
                                         forward_only=True)

        model_param_dir = os.path.join(FLAGS.model_dir,
                                       'train.mappings.X.Y.npz')
        train_X, train_Y = data_utils.load_slot_filling_data(model_param_dir)
        slot_filling_classifier = classifiers.KNearestNeighborModel(
            FLAGS.num_nn_slot_filling, train_X, train_Y)
        print('Slot filling classifier parameters loaded.')

        num_correct_argument = 0.0
        num_argument = 0.0
        num_correct_align = 0.0
        num_predict_align = 0.0
        num_gt_align = 0.0
        for bucket_id in xrange(len(_buckets)):
            for data_id in xrange(len(dataset[bucket_id])):
                dp = dataset[bucket_id][data_id]
                gt_mappings = [tuple(m) for m in dp.mappings]
                outputs = dp.tg_ids[1:-1]
                full_outputs = dp.tg_full_ids[1:-1]
                if gt_mappings:
                    _, entities = tokenizer.ner_tokenizer(dp.sc_txt)
                    nl_fillers = entities[0]
                    encoder_inputs = [dp.sc_ids]
                    encoder_full_inputs = [dp.sc_copy_ids] \
                        if FLAGS.use_copy else [dp.sc_full_ids]
                    decoder_inputs = [dp.tg_ids]
                    decoder_full_inputs = [dp.tg_full_ids] \
                        if FLAGS.use_copy else [dp.tg_copy_ids]
                    pointer_targets = [dp.pointer_targets] \
                        if FLAGS.use_copy else None
                    formatted_example = model.format_example(
                        [encoder_inputs, encoder_full_inputs],
                        [decoder_inputs, decoder_full_inputs],
                        pointer_targets=pointer_targets,
                        bucket_id=bucket_id)
                    model_outputs = model.step(sess,
                                               formatted_example,
                                               bucket_id,
                                               forward_only=True)
                    encoder_outputs = model_outputs.encoder_hidden_states
                    decoder_outputs = model_outputs.decoder_hidden_states
                    print(decoder_outputs[:, 0, :])

                    cm_slots = {}
                    output_tokens = []
                    for ii in xrange(len(outputs)):
                        output = outputs[ii]
                        if output < len(rev_tg_vocab):
                            token = rev_tg_vocab[output]
                            if "@@" in token:
                                token = token.split("@@")[-1]
                            output_tokens.append(token)
                            if token.startswith('__ARG__'):
                                token = token[len('__ARG__'):]
                            if nl_fillers is not None and \
                                    token in constants._ENTITIES:
                                if ii > 0 and slot_filling.is_min_flag(
                                        rev_tg_vocab[outputs[ii - 1]]):
                                    token_type = 'Timespan'
                                else:
                                    token_type = token
                                cm_slots[ii] = (token, token_type)
                        else:
                            output_tokens.append(data_utils._UNK)
                    if FLAGS.use_copy:
                        P = pointer_targets[0][0] > 0
                        pointers = model_outputs.pointers[0]
                        pointers = np.multiply(
                            np.sum(P.astype(float)[:pointers.shape[0],
                                                   -pointers.shape[1]:],
                                   1,
                                   keepdims=True), pointers)
                    else:
                        pointers = None
                    tree, _, mappings = slot_filling.stable_slot_filling(
                        output_tokens,
                        nl_fillers,
                        cm_slots,
                        pointers,
                        encoder_outputs[0],
                        decoder_outputs[0],
                        slot_filling_classifier,
                        verbose=True)

                    if mappings is not None:
                        # print(gt_mappings)
                        for mapping in mappings:
                            # print(mapping)
                            if mapping in gt_mappings:
                                num_correct_align += 1
                        num_predict_align += len(mappings)
                    num_gt_align += len(gt_mappings)

                    tokens = data_tools.ast2tokens(tree)
                    if not tokens:
                        continue
                    for ii in xrange(len(outputs)):
                        output = outputs[ii]
                        token = rev_tg_vocab[output]
                        if token.startswith('__ARG__'):
                            token = token[len('__ARG__'):]
                        if token in constants._ENTITIES:
                            argument = rev_tg_full_vocab[full_outputs[ii]]
                            if argument.startswith('__ARG__'):
                                argument = argument[len('__ARG__'):]
                            pred = tokens[ii]
                            if constants.remove_quotation(argument) == \
                                    constants.remove_quotation(pred):
                                num_correct_argument += 1
                            num_argument += 1
            if gt_mappings:
                break
        precision = num_correct_align / num_predict_align
        recall = num_correct_align / num_gt_align
        print("Argument Alignment Precision: {}".format(precision))
        print("Argument Alignment Recall: {}".format(recall))
        print("Argument Alignment F1: {}".format(2 * precision * recall /
                                                 (precision + recall)))

        print("Argument filling accuracy: {}".format(num_correct_argument /
                                                     num_argument))
예제 #4
0
def print_tokens(cmd):
    ast = bashlint.data_tools.bash_parser(cmd)
    tokens = data_tools.ast2tokens(ast)
    print(" ".join(tokens))