Exemplo n.º 1
0
def decode(encoder_full_inputs,
           model_outputs,
           FLAGS,
           vocabs,
           sc_fillers=None,
           slot_filling_classifier=None):
    """
    Transform the neural network output into readable strings and apply output
    filtering (if any).
    :param encoder_inputs:
    :param model_outputs:
    :param FLAGS:
    :param vocabs:
    :param sc_fillers:
    :param slot_filling_classifier:
    :return batch_outputs: nested list of (target_ast, target) tuples
        - target_ast is a python tree object for target languages that we know
          how to parse and a dummy string for those we don't
        - target is the output string
    """
    rev_sc_vocab = vocabs.rev_sc_vocab
    rev_tg_vocab = vocabs.rev_tg_vocab
    rev_sc_char_vocab = vocabs.rev_sc_char_vocab
    rev_tg_char_vocab = vocabs.rev_tg_char_vocab

    encoder_outputs = model_outputs.encoder_hidden_states
    decoder_outputs = model_outputs.decoder_hidden_states
    # print("encoder_outputs.shape = {}".format(encoder_outputs.shape))
    # print("decoder_outputs.shape = {}".format(decoder_outputs.shape))

    if FLAGS.fill_argument_slots:
        assert (sc_fillers is not None)
        assert (slot_filling_classifier is not None)
        assert (encoder_outputs is not None)
        assert (decoder_outputs is not None)

    output_symbols = model_outputs.output_symbols
    batch_size = len(output_symbols)
    batch_outputs = []
    num_output_examples = 0

    # Prepare copied indices if the model is trained with explicit copy
    # alignments.
    if FLAGS.use_copy and FLAGS.copy_fun == 'supervised':
        pointers = model_outputs.pointers
        sc_length = pointers.shape[1]
        tg_length = pointers.shape[2]
        if FLAGS.token_decoding_algorithm == 'greedy':
            batch_pointers = np.reshape(pointers,
                                        [batch_size, 1, sc_length, tg_length])
        else:
            batch_pointers = np.reshape(
                pointers, [batch_size, FLAGS.beam_size, sc_length, tg_length])

    for batch_id in xrange(batch_size):

        def as_str(output, r_sc_vocab, r_tg_vocab):
            if output < FLAGS.tg_vocab_size:
                token = r_tg_vocab[output]
            else:
                if FLAGS.use_copy and FLAGS.copy_fun == 'copynet':
                    token = r_sc_vocab[encoder_full_inputs[
                        len(encoder_full_inputs) - 1 -
                        (output - FLAGS.tg_vocab_size)][batch_id]]
                else:
                    return data_utils._UNK
            return token

        top_k_predictions = output_symbols[batch_id]
        if FLAGS.token_decoding_algorithm == 'beam_search':
            assert (len(top_k_predictions) == FLAGS.beam_size)
            beam_outputs = []
        else:
            # pack greedy decoding results into size-1 beam
            top_k_predictions = [top_k_predictions]

        for beam_id in xrange(len(top_k_predictions)):
            # Step 1: transform the neural network output into readable strings
            prediction = top_k_predictions[beam_id]
            outputs = [int(pred) for pred in prediction]

            # If there is an EOS symbol in outputs, cut them at that point.
            if data_utils.EOS_ID in outputs:
                outputs = outputs[:outputs.index(data_utils.EOS_ID)]

            if FLAGS.char:
                target = ''.join([
                    as_str(output, rev_sc_char_vocab, rev_tg_char_vocab)
                    for output in outputs
                ]).replace(constants._SPACE, ' ')
            else:
                output_tokens = []
                tg_slots = {}
                for token_id in xrange(len(outputs)):
                    output = outputs[token_id]
                    pred_token = as_str(output, rev_sc_vocab, rev_tg_vocab)
                    if pred_token.startswith('__ARG__'):
                        pred_token = pred_token[len('__ARG__'):]
                    if '@@' in pred_token:
                        pred_token = pred_token.split('@@')[-1]
                    # process argument slots
                    if pred_token in constants._ENTITIES:
                        if token_id > 0 and slot_filling.is_min_flag(
                                rev_tg_vocab[outputs[token_id - 1]]):
                            pred_token_type = 'Timespan'
                        else:
                            pred_token_type = pred_token
                        tg_slots[token_id] = (pred_token, pred_token_type)
                    output_tokens.append(pred_token)

                if FLAGS.partial_token:
                    # process partial-token outputs
                    merged_output_tokens = []
                    buffer = ''
                    load_buffer = False
                    for token in output_tokens:
                        if load_buffer:
                            if token == data_utils._ARG_END:
                                merged_output_tokens.append(buffer)
                                load_buffer = False
                                buffer = ''
                            else:
                                buffer += token
                        else:
                            if token == data_utils._ARG_START:
                                load_buffer = True
                            else:
                                merged_output_tokens.append(token)
                    output_tokens = merged_output_tokens

                target = ' '.join(output_tokens)
            # Step 2: check if the predicted command template is grammatical
            if FLAGS.grammatical_only and not FLAGS.explain:
                if FLAGS.dataset.startswith('bash'):
                    target = re.sub('( ;\s+)|( ;$)', ' \\; ', target)
                    target_ast = data_tools.bash_parser(target)
                elif FLAGS.dataset.startswith('regex'):
                    # TODO: check if a predicted regular expression is legal
                    target_ast = '__DUMMY_TREE__'
                else:
                    target_ast = data_tools.paren_parser(target)
                # filter out non-grammatical output
                if target_ast is None:
                    continue
            else:
                target_ast = '__DUMMY_TREE__'

            # Step 3: check if the predicted command templates have enough
            # slots to hold the fillers (to rule out templates that are
            # trivially unqualified)
            output_example = False
            if FLAGS.explain or not FLAGS.dataset.startswith('bash') \
                    or not FLAGS.normalized:
                output_example = True
            else:
                # Step 3: match the fillers to the argument slots
                batch_sc_fillers = sc_fillers[batch_id]
                if len(tg_slots) >= len(batch_sc_fillers):
                    if FLAGS.use_copy and FLAGS.copy_fun == 'supervised':
                        target_ast, target, _ = slot_filling.stable_slot_filling(
                            output_tokens,
                            batch_sc_fillers,
                            tg_slots,
                            batch_pointers[batch_id, beam_id, :, :],
                            None,
                            None,
                            None,
                            verbose=False)
                    elif FLAGS.fill_argument_slots:
                        target_ast, target, _ = slot_filling.stable_slot_filling(
                            output_tokens,
                            batch_sc_fillers,
                            tg_slots,
                            None,
                            encoder_outputs[batch_id],
                            decoder_outputs[batch_id * FLAGS.beam_size +
                                            beam_id],
                            slot_filling_classifier,
                            verbose=False)
                    else:
                        output_example = True
                    if not output_example and (target_ast is not None):
                        output_example = True

            if output_example:
                if FLAGS.token_decoding_algorithm == 'greedy':
                    batch_outputs.append((target_ast, target))
                else:
                    beam_outputs.append((target_ast, target))
                num_output_examples += 1

            # The threshold is used to increase decoding speed
            if num_output_examples == 20:
                break

        if FLAGS.token_decoding_algorithm == 'beam_search':
            if beam_outputs:
                batch_outputs.append(beam_outputs)

    # Step 4: apply character decoding
    if FLAGS.tg_char:
        char_output_symbols = model_outputs.char_output_symbols
        sentence_length = char_output_symbols.shape[0]
        batch_char_outputs = []
        batch_char_predictions = \
            [np.transpose(np.reshape(x, [sentence_length, FLAGS.beam_size,
                                         FLAGS.max_tg_token_size + 1]),x
                          (1, 0, 2))
             for x in np.split(char_output_symbols, batch_size, 1)]
        for batch_id in xrange(len(batch_char_predictions)):
            beam_char_outputs = []
            top_k_char_predictions = batch_char_predictions[batch_id]
            for k in xrange(len(top_k_char_predictions)):
                top_k_char_prediction = top_k_char_predictions[k]
                sent = []
                for i in xrange(sentence_length):
                    word = ''
                    for j in xrange(FLAGS.max_tg_token_size):
                        char_prediction = top_k_char_prediction[i, j]
                        if char_prediction == data_utils.CEOS_ID or \
                            char_prediction == data_utils.CPAD_ID:
                            break
                        elif char_prediction in rev_tg_char_vocab:
                            word += rev_tg_char_vocab[char_prediction]
                        else:
                            word += data_utils._CUNK
                    sent.append(word)
                if data_utils._CATOM in sent:
                    sent = sent[:sent[:].index(data_utils._CATOM)]
                beam_char_outputs.append(' '.join(sent))
            batch_char_outputs.append(beam_char_outputs)
        return batch_outputs, batch_char_outputs
    else:
        return batch_outputs
Exemplo n.º 2
0
def decode(model_outputs, FLAGS, vocabs, sc_fillers=None,
           slot_filling_classifier=None, copy_tokens=None):
    """
    Transform the neural network output into readable strings and apply output
    filtering (if any).
    :param encoder_inputs:
    :param model_outputs:
    :param FLAGS:
    :param vocabs:
    :param sc_fillers:
    :param slot_filling_classifier:
    :return batch_outputs: nested list of (target_ast, target) tuples
        - target_ast is a python tree object for target languages that we know
          how to parse and a dummy string for those we don't
        - target is the output string
    """
    rev_tg_vocab = vocabs.rev_tg_vocab

    encoder_outputs = model_outputs.encoder_hidden_states
    decoder_outputs = model_outputs.decoder_hidden_states

    if FLAGS.fill_argument_slots:
        assert(sc_fillers is not None)
        assert(slot_filling_classifier is not None)
        assert(encoder_outputs is not None)
        assert(decoder_outputs is not None)

    output_symbols = model_outputs.output_symbols
    batch_size = len(output_symbols)
    batch_outputs = []
    num_output_examples = 0

    for batch_id in xrange(batch_size):
        def as_str(output, r_tg_vocab):
            if output < FLAGS.tg_vocab_size:
                token = r_tg_vocab[output]
            else:
                if FLAGS.use_copy and FLAGS.copy_fun == 'copynet':
                    source_id = output - FLAGS.tg_vocab_size
                    if source_id >= 0 and source_id < len(copy_tokens[batch_id]):
                        token = copy_tokens[batch_id][source_id]
                    else:
                        return data_utils._UNK
                else:
                    return data_utils._UNK
            return token

        top_k_predictions = output_symbols[batch_id]
        if FLAGS.token_decoding_algorithm == 'beam_search':
            assert(len(top_k_predictions) == FLAGS.beam_size)
            beam_outputs = []
        else:
            # pack greedy decoding results into size-1 beam
            top_k_predictions = [top_k_predictions]

        for beam_id in xrange(len(top_k_predictions)):
            # Step 1: transform the neural network output into readable strings
            prediction = top_k_predictions[beam_id]
            outputs = [int(pred) for pred in prediction]
            
            # If there is an EOS symbol in outputs, cut them at that point.
            if data_utils.EOS_ID in outputs:
                outputs = outputs[:outputs.index(data_utils.EOS_ID)]
            if data_utils.PAD_ID in outputs:
                outputs = outputs[:outputs.index(data_utils.PAD_ID)]
            output_tokens = []
            tg_slots = {}
            for token_id in xrange(len(outputs)):
                output = outputs[token_id]
                pred_token = as_str(output, rev_tg_vocab)
                if data_tools.flag_suffix in pred_token:
                    pred_token = pred_token.split(data_tools.flag_suffix)[0]
                # process argument slots
                if pred_token in bash.argument_types:
                    if token_id > 0 and format_args.is_min_flag(
                        rev_tg_vocab[outputs[token_id-1]]):
                        pred_token_type = 'Timespan'
                    else:
                        pred_token_type = pred_token
                    tg_slots[token_id] = (pred_token, pred_token_type)
                output_tokens.append(pred_token)

            if FLAGS.channel == 'partial.token':
                # process partial-token outputs
                merged_output_tokens = []
                buffer = ''
                load_buffer = False
                for token in output_tokens:
                    if load_buffer:
                        if token == data_utils._ARG_END:
                            merged_output_tokens.append(buffer)
                            load_buffer = False
                            buffer = ''
                        else:
                            buffer += token
                    else:
                        if token == data_utils._ARG_START:
                            load_buffer = True
                        else:
                            merged_output_tokens.append(token)
                if buffer:
                    merged_output_tokens.append(buffer)
                output_tokens = merged_output_tokens
    
            if FLAGS.channel == 'char':
                target = ''
                for char in output_tokens:
                    if char == data_utils.constants._SPACE:
                        target += ' '
                    else:
                        target += char
            else:
                target = ' '.join(output_tokens)
            
            # Step 2: checvik if the predicted command template is grammatical
            if FLAGS.grammatical_only and not FLAGS.explain:
                if FLAGS.dataset.startswith('bash'):
                    target = re.sub('( ;\s+)|( ;$)', ' \\; ', target)
                    target_ast = data_tools.bash_parser(target, verbose=False)
                elif FLAGS.dataset.startswith('regex'):
                    # TODO: check if a predicted regular expression is legal
                    target_ast = '__DUMMY_TREE__'
                else:
                    target_ast = data_tools.paren_parser(target)
                # filter out non-grammatical output
                if target_ast is None:
                    continue
            else:
                target_ast = '__DUMMY_TREE__'

            # Step 3: check if the predicted command templates have enough
            # slots to hold the fillers (to rule out templates that are
            # trivially unqualified)
            output_example = False
            if FLAGS.explain or not FLAGS.dataset.startswith('bash') \
                    or not FLAGS.normalized:
                output_example = True
            else:
                # Step 3: match the fillers to the argument slots
                batch_sc_fillers = sc_fillers[batch_id]
                if len(tg_slots) >= len(batch_sc_fillers):
                    if FLAGS.fill_argument_slots:
                        target_ast, target, _ = slot_filling.stable_slot_filling(
                            output_tokens, batch_sc_fillers, tg_slots, None,
                            encoder_outputs[batch_id],
                            decoder_outputs[batch_id*FLAGS.beam_size+beam_id],
                            slot_filling_classifier, verbose=False)
                    else:
                        output_example = True
                    if not output_example and (target_ast is not None):
                        output_example = True

            if output_example:
                if FLAGS.token_decoding_algorithm == 'greedy':
                    batch_outputs.append((target_ast, target))
                else:
                    beam_outputs.append((target_ast, target))
                num_output_examples += 1

            # The threshold is used to increase decoding speed
            if num_output_examples == 20:
                break

        if FLAGS.token_decoding_algorithm == 'beam_search':
            if beam_outputs:
                batch_outputs.append(beam_outputs)

    return batch_outputs
Exemplo n.º 3
0
def decode_set(model, dataset, rev_sc_vocab, rev_tg_vocab, verbose=True):
    grouped_dataset = data_utils.group_data_by_nl(dataset, use_bucket=False,
                                                  use_temp=False)

    with DBConnection() as db:
        db.remove_model(model_name)
        
        num_eval = 0
        for sc_temp in grouped_dataset:
            batch_sc_strs, batch_tg_strs, batch_scs, batch_cmds = \
                grouped_dataset[sc_temp]
            _, entities = tokenizer.ner_tokenizer(sc_temp)
            nl_fillers = entities[-1]
            if nl_fillers is not None:
                cm_slots = {}

            sc_str = batch_sc_strs[0]
            nl = batch_scs[0]
            if verbose:
                print("Example {}".format(num_eval+1))
                print("Original English: " + sc_str.strip())
                print("English: " + sc_temp)
                for j in xrange(len(batch_tg_strs)):
                    print("GT Command {}: {}".format(j+1, batch_tg_strs[j].strip()))
            # retrieve top-ranked command template
            top_k_results = model.test(nl, 100)
            count = 0
            for i in xrange(len(top_k_results)):
                nn, output_tokens, score = top_k_results[i]
                nn_str = ' '.join([rev_sc_vocab[j] for j in nn])
                tokens = []
                for j in xrange(1, len(output_tokens)-1):
                    pred_token = rev_tg_vocab[output_tokens[j]]
                    if "@@" in pred_token:
                        pred_token = pred_token.split("@@")[-1]
                    if nl_fillers is not None and \
                            pred_token in constants._ENTITIES:
                        if j > 0 and slot_filling.is_min_flag(
                                rev_tg_vocab[output_tokens[j-1]]):
                            pred_token_type = 'Timespan'
                        else:
                            pred_token_type = pred_token
                        cm_slots[j] = (pred_token, pred_token_type)
                    tokens.append(pred_token)
                pred_cmd = ' '.join(tokens)
                # check if the predicted command templates have enough slots to
                # hold the fillers (to rule out templates that are trivially
                # unqualified)
                if FLAGS.dataset.startswith("bash"):
                    pred_cmd = re.sub('( ;\s+)|( ;$)', ' \\; ', pred_cmd)
                    tree = data_tools.bash_parser(pred_cmd)
                else:
                    tree = data_tools.paren_parser(pred_cmd)
                if nl_fillers is None or len(cm_slots) >= len(nl_fillers):
                    # Step 2: check if the predicted command template is grammatical
                    # filter out non-grammatical output
                    if tree is not None:
                        matched = slot_filling.heuristic_slot_filling(tree, nl_fillers)
                if tree is not None:
                    slot_filling.fill_default_value(tree)
                    pred_cmd = data_tools.ast2command(tree)
                if verbose:
                    print("NN: {}".format(nn_str))
                    print("Prediction {}: {} ({})".format(i, pred_cmd, score))
                db.add_prediction(model_name, sc_str, pred_cmd, float(score),
                                      update_mode=False)
                count += 1
                if count == 10:
                    break
            print("")        
            num_eval += 1