コード例 #1
0
ファイル: tree_lstm.py プロジェクト: gottalottarock/hlstm
    def logits_and_state(self):
        """Creates a block that goes from tokens to (logits, state) tuples."""
        unknown_idx = len(self.vocab)

        def lookup_word(word):
            return self.vocab.get(word, unknown_idx)

        #(GetItem(key) >> block).eval(inp) => block.eval(inp[key])
        # InputTransform(funk): A Python function, lifted to a block.
        # Scalar - input to scalar
        word2vec = (td.GetItem(0) >> td.InputTransform(lookup_word) >>
                    td.Scalar('int32') >> self.word_embedding)
        #
        pair2vec = (self.embed_subtree(), self.embed_subtree())

        # Trees are binary, so the tree layer takes two states as its
        # input_state.
        zero_state = td.Zeros((self.tree_lstm_cell.state_size, ) * 2)
        # Input is a word vector.
        zero_inp = td.Zeros(self.word_embedding.output_type.shape[0])

        # AllOf(a, b, c).eval(inp) => (a.eval(inp), b.eval(inp), c.eval(inp))
        word_case = td.AllOf(word2vec, zero_state)
        pair_case = td.AllOf(zero_inp, pair2vec)
        # OneOf(func, [(key, block),(key,block)])) where funk(input) => key and
        # OneOf returns one of blocks
        tree2vec = td.OneOf(len, [(1, word_case), (2, pair_case)])

        return tree2vec >> self.tree_lstm_cell
コード例 #2
0
def logits_and_state():
    """Creates a block that goes from tokens to (logits, state) tuples."""
    unknown_idx = len(word_idx)

    lookup_word = lambda word: word_idx.get(
        word)  # unknown_idx is the default return value
    word2vec = (
        td.GetItem(0) >> td.GetItem(0) >> td.InputTransform(lookup_word) >>
        td.Scalar('int32') >> word_embedding
    )  # <td.Pipe>: None -> TensorType((200,), 'float32')
    context2vec1 = td.GetItem(1) >> td.InputTransform(
        makeContextMat) >> td.Vector(10)
    context2vec2 = td.GetItem(1) >> td.InputTransform(
        makeContextMat) >> td.Vector(10)
    ent1posit1 = td.GetItem(2) >> td.InputTransform(
        makeEntPositMat) >> td.Vector(10)
    ent1posit2 = td.GetItem(2) >> td.InputTransform(
        makeEntPositMat) >> td.Vector(10)
    ent2posit1 = td.GetItem(3) >> td.InputTransform(
        makeEntPositMat) >> td.Vector(10)
    ent2posit2 = td.GetItem(3) >> td.InputTransform(
        makeEntPositMat) >> td.Vector(10)

    pairs2vec = td.GetItem(0) >> (embed_subtree(), embed_subtree())

    # our binary Tree can have two child nodes, therefore, we assume the zero state have two child nodes.
    zero_state = td.Zeros((tree_lstm.state_size, ) * 2)
    # Input is a word vector.
    zero_inp = td.Zeros(word_embedding.output_type.shape[0]
                        )  # word_embedding.output_type.shape[0] == 200

    word_case = td.AllOf(word2vec, zero_state, context2vec1, ent1posit1,
                         ent2posit1)
    children_case = td.AllOf(zero_inp, pairs2vec, context2vec2, ent1posit2,
                             ent2posit2)
    # if leaf case, go to word case...
    tree2vec = td.OneOf(lambda x: 1
                        if len(x[0]) == 1 else 2, [(1, word_case),
                                                   (2, children_case)])
    # tree2vec = td.OneOf(lambda pair: len(pair[0]), [(1, word_case), (2, children_case)])
    # logits and lstm states
    return tree2vec >> tree_lstm >> (output_layer, td.Identity())
コード例 #3
0
    def logits_and_state():
        """Creates a block that goes from tokens to (logits, state) tuples."""
        unknown_idx = len(word_idx)
        lookup_word = lambda word: word_idx.get(word, unknown_idx)

        word2vec = (td.GetItem(0) >> td.InputTransform(lookup_word) >>
                    td.Scalar('int32') >> word_embedding)

        pair2vec = (embed_subtree(), embed_subtree())

        # Trees are binary, so the tree layer takes two states as its input_state.
        zero_state = td.Zeros((tree_lstm.state_size, ) * 2)
        # Input is a word vector.
        zero_inp = td.Zeros(word_embedding.output_type.shape[0])

        word_case = td.AllOf(word2vec, zero_state)
        pair_case = td.AllOf(zero_inp, pair2vec)

        tree2vec = td.OneOf(len, [(1, word_case), (2, pair_case)])

        return tree2vec >> tree_lstm >> (output_layer, td.Identity())