def build_program_decoder_for_analysis(token_emb_size, rnn_cell):
    """
    Does the same as build_program_decoder_for_analysis, but also returns
        the final hidden state of the decoder
    """
    decoder_rnn = td.ScopedLayer(rnn_cell, 'decoder')
    decoder_rnn_output = td.RNN(decoder_rnn,
                                initial_state_from_input=True) >> td.GetItem(0)

    fc_layer = td.FC(token_emb_size,
                     activation=tf.nn.relu,
                     initializer=tf.contrib.layers.xavier_initializer(),
                     name='encoder_fc')
    # decoder_rnn_output.reads()
    un_normalised_token_probs = td.Map(fc_layer)
    return decoder_rnn_output >> td.AllOf(un_normalised_token_probs,
                                          td.Identity())
Example #2
0
def logits_and_state():
    """Creates a block that goes from tokens to (logits, state) tuples."""
    unknown_idx = len(word_idx)

    lookup_word = lambda word: word_idx.get(
        word)  # unknown_idx is the default return value
    word2vec = (
        td.GetItem(0) >> td.GetItem(0) >> td.InputTransform(lookup_word) >>
        td.Scalar('int32') >> word_embedding
    )  # <td.Pipe>: None -> TensorType((200,), 'float32')
    context2vec1 = td.GetItem(1) >> td.InputTransform(
        makeContextMat) >> td.Vector(10)
    context2vec2 = td.GetItem(1) >> td.InputTransform(
        makeContextMat) >> td.Vector(10)
    ent1posit1 = td.GetItem(2) >> td.InputTransform(
        makeEntPositMat) >> td.Vector(10)
    ent1posit2 = td.GetItem(2) >> td.InputTransform(
        makeEntPositMat) >> td.Vector(10)
    ent2posit1 = td.GetItem(3) >> td.InputTransform(
        makeEntPositMat) >> td.Vector(10)
    ent2posit2 = td.GetItem(3) >> td.InputTransform(
        makeEntPositMat) >> td.Vector(10)

    pairs2vec = td.GetItem(0) >> (embed_subtree(), embed_subtree())

    # our binary Tree can have two child nodes, therefore, we assume the zero state have two child nodes.
    zero_state = td.Zeros((tree_lstm.state_size, ) * 2)
    # Input is a word vector.
    zero_inp = td.Zeros(word_embedding.output_type.shape[0]
                        )  # word_embedding.output_type.shape[0] == 200

    word_case = td.AllOf(word2vec, zero_state, context2vec1, ent1posit1,
                         ent2posit1)
    children_case = td.AllOf(zero_inp, pairs2vec, context2vec2, ent1posit2,
                             ent2posit2)
    # if leaf case, go to word case...
    tree2vec = td.OneOf(lambda x: 1
                        if len(x[0]) == 1 else 2, [(1, word_case),
                                                   (2, children_case)])
    # tree2vec = td.OneOf(lambda pair: len(pair[0]), [(1, word_case), (2, children_case)])
    # logits and lstm states
    return tree2vec >> tree_lstm >> (output_layer, td.Identity())
def build_program_decoder(token_emb_size, rnn_cell, just_tokens=False):
    """
    Used for blind or 'look-behind' decoders
    """
    decoder_rnn = td.ScopedLayer(rnn_cell, 'decoder')
    decoder_rnn_output = td.RNN(decoder_rnn,
                                initial_state_from_input=True) >> td.GetItem(0)

    fc_layer = td.FC(
        token_emb_size,
        activation=tf.nn.relu,
        initializer=tf.contrib.layers.xavier_initializer(),
        name='encoder_fc'  # this is fantastic
    )

    # un_normalised_token_probs = decoder_rnn_output >> td.Map(fc_layer)
    if just_tokens:
        return decoder_rnn_output >> td.Map(fc_layer)
    else:
        return decoder_rnn_output >> td.AllOf(td.Map(fc_layer), td.Identity())
Example #4
0
    def logits_and_state():
        """Creates a block that goes from tokens to (logits, state) tuples."""
        unknown_idx = len(word_idx)
        lookup_word = lambda word: word_idx.get(word, unknown_idx)

        word2vec = (td.GetItem(0) >> td.InputTransform(lookup_word) >>
                    td.Scalar('int32') >> word_embedding)

        pair2vec = (embed_subtree(), embed_subtree())

        # Trees are binary, so the tree layer takes two states as its input_state.
        zero_state = td.Zeros((tree_lstm.state_size, ) * 2)
        # Input is a word vector.
        zero_inp = td.Zeros(word_embedding.output_type.shape[0])

        word_case = td.AllOf(word2vec, zero_state)
        pair_case = td.AllOf(zero_inp, pair2vec)

        tree2vec = td.OneOf(len, [(1, word_case), (2, pair_case)])

        return tree2vec >> tree_lstm >> (output_layer, td.Identity())
Example #5
0
def bidirectional_dynamic_FC(fw_cell, bw_cell, hidden):
    bidir_conv_lstm = td.Composition()
    with bidir_conv_lstm.scope():
        fw_seq = td.Identity().reads(bidir_conv_lstm.input[0])
        labels = (
            td.GetItem(1) >> td.Map(td.Metric("labels")) >> td.Void()).reads(
                bidir_conv_lstm.input)
        bw_seq = td.Slice(step=-1).reads(fw_seq)

        forward_dir = (td.RNN(fw_cell) >> td.GetItem(0)).reads(fw_seq)
        back_dir = (td.RNN(bw_cell) >> td.GetItem(0)).reads(bw_seq)
        back_to_leftright = td.Slice(step=-1).reads(back_dir)

        output_transform = td.FC(1, activation=None)

        bidir_common = (td.ZipWith(
            td.Concat() >> output_transform >> td.Metric('logits'))).reads(
                forward_dir, back_to_leftright)

        bidir_conv_lstm.output.reads(bidir_common)
    return bidir_conv_lstm