def build_program_decoder_for_analysis(token_emb_size, rnn_cell): """ Does the same as build_program_decoder_for_analysis, but also returns the final hidden state of the decoder """ decoder_rnn = td.ScopedLayer(rnn_cell, 'decoder') decoder_rnn_output = td.RNN(decoder_rnn, initial_state_from_input=True) >> td.GetItem(0) fc_layer = td.FC(token_emb_size, activation=tf.nn.relu, initializer=tf.contrib.layers.xavier_initializer(), name='encoder_fc') # decoder_rnn_output.reads() un_normalised_token_probs = td.Map(fc_layer) return decoder_rnn_output >> td.AllOf(un_normalised_token_probs, td.Identity())
def logits_and_state(): """Creates a block that goes from tokens to (logits, state) tuples.""" unknown_idx = len(word_idx) lookup_word = lambda word: word_idx.get( word) # unknown_idx is the default return value word2vec = ( td.GetItem(0) >> td.GetItem(0) >> td.InputTransform(lookup_word) >> td.Scalar('int32') >> word_embedding ) # <td.Pipe>: None -> TensorType((200,), 'float32') context2vec1 = td.GetItem(1) >> td.InputTransform( makeContextMat) >> td.Vector(10) context2vec2 = td.GetItem(1) >> td.InputTransform( makeContextMat) >> td.Vector(10) ent1posit1 = td.GetItem(2) >> td.InputTransform( makeEntPositMat) >> td.Vector(10) ent1posit2 = td.GetItem(2) >> td.InputTransform( makeEntPositMat) >> td.Vector(10) ent2posit1 = td.GetItem(3) >> td.InputTransform( makeEntPositMat) >> td.Vector(10) ent2posit2 = td.GetItem(3) >> td.InputTransform( makeEntPositMat) >> td.Vector(10) pairs2vec = td.GetItem(0) >> (embed_subtree(), embed_subtree()) # our binary Tree can have two child nodes, therefore, we assume the zero state have two child nodes. zero_state = td.Zeros((tree_lstm.state_size, ) * 2) # Input is a word vector. zero_inp = td.Zeros(word_embedding.output_type.shape[0] ) # word_embedding.output_type.shape[0] == 200 word_case = td.AllOf(word2vec, zero_state, context2vec1, ent1posit1, ent2posit1) children_case = td.AllOf(zero_inp, pairs2vec, context2vec2, ent1posit2, ent2posit2) # if leaf case, go to word case... tree2vec = td.OneOf(lambda x: 1 if len(x[0]) == 1 else 2, [(1, word_case), (2, children_case)]) # tree2vec = td.OneOf(lambda pair: len(pair[0]), [(1, word_case), (2, children_case)]) # logits and lstm states return tree2vec >> tree_lstm >> (output_layer, td.Identity())
def build_program_decoder(token_emb_size, rnn_cell, just_tokens=False): """ Used for blind or 'look-behind' decoders """ decoder_rnn = td.ScopedLayer(rnn_cell, 'decoder') decoder_rnn_output = td.RNN(decoder_rnn, initial_state_from_input=True) >> td.GetItem(0) fc_layer = td.FC( token_emb_size, activation=tf.nn.relu, initializer=tf.contrib.layers.xavier_initializer(), name='encoder_fc' # this is fantastic ) # un_normalised_token_probs = decoder_rnn_output >> td.Map(fc_layer) if just_tokens: return decoder_rnn_output >> td.Map(fc_layer) else: return decoder_rnn_output >> td.AllOf(td.Map(fc_layer), td.Identity())
def logits_and_state(): """Creates a block that goes from tokens to (logits, state) tuples.""" unknown_idx = len(word_idx) lookup_word = lambda word: word_idx.get(word, unknown_idx) word2vec = (td.GetItem(0) >> td.InputTransform(lookup_word) >> td.Scalar('int32') >> word_embedding) pair2vec = (embed_subtree(), embed_subtree()) # Trees are binary, so the tree layer takes two states as its input_state. zero_state = td.Zeros((tree_lstm.state_size, ) * 2) # Input is a word vector. zero_inp = td.Zeros(word_embedding.output_type.shape[0]) word_case = td.AllOf(word2vec, zero_state) pair_case = td.AllOf(zero_inp, pair2vec) tree2vec = td.OneOf(len, [(1, word_case), (2, pair_case)]) return tree2vec >> tree_lstm >> (output_layer, td.Identity())
def bidirectional_dynamic_FC(fw_cell, bw_cell, hidden): bidir_conv_lstm = td.Composition() with bidir_conv_lstm.scope(): fw_seq = td.Identity().reads(bidir_conv_lstm.input[0]) labels = ( td.GetItem(1) >> td.Map(td.Metric("labels")) >> td.Void()).reads( bidir_conv_lstm.input) bw_seq = td.Slice(step=-1).reads(fw_seq) forward_dir = (td.RNN(fw_cell) >> td.GetItem(0)).reads(fw_seq) back_dir = (td.RNN(bw_cell) >> td.GetItem(0)).reads(bw_seq) back_to_leftright = td.Slice(step=-1).reads(back_dir) output_transform = td.FC(1, activation=None) bidir_common = (td.ZipWith( td.Concat() >> output_transform >> td.Metric('logits'))).reads( forward_dir, back_to_leftright) bidir_conv_lstm.output.reads(bidir_common) return bidir_conv_lstm