def logits_and_state(self): """Creates a block that goes from tokens to (logits, state) tuples.""" unknown_idx = len(self.vocab) def lookup_word(word): return self.vocab.get(word, unknown_idx) #(GetItem(key) >> block).eval(inp) => block.eval(inp[key]) # InputTransform(funk): A Python function, lifted to a block. # Scalar - input to scalar word2vec = (td.GetItem(0) >> td.InputTransform(lookup_word) >> td.Scalar('int32') >> self.word_embedding) # pair2vec = (self.embed_subtree(), self.embed_subtree()) # Trees are binary, so the tree layer takes two states as its # input_state. zero_state = td.Zeros((self.tree_lstm_cell.state_size, ) * 2) # Input is a word vector. zero_inp = td.Zeros(self.word_embedding.output_type.shape[0]) # AllOf(a, b, c).eval(inp) => (a.eval(inp), b.eval(inp), c.eval(inp)) word_case = td.AllOf(word2vec, zero_state) pair_case = td.AllOf(zero_inp, pair2vec) # OneOf(func, [(key, block),(key,block)])) where funk(input) => key and # OneOf returns one of blocks tree2vec = td.OneOf(len, [(1, word_case), (2, pair_case)]) return tree2vec >> self.tree_lstm_cell
def logits_and_state(): """Creates a block that goes from tokens to (logits, state) tuples.""" unknown_idx = len(word_idx) lookup_word = lambda word: word_idx.get( word) # unknown_idx is the default return value word2vec = ( td.GetItem(0) >> td.GetItem(0) >> td.InputTransform(lookup_word) >> td.Scalar('int32') >> word_embedding ) # <td.Pipe>: None -> TensorType((200,), 'float32') context2vec1 = td.GetItem(1) >> td.InputTransform( makeContextMat) >> td.Vector(10) context2vec2 = td.GetItem(1) >> td.InputTransform( makeContextMat) >> td.Vector(10) ent1posit1 = td.GetItem(2) >> td.InputTransform( makeEntPositMat) >> td.Vector(10) ent1posit2 = td.GetItem(2) >> td.InputTransform( makeEntPositMat) >> td.Vector(10) ent2posit1 = td.GetItem(3) >> td.InputTransform( makeEntPositMat) >> td.Vector(10) ent2posit2 = td.GetItem(3) >> td.InputTransform( makeEntPositMat) >> td.Vector(10) pairs2vec = td.GetItem(0) >> (embed_subtree(), embed_subtree()) # our binary Tree can have two child nodes, therefore, we assume the zero state have two child nodes. zero_state = td.Zeros((tree_lstm.state_size, ) * 2) # Input is a word vector. zero_inp = td.Zeros(word_embedding.output_type.shape[0] ) # word_embedding.output_type.shape[0] == 200 word_case = td.AllOf(word2vec, zero_state, context2vec1, ent1posit1, ent2posit1) children_case = td.AllOf(zero_inp, pairs2vec, context2vec2, ent1posit2, ent2posit2) # if leaf case, go to word case... tree2vec = td.OneOf(lambda x: 1 if len(x[0]) == 1 else 2, [(1, word_case), (2, children_case)]) # tree2vec = td.OneOf(lambda pair: len(pair[0]), [(1, word_case), (2, children_case)]) # logits and lstm states return tree2vec >> tree_lstm >> (output_layer, td.Identity())
def logits_and_state(): """Creates a block that goes from tokens to (logits, state) tuples.""" unknown_idx = len(word_idx) lookup_word = lambda word: word_idx.get(word, unknown_idx) word2vec = (td.GetItem(0) >> td.InputTransform(lookup_word) >> td.Scalar('int32') >> word_embedding) pair2vec = (embed_subtree(), embed_subtree()) # Trees are binary, so the tree layer takes two states as its input_state. zero_state = td.Zeros((tree_lstm.state_size, ) * 2) # Input is a word vector. zero_inp = td.Zeros(word_embedding.output_type.shape[0]) word_case = td.AllOf(word2vec, zero_state) pair_case = td.AllOf(zero_inp, pair2vec) tree2vec = td.OneOf(len, [(1, word_case), (2, pair_case)]) return tree2vec >> tree_lstm >> (output_layer, td.Identity())