Example #1
0
class TopDownDepLM:
    def __init__(self, pc, vocab, layers, state_dim, final_hidden_dim, tied,
                 residual):
        self.vocab = vocab
        self.layers = layers
        self.state_dim = state_dim
        self.tied = tied
        self.residual = residual
        self.done_with_left = vocab.convert('</LEFT>')
        self.done_with_right = vocab.convert('</RIGHT>')
        vocab_size = len(self.vocab)

        self.pc = pc.add_subcollection()
        if not self.tied:
            self.word_embs = self.pc.add_lookup_parameters(
                (vocab_size, state_dim))

        self.top_lstm = dy.LSTMBuilder(layers, state_dim, state_dim, self.pc)
        self.vertical_lstm = dy.LSTMBuilder(layers, state_dim, state_dim,
                                            self.pc)
        self.gate_mlp = MLP(self.pc, [2 * state_dim, state_dim, state_dim])
        self.open_constit_lstms = []
        self.debug_stack = []
        self.spine = []
        self.final_mlp = MLP(self.pc,
                             [state_dim, final_hidden_dim, vocab_size])

        self.top_initial_state = [
            self.pc.add_parameters((state_dim, )) for _ in range(2 * layers)
        ]
        self.open_initial_state = [
            self.pc.add_parameters((state_dim, )) for _ in range(2 * layers)
        ]

    def set_dropout(self, r):
        self.dropout_rate = r
        self.top_lstm.set_dropout(r)
        self.vertical_lstm.set_dropout(r)
        self.final_mlp.set_dropout(r)

    def new_graph(self):
        # Do LSTM builders need reset?
        self.final_mlp.new_graph()
        self.gate_mlp.new_graph()

    def embed_word(self, word):
        if self.tied:
            word_embs = self.final_mlp.layers[-1].w
            word_emb = dy.select_rows(word_embs, [word])
            word_emb = dy.transpose(word_emb)
        else:
            word_emb = dy.lookup(self.word_embs, word)
        return word_emb

    def add_to_last(self, word):
        assert len(self.open_constit_lstms) > 0
        word_emb = self.embed_word(word)
        new_rep = self.open_constit_lstms[-1].add_input(word_emb)
        self.open_constit_lstms[-1] = new_rep

        self.debug_stack[-1].append(self.vocab.to_word(word))

    def pop_and_add(self, word):
        assert len(self.open_constit_lstms) >= 1
        word_emb = self.embed_word(word)
        child_state = self.open_constit_lstms[-1].add_input(word_emb)
        child_emb = child_state.output()
        self.open_constit_lstms.pop()
        if len(self.open_constit_lstms) > 0:
            self.open_constit_lstms[-1] = self.open_constit_lstms[
                -1].add_input(child_emb)
        self.spine.pop()

        self.debug_stack[-1].append(self.vocab.to_word(word))
        debug_child = self.debug_stack.pop()
        if len(self.debug_stack) > 0:
            self.debug_stack[-1].append(debug_child)

    def push(self, word):
        word_emb = self.embed_word(word)

        new_state = self.vertical_lstm.initial_state()
        new_state = new_state.set_s(self.open_initial_state)
        new_state = new_state.add_input(word_emb)
        self.open_constit_lstms.append(new_state)
        self.spine.append(word)

        self.debug_stack.append([self.vocab.to_word(word)])

    def add_input(self, state, word):
        word_emb = self.embed_word(word)
        if word == self.done_with_left:
            self.add_to_last(word)
        elif word == self.done_with_right:
            self.pop_and_add(word)
        else:
            self.push(word)
        #print('After:', self.debug_stack)
        assert len(self.debug_stack) == len(self.open_constit_lstms)
        return ParserState(self.open_constit_lstms, self.spine)

    def new_sent(self):
        new_state = self.vertical_lstm.initial_state()
        new_state = new_state.set_s(self.open_initial_state)
        self.open_constit_lstms = [new_state]
        self.spine = [-1]
        self.debug_stack = [[]]
        return ParserState(self.open_constit_lstms, self.spine)

    def debug_embed_vertical(self, vertical):
        state = self.vertical_lstm.initial_state()
        state = state.set_s(self.open_initial_state)
        for word in vertical:
            if type(word) == list:
                emb = self.debug_embed_vertical(word)
            else:
                emb = self.embed_word(self.vocab.convert(word))
            state = state.add_input(emb)
        return state.output()

    def debug_embed(self):
        top_state = self.top_lstm.initial_state()
        top_state = top_state.set_s(self.top_initial_state)

        assert len(self.open_constit_lstms) == len(self.debug_stack)
        for i, open_constit in enumerate(self.debug_stack):
            emb = self.debug_embed_vertical(open_constit)
            top_state = top_state.add_input(emb)
            alt = self.open_constit_lstms[i]
            #c = 'O' if np.isclose(emb.npvalue(), alt.output().npvalue()).all() else 'X'
            #print(c, emb.npvalue(), alt.output().npvalue())
            #assert np.isclose(emb.npvalue(), alt.output().npvalue()).all()
        #print()
        return top_state

    warned = False

    def compute_loss(self, state, word):
        top_state = self.top_lstm.initial_state()
        top_state = top_state.set_s(self.top_initial_state)
        assert len(state.open_constits) == len(state.spine)
        for open_constit, spine_word in zip(state.open_constits, state.spine):
            constit_emb = open_constit.output()
            if self.residual and spine_word != -1:
                spine_word_emb = self.embed_word(spine_word)
                if False:
                    constit_emb += spine_word_emb
                else:
                    inp = dy.concatenate([constit_emb, spine_word_emb])
                    mask = self.gate_mlp(inp)
                    mask = dy.logistic(mask)
                    constit_emb = dy.cmult(1 - mask, constit_emb)
                    constit_emb = constit_emb + dy.cmult(mask, spine_word_emb)
            top_state = top_state.add_input(constit_emb)
        #debug_top_state = self.debug_embed()
        #assert np.isclose(top_state.output().npvalue(), debug_top_state.output().npvalue()).all()

        logits = self.final_mlp(top_state.output())
        loss = dy.pickneglogsoftmax(logits, word)

        #if not self.warned:
        #  sys.stderr.write('WARNING: compute_loss hacked to not include actual terminals.\n')
        #  self.warned = True
        #if word != 0 and word != 1:
        #  probs = -dy.softmax(logits)
        #  left_prob = dy.pick(probs, 0)
        #  right_prob = dy.pick(probs, 1)
        #  loss = dy.log(1 - left_prob - right_prob)
        #else:
        #  loss = dy.pickneglogsoftmax(logits, word)

        return loss

    def build_graph(self, sent):
        state = self.new_sent()

        losses = []
        for word in sent:
            loss = self.compute_loss(state, word)
            losses.append(loss)
            state = self.add_input(state, word)

        return dy.esum(losses)
Example #2
0
class RNNLM:
    def __init__(self, pc, layers, emb_dim, hidden_dim, vocab_size, tied):
        self.spec = (layers, emb_dim, hidden_dim, vocab_size)
        self.pc = pc.add_subcollection()
        self.rnn = dy.LSTMBuilder(layers, emb_dim, hidden_dim, self.pc)
        self.initial_state_params = [
            self.pc.add_parameters((hidden_dim, )) for _ in range(2 * layers)
        ]
        self.output_mlp = MLP(self.pc, [hidden_dim, hidden_dim, vocab_size])
        self.tied = tied
        if not self.tied:
            self.word_embs = self.pc.add_lookup_parameters(
                (vocab_size, emb_dim))
        self.dropout_rate = 0.0

    def new_graph(self):
        self.output_mlp.new_graph()
        self.initial_state = [
            dy.parameter(p) for p in self.initial_state_params
        ]
        #self.exp = dy.scalarInput(-0.5)

    def set_dropout(self, r):
        self.dropout_rate = r
        self.output_mlp.set_dropout(r)
        self.rnn.set_dropout(r)

    def embed_word(self, word):
        if self.tied:
            word_embs = self.output_mlp.layers[-1].w
            word_emb = dy.select_rows(word_embs, [word])
            word_emb = dy.transpose(word_emb)
        else:
            word_emb = dy.lookup(self.word_embs, word)

        # Normalize word vectors to have length one
        #word_emb_norm = dy.pow(dy.dot_product(word_emb, word_emb), self.exp)
        #word_emb = word_emb * word_emb_norm
        return word_emb

    def build_graph(self, sent):
        state = self.rnn.initial_state()
        state = state.set_s(self.initial_state)

        losses = []
        for word in sent:
            assert state != None
            so = state.output()
            assert so != None
            output_dist = self.output_mlp(so)
            loss = dy.pickneglogsoftmax(output_dist, word)
            losses.append(loss)
            word_emb = self.embed_word(word)
            if self.dropout_rate > 0.0:
                word_emb = dy.dropout(word_emb, self.dropout_rate)

            state = state.add_input(word_emb)
        return dy.esum(losses)

    def sample(self, eos, max_len):
        #dy.renew_cg()
        #self.new_graph()
        state = self.rnn.initial_state()
        state = state.set_s(self.initial_state)
        sent = []
        while len(sent) < max_len:
            assert state != None
            so = state.output()
            assert so != None
            output_dist = dy.softmax(self.output_mlp(so))
            output_dist = output_dist.vec_value()
            word = sample(output_dist)
            sent.append(word)
            if word == eos:
                break
            word_emb = self.embed_word(word)
            state = state.add_input(word_emb)
        return sent

    def param_collection(self):
        return self.pc

    @staticmethod
    def from_spec(spec, pc):
        rnnlm = RNNLM(pc, *spec)
        return rnnlm
Example #3
0
class BottomUpDepLM:
    def __init__(self,
                 pc,
                 action_vocab,
                 word_vocab_size,
                 rel_vocab_size,
                 layers,
                 hidden_dim,
                 labelled=True,
                 tied=False):
        self.labelled = labelled
        self.tied = tied
        self.action_vocab = action_vocab
        self.pc = pc.add_subcollection()
        action_vocab_size = len(action_vocab)

        if not self.tied:
            self.word_embs = self.pc.add_lookup_parameters(
                (word_vocab_size, hidden_dim))
        self.action_mlp = MLP(self.pc,
                              [hidden_dim, hidden_dim, action_vocab_size])
        self.word_mlp = MLP(self.pc, [hidden_dim, hidden_dim, word_vocab_size])

        self.combine_mlp = MLP(self.pc,
                               [2 * hidden_dim, hidden_dim, hidden_dim])

        self.stack_lstm = dy.LSTMBuilder(layers, hidden_dim, hidden_dim,
                                         self.pc)
        self.initial_state_params = [
            self.pc.add_parameters((hidden_dim, )) for _ in range(2 * layers)
        ]
        self.stack_embs = []

        if labelled:
            self.rel_embs = self.pc.add_lookup_parameters(
                (rel_vocab_size, hidden_dim))
            self.rel_mlp = MLP(self.pc,
                               [hidden_dim, hidden_dim, rel_vocab_size])

    def new_graph(self):
        self.action_mlp.new_graph()
        self.word_mlp.new_graph()
        self.combine_mlp.new_graph()
        if self.labelled:
            self.rel_mlp.new_graph()
        self.initial_state = [
            dy.parameter(p) for p in self.initial_state_params
        ]

    def new_sent(self):
        self.stack_embs = []
        self.stack = []
        state = self.stack_lstm.initial_state()
        state = state.set_s(self.initial_state)
        self.stack_embs.append(state)

    def set_dropout(self, r):
        self.action_mlp.set_dropout(r)
        self.word_mlp.set_dropout(r)
        self.combine_mlp.set_dropout(r)
        self.stack_lstm.set_dropout(r)
        if self.labelled:
            self.rel_mlp.set_dropout(r)

    def combine(self, head, child, direction):
        head_and_child = dy.concatenate([head, child])
        return self.combine_mlp(head_and_child)

    def embed_word(self, word):
        if self.tied:
            word_embs = self.word_mlp.layers[-1].w
            word_emb = dy.select_rows(word_embs, [word])
            word_emb = dy.transpose(word_emb)
        else:
            word_emb = dy.lookup(self.word_embs, word)
        return word_emb

    def embed_stack_naive(self):
        state = self.stack_lstm.initial_state()
        state = state.set_s(self.initial_state)
        for item in self.stack:
            state = state.add_input(item)
        return state.output()

    def embed_stack(self):
        return self.stack_embs[-1].output()

    def pop(self):
        self.stack.pop()
        self.stack_embs.pop()

    def push(self, v):
        self.stack.append(v)
        state = self.stack_embs[-1]
        state = state.add_input(v)
        self.stack_embs.append(state)

    def shift(self, word):
        word_emb = self.embed_word(word)
        self.push(word_emb)

    def reduce_right(self):
        assert len(self.stack) >= 2
        head = self.stack[-1]
        child = self.stack[-2]
        self.pop()
        self.pop()
        combined = self.combine(head, child, 'right')
        self.push(combined)

    def reduce_left(self):
        assert len(self.stack) >= 2
        head = self.stack[-2]
        child = self.stack[-1]
        self.pop()
        self.pop()
        combined = self.combine(head, child, 'left')
        self.push(combined)

    warned = False

    def build_graph(self, sent):
        losses = []
        self.new_sent()
        for action, subtype in sent:
            action_str = self.action_vocab.to_word(action)

            # predict action
            hidden_state = self.embed_stack()
            action_logits = self.action_mlp(hidden_state)
            action_nlp = dy.pickneglogsoftmax(action_logits, action)

            loss = action_nlp
            if action_str == 'shift':
                if not self.warned:
                    sys.stderr.write(
                        'WARNING: Hacked to not include terminal losses')
                    self.warned = True
                #word_logits = self.word_mlp(hidden_state)
                #word_nlp = dy.pickneglogsoftmax(word_logits, subtype)
                #loss += word_nlp
            elif self.labelled:
                rel_logits = self.rel_mlp(hidden_state)
                rel_nlp = dy.pickneglogsoftmax(rel_logits, subtype)
                #loss += rel_nlp
            losses.append(loss)

            # Do the reference action
            if action_str == 'shift':
                self.shift(subtype)
            elif action_str == 'right':
                self.reduce_right()
            elif action_str == 'left':
                self.reduce_left()
            else:
                assert 'Unknown action: %s' % action_str

        return dy.esum(losses)