コード例 #1
0
ファイル: filter_data.py プロジェクト: changlinzhang/nl2bash
 def select(ast, utility_set):
     for ut in data_tools.get_utilities(ast):
         if not ut in utility_set:
             print('Utility currently not handled: {} - {}'.format(
                 ut, data_tools.ast2command(ast, loose_constraints=True)))
             return False
     return True
コード例 #2
0
def stable_slot_filling(template_tokens,
                        sc_fillers,
                        tg_slots,
                        pointer_targets,
                        encoder_outputs,
                        decoder_outputs,
                        slot_filling_classifier,
                        verbose=False):
    """
    Fills the argument slots using learnt local alignment scores and a greedy 
    global alignment algorithm (stable marriage).

    :param template_tokens: list of tokens in the command template
    :param sc_fillers: the slot fillers extracted from the source sequence,
        indexed by token id
    :param tg_slots: the argument slots in the command template, indexed by
        token id
    :param pointer_targets: [encoder_length, decoder_length], local alignment
        scores between source and target tokens
    :param encoder_outputs: [encoder_length, dim] sequence of encoder hidden states
    :param decoder_outputs: [decoder_length, dim] sequence of decoder hidden states
    :param slot_filling_classifier: the classifier that produces the local
        alignment scores
    :param verbose: print all local alignment scores if set to true
    """

    # Step a): prepare (binary) type alignment matrix based on type info
    M = np.zeros([len(encoder_outputs), len(decoder_outputs)], dtype=np.int32)
    for f in sc_fillers:
        assert (f <= len(encoder_outputs))
        surface, filler_type = sc_fillers[f]
        matched = False
        for s in tg_slots:
            assert (s <= len(decoder_outputs))
            slot_value, slot_type = tg_slots[s]
            if slot_filler_type_match(slot_type, filler_type):
                M[f, s] = 1
                matched = True
        if not matched:
            # If no target slot can hold a source filler, skip the alignment
            # step and return None
            return None, None, None

    # Step b): compute local alignment scores if they are not provided already
    if pointer_targets is None:
        assert (encoder_outputs is not None)
        assert (decoder_outputs is not None)
        assert (slot_filling_classifier is not None)
        pointer_targets = np.zeros(
            [len(encoder_outputs), len(decoder_outputs)])
        for f in xrange(M.shape[0]):
            if np.sum(M[f]) > 1:
                X = []
                # use reversed index for the encoder embeddings matrix
                ff = len(encoder_outputs) - f - 1
                cm_slots_keys = list(tg_slots.keys())
                for s in cm_slots_keys:
                    X.append(
                        np.concatenate([
                            encoder_outputs[ff:ff + 1],
                            decoder_outputs[s:s + 1]
                        ],
                                       axis=1))
                X = np.concatenate(X, axis=0)
                X = X / norm(X, axis=1)[:, None]
                raw_scores = slot_filling_classifier.predict(X)
                for ii in xrange(len(raw_scores)):
                    s = cm_slots_keys[ii]
                    pointer_targets[f, s] = raw_scores[ii]
                    if verbose:
                        print('• alignment ({}, {}): {}\t{}\t{}'.format(
                            f, s, sc_fillers[f], tg_slots[s], raw_scores[ii]))

    M = M + M * pointer_targets
    # convert M into a dictinary representation of a sparse matrix
    M_dict = collections.defaultdict(dict)
    for i in xrange(M.shape[0]):
        if np.sum(M[i]) > 0:
            for j in xrange(M.shape[1]):
                if M[i, j] > 0:
                    M_dict[i][j] = M[i, j]

    mappings, remained_fillers = stable_marriage_alignment(M_dict)

    if not remained_fillers:
        for f, s in mappings:
            template_tokens[s] = get_fill_in_value(tg_slots[s], sc_fillers[f])
        cmd = ' '.join(template_tokens)
        tree = data_tools.bash_parser(cmd)
        if not tree is None:
            data_tools.fill_default_value(tree)
        temp = data_tools.ast2command(tree,
                                      loose_constraints=True,
                                      ignore_flag_order=False)
    else:
        tree, temp = None, None

    return tree, temp, mappings
コード例 #3
0
ファイル: decode_tools.py プロジェクト: hpplinux/nl2bash
def decode_set(sess, model, dataset, top_k, FLAGS, verbose=False):
    """
    Compute top-k predictions on the dev/test dataset and write the predictions
    to disk.

    :param sess: A TensorFlow session.
    :param model: Prediction model object.
    :param top_k: Number of top predictions to compute.
    :param FLAGS: Training/testing hyperparameter settings.
    :param verbose: If set, also print decoding results to screen.
    """
    nl2bash = FLAGS.dataset.startswith('bash') and not FLAGS.explain

    tokenizer_selector = 'cm' if FLAGS.explain else 'nl'
    grouped_dataset = data_utils.group_parallel_data(
        dataset, okenizer_selector=tokenizer_selector)
    vocabs = data_utils.load_vocabulary(FLAGS)
    rev_sc_vocab = vocabs.rev_sc_vocab

    ts = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d-%H%M%S')
    pred_file_path = os.path.join(model.model_dir, 'predictions.{}.{}'.format(
        model.decode_sig, ts))
    pred_file = open(pred_file_path, 'w')
    eval_file_path = os.path.join(model.model_dir, 'predictions.{}.{}.csv'.format(
        model.decode_sig, ts))
    eval_file = open(eval_file_path, 'w')
    eval_file.write('example_id, description, ground_truth, prediction, ' +
                    'correct template, correct command\n')
    for example_id in xrange(len(grouped_dataset)):
        key, data_group = grouped_dataset[example_id]

        sc_txt = data_group[0].sc_txt.strip()
        sc_tokens = [rev_sc_vocab[i] for i in data_group[0].sc_ids]
        if FLAGS.channel == 'char':
            sc_temp = ''.join(sc_tokens)
            sc_temp = sc_temp.replace(constants._SPACE, ' ')
        else:
            sc_temp = ' '.join(sc_tokens)
        tg_txts = [dp.tg_txt for dp in data_group]
        tg_asts = [data_tools.bash_parser(tg_txt) for tg_txt in tg_txts]
        if verbose:
            print('\nExample {}:'.format(example_id))
            print('Original Source: {}'.format(sc_txt.encode('utf-8')))
            print('Source: {}'.format(sc_temp.encode('utf-8')))
            for j in xrange(len(data_group)):
                print('GT Target {}: {}'.format(j+1, data_group[j].tg_txt.encode('utf-8')))

        if FLAGS.fill_argument_slots:
            slot_filling_classifier = get_slot_filling_classifer(FLAGS)
            batch_outputs, sequence_logits = translate_fun(data_group, sess, model,
                vocabs, FLAGS, slot_filling_classifier=slot_filling_classifier)
        else:
            batch_outputs, sequence_logits = translate_fun(data_group, sess, model,
                vocabs, FLAGS)
        if FLAGS.tg_char:
            batch_outputs, batch_char_outputs = batch_outputs

        eval_row = '{},"{}",'.format(example_id, sc_txt.replace('"', '""'))
        if batch_outputs:
            if FLAGS.token_decoding_algorithm == 'greedy':
                tree, pred_cmd = batch_outputs[0]
                if nl2bash:
                    pred_cmd = data_tools.ast2command(
                        tree, loose_constraints=True)
                score = sequence_logits[0]
                if verbose:
                    print('Prediction: {} ({})'.format(pred_cmd, score))
                pred_file.write('{}\n'.format(pred_cmd))
            elif FLAGS.token_decoding_algorithm == 'beam_search':
                top_k_predictions = batch_outputs[0]
                if FLAGS.tg_char:
                    top_k_char_predictions = batch_char_outputs[0]
                top_k_scores = sequence_logits[0]
                num_preds = min(FLAGS.beam_size, top_k, len(top_k_predictions))
                for j in xrange(num_preds):
                    if j > 0:
                        eval_row = ',,'
                    if j < len(tg_txts):
                        eval_row += '"{}",'.format(tg_txts[j].strip().replace('"', '""'))
                    else:
                        eval_row += ','
                    top_k_pred_tree, top_k_pred_cmd = top_k_predictions[j]
                    if nl2bash:
                        pred_cmd = data_tools.ast2command(
                            top_k_pred_tree, loose_constraints=True)
                    else:
                        pred_cmd = top_k_pred_cmd
                    pred_file.write('{}|||'.format(pred_cmd.encode('utf-8')))
                    eval_row += '"{}",'.format(pred_cmd.replace('"', '""'))
                    temp_match = tree_dist.one_match(
                        tg_asts, top_k_pred_tree, ignore_arg_value=True)
                    str_match = tree_dist.one_match(
                        tg_asts, top_k_pred_tree, ignore_arg_value=False)
                    if temp_match:
                        eval_row += 'y,'
                    if str_match:
                        eval_row += 'y'
                    eval_file.write('{}\n'.format(eval_row.encode('utf-8')))
                    if verbose:
                        print('Prediction {}: {} ({})'.format(
                            j+1, pred_cmd.encode('utf-8'), top_k_scores[j]))
                        if FLAGS.tg_char:
                            print('Character-based prediction {}: {}'.format(
                                j+1, top_k_char_predictions[j].encode('utf-8')))
                pred_file.write('\n')
        else:
            print(APOLOGY_MSG)
            pred_file.write('\n')
            eval_file.write('{}\n'.format(eval_row))
            eval_file.write('\n')
            eval_file.write('\n')
    pred_file.close()
    eval_file.close()
    shutil.copyfile(pred_file_path, os.path.join(FLAGS.model_dir,
        'predictions.{}.latest'.format(model.decode_sig)))
    shutil.copyfile(eval_file_path, os.path.join(FLAGS.model_dir,
        'predictions.{}.latest.csv'.format(model.decode_sig)))
コード例 #4
0
def decode_set(sess, model, dataset, top_k, FLAGS, verbose=True):
    """
    Compute top-k predictions on the dev/test dataset and write the predictions
    to disk.

    :param sess: A TensorFlow session.
    :param model: Prediction model object.
    :param top_k: Number of top predictions to compute.
    :param FLAGS: Training/testing hyperparameter settings.
    :param verbose: If set, also print decoding results to screen.
    """
    nl2bash = FLAGS.dataset.startswith('bash') and not FLAGS.explain

    tokenizer_selector = 'cm' if FLAGS.explain else 'nl'
    grouped_dataset = data_utils.group_data(
        dataset,
        use_bucket=model.buckets,
        use_temp=FLAGS.normalized,
        tokenizer_selector=tokenizer_selector)
    vocabs = data_utils.load_vocab(FLAGS)
    rev_sc_vocab = vocabs.rev_sc_vocab

    if FLAGS.fill_argument_slots:
        # create slot filling classifier
        mapping_param_dir = os.path.join(FLAGS.model_dir,
                                         'train.mappings.X.Y.npz')
        train_X, train_Y = data_utils.load_slot_filling_data(mapping_param_dir)
        slot_filling_classifier = classifiers.KNearestNeighborModel(
            FLAGS.num_nn_slot_filling, train_X, train_Y)
        print('Slot filling classifier parameters loaded.')
    else:
        slot_filling_classifier = None

    ts = datetime.datetime.fromtimestamp(
        time.time()).strftime('%Y-%m-%d-%H%M%S')
    pred_file_path = os.path.join(
        model.model_dir, 'predictions.{}.{}'.format(model.decode_sig, ts))
    pred_file = open(pred_file_path, 'w')
    for example_id in xrange(len(grouped_dataset)):
        key, data_group = grouped_dataset[example_id]

        sc_txt = data_group[0].sc_txt
        sc_temp = ' '.join([rev_sc_vocab[i] for i in data_group[0].sc_ids])
        if verbose:
            print('Example {}:'.format(example_id))
            print('Original Source: {}'.format(sc_txt))
            print('Source: {}'.format(sc_temp))
            for j in xrange(len(data_group)):
                print('GT Target {}: {}'.format(j + 1, data_group[j].tg_txt))

        batch_outputs, output_logits = translate_fun(
            data_group,
            sess,
            model,
            vocabs,
            FLAGS,
            slot_filling_classifier=slot_filling_classifier)
        if FLAGS.tg_char:
            batch_outputs, batch_char_outputs = batch_outputs

        if batch_outputs:
            if FLAGS.token_decoding_algorithm == 'greedy':
                tree, pred_cmd = batch_outputs[0]
                if nl2bash:
                    pred_cmd = data_tools.ast2command(tree,
                                                      loose_constraints=True)
                score = output_logits[0]
                if verbose:
                    print('Prediction: {} ({})'.format(pred_cmd, score))
                pred_file.write('{}\n'.format(pred_cmd))
            elif FLAGS.token_decoding_algorithm == 'beam_search':
                top_k_predictions = batch_outputs[0]
                if FLAGS.tg_char:
                    top_k_char_predictions = batch_char_outputs[0]
                top_k_scores = output_logits[0]
                num_preds = min(FLAGS.beam_size, top_k, len(top_k_predictions))
                for j in xrange(num_preds):
                    top_k_pred_tree, top_k_pred_cmd = top_k_predictions[j]
                    if nl2bash:
                        pred_cmd = data_tools.ast2command(
                            top_k_pred_tree, loose_constraints=True)
                    else:
                        pred_cmd = top_k_pred_cmd
                    pred_file.write('{}|||'.format(pred_cmd))
                    if verbose:
                        print('Prediction {}: {} ({})'.format(
                            j + 1, pred_cmd, top_k_scores[j]))
                        if FLAGS.tg_char:
                            print('Character-based prediction {}: {}'.format(
                                j + 1, top_k_char_predictions[j]))
                pred_file.write('\n')
        else:
            print(APOLOGY_MSG)
            pred_file.write('\n')
    pred_file.close()
    shutil.copyfile(
        pred_file_path,
        os.path.join(FLAGS.model_dir,
                     'predictions.{}.latest'.format(model.decode_sig)))
コード例 #5
0
ファイル: translate.py プロジェクト: sxdkxgwan/awesome_nmt
def decode_set(model, dataset, rev_sc_vocab, rev_tg_vocab, verbose=True):
    grouped_dataset = data_utils.group_data_by_nl(dataset, use_bucket=False,
                                                  use_temp=False)

    with DBConnection() as db:
        db.remove_model(model_name)
        
        num_eval = 0
        for sc_temp in grouped_dataset:
            batch_sc_strs, batch_tg_strs, batch_scs, batch_cmds = \
                grouped_dataset[sc_temp]
            _, entities = tokenizer.ner_tokenizer(sc_temp)
            nl_fillers = entities[-1]
            if nl_fillers is not None:
                cm_slots = {}

            sc_str = batch_sc_strs[0]
            nl = batch_scs[0]
            if verbose:
                print("Example {}".format(num_eval+1))
                print("Original English: " + sc_str.strip())
                print("English: " + sc_temp)
                for j in xrange(len(batch_tg_strs)):
                    print("GT Command {}: {}".format(j+1, batch_tg_strs[j].strip()))
            # retrieve top-ranked command template
            top_k_results = model.test(nl, 100)
            count = 0
            for i in xrange(len(top_k_results)):
                nn, output_tokens, score = top_k_results[i]
                nn_str = ' '.join([rev_sc_vocab[j] for j in nn])
                tokens = []
                for j in xrange(1, len(output_tokens)-1):
                    pred_token = rev_tg_vocab[output_tokens[j]]
                    if "@@" in pred_token:
                        pred_token = pred_token.split("@@")[-1]
                    if nl_fillers is not None and \
                            pred_token in constants._ENTITIES:
                        if j > 0 and slot_filling.is_min_flag(
                                rev_tg_vocab[output_tokens[j-1]]):
                            pred_token_type = 'Timespan'
                        else:
                            pred_token_type = pred_token
                        cm_slots[j] = (pred_token, pred_token_type)
                    tokens.append(pred_token)
                pred_cmd = ' '.join(tokens)
                # check if the predicted command templates have enough slots to
                # hold the fillers (to rule out templates that are trivially
                # unqualified)
                if FLAGS.dataset.startswith("bash"):
                    pred_cmd = re.sub('( ;\s+)|( ;$)', ' \\; ', pred_cmd)
                    tree = data_tools.bash_parser(pred_cmd)
                else:
                    tree = data_tools.paren_parser(pred_cmd)
                if nl_fillers is None or len(cm_slots) >= len(nl_fillers):
                    # Step 2: check if the predicted command template is grammatical
                    # filter out non-grammatical output
                    if tree is not None:
                        matched = slot_filling.heuristic_slot_filling(tree, nl_fillers)
                if tree is not None:
                    slot_filling.fill_default_value(tree)
                    pred_cmd = data_tools.ast2command(tree)
                if verbose:
                    print("NN: {}".format(nn_str))
                    print("Prediction {}: {} ({})".format(i, pred_cmd, score))
                db.add_prediction(model_name, sc_str, pred_cmd, float(score),
                                      update_mode=False)
                count += 1
                if count == 10:
                    break
            print("")        
            num_eval += 1
コード例 #6
0
def test_rewrite(cmd):
    with DBConnection() as db:
        ast = data_tools.bash_parser(cmd)
        rewrites = list(db.get_rewrites(ast))
        for i in xrange(len(rewrites)):
            print("rewrite %d: %s" % (i, data_tools.ast2command(rewrites[i])))