class TopDownDepLM: def __init__(self, pc, vocab, layers, state_dim, final_hidden_dim, tied, residual): self.vocab = vocab self.layers = layers self.state_dim = state_dim self.tied = tied self.residual = residual self.done_with_left = vocab.convert('</LEFT>') self.done_with_right = vocab.convert('</RIGHT>') vocab_size = len(self.vocab) self.pc = pc.add_subcollection() if not self.tied: self.word_embs = self.pc.add_lookup_parameters( (vocab_size, state_dim)) self.top_lstm = dy.LSTMBuilder(layers, state_dim, state_dim, self.pc) self.vertical_lstm = dy.LSTMBuilder(layers, state_dim, state_dim, self.pc) self.gate_mlp = MLP(self.pc, [2 * state_dim, state_dim, state_dim]) self.open_constit_lstms = [] self.debug_stack = [] self.spine = [] self.final_mlp = MLP(self.pc, [state_dim, final_hidden_dim, vocab_size]) self.top_initial_state = [ self.pc.add_parameters((state_dim, )) for _ in range(2 * layers) ] self.open_initial_state = [ self.pc.add_parameters((state_dim, )) for _ in range(2 * layers) ] def set_dropout(self, r): self.dropout_rate = r self.top_lstm.set_dropout(r) self.vertical_lstm.set_dropout(r) self.final_mlp.set_dropout(r) def new_graph(self): # Do LSTM builders need reset? self.final_mlp.new_graph() self.gate_mlp.new_graph() def embed_word(self, word): if self.tied: word_embs = self.final_mlp.layers[-1].w word_emb = dy.select_rows(word_embs, [word]) word_emb = dy.transpose(word_emb) else: word_emb = dy.lookup(self.word_embs, word) return word_emb def add_to_last(self, word): assert len(self.open_constit_lstms) > 0 word_emb = self.embed_word(word) new_rep = self.open_constit_lstms[-1].add_input(word_emb) self.open_constit_lstms[-1] = new_rep self.debug_stack[-1].append(self.vocab.to_word(word)) def pop_and_add(self, word): assert len(self.open_constit_lstms) >= 1 word_emb = self.embed_word(word) child_state = self.open_constit_lstms[-1].add_input(word_emb) child_emb = child_state.output() self.open_constit_lstms.pop() if len(self.open_constit_lstms) > 0: self.open_constit_lstms[-1] = self.open_constit_lstms[ -1].add_input(child_emb) self.spine.pop() self.debug_stack[-1].append(self.vocab.to_word(word)) debug_child = self.debug_stack.pop() if len(self.debug_stack) > 0: self.debug_stack[-1].append(debug_child) def push(self, word): word_emb = self.embed_word(word) new_state = self.vertical_lstm.initial_state() new_state = new_state.set_s(self.open_initial_state) new_state = new_state.add_input(word_emb) self.open_constit_lstms.append(new_state) self.spine.append(word) self.debug_stack.append([self.vocab.to_word(word)]) def add_input(self, state, word): word_emb = self.embed_word(word) if word == self.done_with_left: self.add_to_last(word) elif word == self.done_with_right: self.pop_and_add(word) else: self.push(word) #print('After:', self.debug_stack) assert len(self.debug_stack) == len(self.open_constit_lstms) return ParserState(self.open_constit_lstms, self.spine) def new_sent(self): new_state = self.vertical_lstm.initial_state() new_state = new_state.set_s(self.open_initial_state) self.open_constit_lstms = [new_state] self.spine = [-1] self.debug_stack = [[]] return ParserState(self.open_constit_lstms, self.spine) def debug_embed_vertical(self, vertical): state = self.vertical_lstm.initial_state() state = state.set_s(self.open_initial_state) for word in vertical: if type(word) == list: emb = self.debug_embed_vertical(word) else: emb = self.embed_word(self.vocab.convert(word)) state = state.add_input(emb) return state.output() def debug_embed(self): top_state = self.top_lstm.initial_state() top_state = top_state.set_s(self.top_initial_state) assert len(self.open_constit_lstms) == len(self.debug_stack) for i, open_constit in enumerate(self.debug_stack): emb = self.debug_embed_vertical(open_constit) top_state = top_state.add_input(emb) alt = self.open_constit_lstms[i] #c = 'O' if np.isclose(emb.npvalue(), alt.output().npvalue()).all() else 'X' #print(c, emb.npvalue(), alt.output().npvalue()) #assert np.isclose(emb.npvalue(), alt.output().npvalue()).all() #print() return top_state warned = False def compute_loss(self, state, word): top_state = self.top_lstm.initial_state() top_state = top_state.set_s(self.top_initial_state) assert len(state.open_constits) == len(state.spine) for open_constit, spine_word in zip(state.open_constits, state.spine): constit_emb = open_constit.output() if self.residual and spine_word != -1: spine_word_emb = self.embed_word(spine_word) if False: constit_emb += spine_word_emb else: inp = dy.concatenate([constit_emb, spine_word_emb]) mask = self.gate_mlp(inp) mask = dy.logistic(mask) constit_emb = dy.cmult(1 - mask, constit_emb) constit_emb = constit_emb + dy.cmult(mask, spine_word_emb) top_state = top_state.add_input(constit_emb) #debug_top_state = self.debug_embed() #assert np.isclose(top_state.output().npvalue(), debug_top_state.output().npvalue()).all() logits = self.final_mlp(top_state.output()) loss = dy.pickneglogsoftmax(logits, word) #if not self.warned: # sys.stderr.write('WARNING: compute_loss hacked to not include actual terminals.\n') # self.warned = True #if word != 0 and word != 1: # probs = -dy.softmax(logits) # left_prob = dy.pick(probs, 0) # right_prob = dy.pick(probs, 1) # loss = dy.log(1 - left_prob - right_prob) #else: # loss = dy.pickneglogsoftmax(logits, word) return loss def build_graph(self, sent): state = self.new_sent() losses = [] for word in sent: loss = self.compute_loss(state, word) losses.append(loss) state = self.add_input(state, word) return dy.esum(losses)
class RNNLM: def __init__(self, pc, layers, emb_dim, hidden_dim, vocab_size, tied): self.spec = (layers, emb_dim, hidden_dim, vocab_size) self.pc = pc.add_subcollection() self.rnn = dy.LSTMBuilder(layers, emb_dim, hidden_dim, self.pc) self.initial_state_params = [ self.pc.add_parameters((hidden_dim, )) for _ in range(2 * layers) ] self.output_mlp = MLP(self.pc, [hidden_dim, hidden_dim, vocab_size]) self.tied = tied if not self.tied: self.word_embs = self.pc.add_lookup_parameters( (vocab_size, emb_dim)) self.dropout_rate = 0.0 def new_graph(self): self.output_mlp.new_graph() self.initial_state = [ dy.parameter(p) for p in self.initial_state_params ] #self.exp = dy.scalarInput(-0.5) def set_dropout(self, r): self.dropout_rate = r self.output_mlp.set_dropout(r) self.rnn.set_dropout(r) def embed_word(self, word): if self.tied: word_embs = self.output_mlp.layers[-1].w word_emb = dy.select_rows(word_embs, [word]) word_emb = dy.transpose(word_emb) else: word_emb = dy.lookup(self.word_embs, word) # Normalize word vectors to have length one #word_emb_norm = dy.pow(dy.dot_product(word_emb, word_emb), self.exp) #word_emb = word_emb * word_emb_norm return word_emb def build_graph(self, sent): state = self.rnn.initial_state() state = state.set_s(self.initial_state) losses = [] for word in sent: assert state != None so = state.output() assert so != None output_dist = self.output_mlp(so) loss = dy.pickneglogsoftmax(output_dist, word) losses.append(loss) word_emb = self.embed_word(word) if self.dropout_rate > 0.0: word_emb = dy.dropout(word_emb, self.dropout_rate) state = state.add_input(word_emb) return dy.esum(losses) def sample(self, eos, max_len): #dy.renew_cg() #self.new_graph() state = self.rnn.initial_state() state = state.set_s(self.initial_state) sent = [] while len(sent) < max_len: assert state != None so = state.output() assert so != None output_dist = dy.softmax(self.output_mlp(so)) output_dist = output_dist.vec_value() word = sample(output_dist) sent.append(word) if word == eos: break word_emb = self.embed_word(word) state = state.add_input(word_emb) return sent def param_collection(self): return self.pc @staticmethod def from_spec(spec, pc): rnnlm = RNNLM(pc, *spec) return rnnlm
class BottomUpDepLM: def __init__(self, pc, action_vocab, word_vocab_size, rel_vocab_size, layers, hidden_dim, labelled=True, tied=False): self.labelled = labelled self.tied = tied self.action_vocab = action_vocab self.pc = pc.add_subcollection() action_vocab_size = len(action_vocab) if not self.tied: self.word_embs = self.pc.add_lookup_parameters( (word_vocab_size, hidden_dim)) self.action_mlp = MLP(self.pc, [hidden_dim, hidden_dim, action_vocab_size]) self.word_mlp = MLP(self.pc, [hidden_dim, hidden_dim, word_vocab_size]) self.combine_mlp = MLP(self.pc, [2 * hidden_dim, hidden_dim, hidden_dim]) self.stack_lstm = dy.LSTMBuilder(layers, hidden_dim, hidden_dim, self.pc) self.initial_state_params = [ self.pc.add_parameters((hidden_dim, )) for _ in range(2 * layers) ] self.stack_embs = [] if labelled: self.rel_embs = self.pc.add_lookup_parameters( (rel_vocab_size, hidden_dim)) self.rel_mlp = MLP(self.pc, [hidden_dim, hidden_dim, rel_vocab_size]) def new_graph(self): self.action_mlp.new_graph() self.word_mlp.new_graph() self.combine_mlp.new_graph() if self.labelled: self.rel_mlp.new_graph() self.initial_state = [ dy.parameter(p) for p in self.initial_state_params ] def new_sent(self): self.stack_embs = [] self.stack = [] state = self.stack_lstm.initial_state() state = state.set_s(self.initial_state) self.stack_embs.append(state) def set_dropout(self, r): self.action_mlp.set_dropout(r) self.word_mlp.set_dropout(r) self.combine_mlp.set_dropout(r) self.stack_lstm.set_dropout(r) if self.labelled: self.rel_mlp.set_dropout(r) def combine(self, head, child, direction): head_and_child = dy.concatenate([head, child]) return self.combine_mlp(head_and_child) def embed_word(self, word): if self.tied: word_embs = self.word_mlp.layers[-1].w word_emb = dy.select_rows(word_embs, [word]) word_emb = dy.transpose(word_emb) else: word_emb = dy.lookup(self.word_embs, word) return word_emb def embed_stack_naive(self): state = self.stack_lstm.initial_state() state = state.set_s(self.initial_state) for item in self.stack: state = state.add_input(item) return state.output() def embed_stack(self): return self.stack_embs[-1].output() def pop(self): self.stack.pop() self.stack_embs.pop() def push(self, v): self.stack.append(v) state = self.stack_embs[-1] state = state.add_input(v) self.stack_embs.append(state) def shift(self, word): word_emb = self.embed_word(word) self.push(word_emb) def reduce_right(self): assert len(self.stack) >= 2 head = self.stack[-1] child = self.stack[-2] self.pop() self.pop() combined = self.combine(head, child, 'right') self.push(combined) def reduce_left(self): assert len(self.stack) >= 2 head = self.stack[-2] child = self.stack[-1] self.pop() self.pop() combined = self.combine(head, child, 'left') self.push(combined) warned = False def build_graph(self, sent): losses = [] self.new_sent() for action, subtype in sent: action_str = self.action_vocab.to_word(action) # predict action hidden_state = self.embed_stack() action_logits = self.action_mlp(hidden_state) action_nlp = dy.pickneglogsoftmax(action_logits, action) loss = action_nlp if action_str == 'shift': if not self.warned: sys.stderr.write( 'WARNING: Hacked to not include terminal losses') self.warned = True #word_logits = self.word_mlp(hidden_state) #word_nlp = dy.pickneglogsoftmax(word_logits, subtype) #loss += word_nlp elif self.labelled: rel_logits = self.rel_mlp(hidden_state) rel_nlp = dy.pickneglogsoftmax(rel_logits, subtype) #loss += rel_nlp losses.append(loss) # Do the reference action if action_str == 'shift': self.shift(subtype) elif action_str == 'right': self.reduce_right() elif action_str == 'left': self.reduce_left() else: assert 'Unknown action: %s' % action_str return dy.esum(losses)