def update_batch(self, words_batch, tags_batch): dynet.renew_cg() length = max(len(words) for words in words_batch) word_ids = np.zeros((length, len(words_batch)), dtype='int32') for j, words in enumerate(words_batch): for i, word in enumerate(words): word_ids[i, j] = self.vw.w2i.get(word, self.UNK) tag_ids = np.zeros((length, len(words_batch)), dtype='int32') for j, tags in enumerate(tags_batch): for i, tag in enumerate(tags): tag_ids[i, j] = self.vt.w2i.get(tag, self.UNK) wembs = [dynet.lookup_batch(self._E, word_ids[i]) for i in range(length)] wembs = [dynet.noise(we, 0.1) for we in wembs] f_state = self._fwd_lstm.initial_state() b_state = self._bwd_lstm.initial_state() fw = [x.output() for x in f_state.add_inputs(wembs)] bw = [x.output() for x in b_state.add_inputs(reversed(wembs))] H = dynet.parameter(self._pH) O = dynet.parameter(self._pO) errs = [] for i, (f, b) in enumerate(zip(fw, reversed(bw))): f_b = dynet.concatenate([f,b]) r_t = O * (dynet.tanh(H * f_b)) err = dynet.pickneglogsoftmax_batch(r_t, tag_ids[i]) errs.append(dynet.sum_batches(err)) sum_errs = dynet.esum(errs) squared = -sum_errs # * sum_errs losses = sum_errs.scalar_value() sum_errs.backward() self._sgd.update() return losses
def decode(self, dec_lstm, vectors, output_batch): # output = [EOS] + list(output) + [EOS] # output = [char2int[c] for c in output] w = dy.parameter(decoder_w) b = dy.parameter(decoder_b) w1 = dy.parameter(attention_w1) input_mat = dy.concatenate_cols(vectors) input_len = len(vectors) w1dt = w1 * input_mat curr_bsize = len(output_batch) #n_batch_start = [1] * curr_bsize #last_output_embeddings = dy.lookup_batch(output_lookup, n_batch_start) s = dec_lstm.initial_state() c_t_previous = dy.vecInput(STATE_SIZE * 2) loss = [] output_batch, masks = self.create_sentence_batch(output_batch) for i, (word_batch, mask_word) in enumerate(zip(output_batch[1:], masks[1:]), start=1): last_output_embeddings = dy.lookup_batch(output_lookup, output_batch[i - 1]) vector = dy.concatenate([last_output_embeddings, c_t_previous]) s = s.add_input(vector) h_t = s.output() c_t, alpha_t = self.attend(input_mat, s, w1dt, input_len, curr_bsize) h_c_concat = dy.concatenate([h_t, c_t]) out_vector = dy.affine_transform([b, w, h_c_concat]) if DROPOUT > 0.0: out_vector = dy.dropout(out_vector, DROPOUT) loss_current = dy.pickneglogsoftmax_batch(out_vector, output_batch[i]) if 0 in mask_word: mask_vals = dy.inputVector(mask_word) mask_vals = dy.reshape(mask_vals, (1, ), curr_bsize) loss_current = loss_current * mask_vals loss.append(loss_current) c_t_previous = c_t loss = dy.esum(loss) loss = dy.sum_batches(loss) / curr_bsize #perplexity = loss.value() * curr_bsize / float(sum([x.count(1) for x in masks[1:]])) return loss
def decode_batch(dec_lstm, input_encodings, output_sentences): #print "in decode batch" w = dy.parameter(decoder_w) b = dy.parameter(decoder_b) w1 = dy.parameter(attention_w1) output_words, masks = sentences_to_batch(output_sentences) decoder_target_input = zip(output_words[1:], masks[1:]) batch_size = len(output_sentences) input_length = len(input_encodings) input_mat = dy.concatenate_cols(input_encodings) #print "Computing w1dt" w1dt = w1 * input_mat s = dec_lstm.initial_state() c_t_minus_1 = dy.vecInput(state_size * 2) loss = [] for t, (words_t, mask_t) in enumerate(decoder_target_input, start = 1): last_output_embeddings = dy.lookup_batch(output_lookup, output_words[t-1]) vector = dy.concatenate([last_output_embeddings, c_t_minus_1]) s = s.add_input(vector) h_t = s.output() #print "Calling attend" c_t = attend_batch(input_mat, s, w1dt, batch_size, input_length) predicted = dy.affine_transform([b, w, dy.concatenate([h_t, c_t])]) if(dropout > 0.): predicted = dy.dropout(predicted, dropout) cur_loss = dy.pickneglogsoftmax_batch(predicted, words_t) c_t_minus_1 = c_t #Mask the loss in case mask == 0 if 0 in mask_t: mask = dy.inputVector(mask_t) mask = dy.reshape(mask, (1, ), batch_size) cur_loss = cur_loss * mask loss.append(cur_loss) #Get the average batch loss loss = dy.esum(loss) loss = dy.sum_batches(loss) / batch_size return loss
def build_sentence_graph(self, sents, labels): renew_cg() f_init = self.builder.initial_state() STOP = vocab.w2i["<stop>"] START = vocab.w2i["<start>"] W_exp = parameter(self.R) b_exp = parameter(self.bias) state = f_init # get the wids and masks for each step tot_words = 0 wids = [] masks = [] for i in range(len(sents[0])): wids.append([(START if len(sents[0]) - len(sent) > i else vocab.w2i[sent[i - len(sents[0]) + len(sent)]]) for sent in sents]) #wids.append([(vocab.w2i[sent[i]] if len(sent) > i else STOP) for sent in sents]) mask = [(1 if len(sent) > i else 0) for sent in sents] masks.append(mask) tot_words += sum(mask) #print "wids:" #print wids # start the rnn by inputting "<start>" init_ids = [START] * len(sents) s = f_init.add_input(dy.lookup_batch(self.lookup, init_ids)) # feed word vectors into the RNN and predict the next word losses = [] for wid in wids: # calculate the softmax and loss score = W_exp * s.output() + b_exp # update the state of the RNN wemb = dy.lookup_batch(self.lookup, wid) s = s.add_input(wemb) loss = dy.pickneglogsoftmax_batch(score, labels) losses.append(loss) return dy.sum_batches(dy.esum(losses))
def myDecode(self, dec_lstm, h_encodings, target_sents): w = dy.parameter(decoder_w) b = dy.parameter(decoder_b) w1 = dy.parameter(attention_w1) dec_wrds, dec_mask = self.pad_zero(target_sents) curr_bsize = len(target_sents) h_len = len(h_encodings) H_source = dy.concatenate_cols(h_encodings) s = dec_lstm.initial_state() ctx_t0 = dy.vecInput(hidden_size * 2) w1dt = w1 * H_source loss = [] #print curr_bsize for sent in range(1, len(dec_wrds)): last_output_embeddings = dy.lookup_batch(output_lookup, dec_wrds[sent-1]) x = dy.concatenate([ctx_t0, last_output_embeddings]) s = s.add_input(x) h_t = s.output() ctx_t, alpha_t = self.attend(H_source, s, w1dt, h_len, curr_bsize) output_vector = w * dy.concatenate([h_t, ctx_t]) + b #probs = dy.softmax(output_vector) ctx_t0 = ctx_t if dropout_config: output_vector = dy.dropout(output_vector, dropout_val) temp_loss = dy.pickneglogsoftmax_batch(output_vector, dec_wrds[sent]) if 0 in dec_mask[sent]: mask_expr = dy.inputVector(dec_mask[sent]) mask_expr = dy.reshape(mask_expr, (1, ), curr_bsize) temp_loss = temp_loss * mask_expr loss.append(temp_loss) loss = dy.esum(loss) loss = dy.sum_batches(loss) / batch_size return loss
def calc_lm_loss(sents): dy.renew_cg() # parameters -> expressions W_exp = dy.parameter(W_sm) b_exp = dy.parameter(b_sm) # initialize the RNN f_init = RNN.initial_state() # get the wids and masks for each step tot_words = 0 wids = [] masks = [] for i in range(len(sents[0])): wids.append([(vw.w2i[sent[i]] if len(sent) > i else STOP) for sent in sents]) mask = [(1 if len(sent) > i else 0) for sent in sents] masks.append(mask) tot_words += sum(mask) # start the rnn by inputting "<start>" init_ids = [START] * len(sents) s = f_init.add_input(dy.lookup_batch(WORDS_LOOKUP, init_ids)) # feed word vectors into the RNN and predict the next word losses = [] for wid, mask in zip(wids, masks): # calculate the softmax and loss score = W_exp * s.output() + b_exp loss = dy.pickneglogsoftmax_batch(score, wid) # mask the loss if at least one sentence is shorter if mask[-1] != 1: mask_expr = dy.inputVector(mask) mask_expr = dy.reshape(mask_expr, (1,), MB_SIZE) loss = loss * mask_expr losses.append(loss) # update the state of the RNN wemb = dy.lookup_batch(WORDS_LOOKUP, wid) s = s.add_input(wemb) return dy.sum_batches(dy.esum(losses)), tot_words
def build_sentence_graph(self, sents, labels): renew_cg() f_init = self.builder.initial_state() START = vocab.w2i["<start>"] W_exp = parameter(self.R) b_exp = parameter(self.bias) state = f_init # get the wids and masks for each step tot_words = 0 wids = [] masks = [] #pad the sequences that are shorter, with a START at the beginning for i in range(len(sents[0])): wids.append([(START if len(sents[0]) - len(sent) > i else vocab.w2i[sent[i - len(sents[0]) + len(sent)]]) for sent in sents]) mask = [(1 if len(sent) > i else 0) for sent in sents] masks.append(mask) tot_words += sum(mask) init_ids = [START] * len(sents) s = f_init.add_input(dy.lookup_batch(self.lookup, init_ids)) losses = [] for wid in wids: # calculate the softmax and loss score = W_exp * s.output() + b_exp # update the state of the RNN wemb = dy.lookup_batch(self.lookup, wid) s = s.add_input(wemb) loss = dy.pickneglogsoftmax_batch(score, labels) losses.append(loss) return dy.sum_batches(dy.esum(losses))
def build_sentence_graph(self, sents, labels): renew_cg() f_init = self.builder.initial_state() START = vocab.w2i["<start>"] W_exp = parameter(self.R) b_exp = parameter(self.bias) state = f_init # get the wids and masks for each step tot_words = 0 wids = [] masks = [] #pad the sequences that are shorter, with a START at the beginning for i in range(len(sents[0])): wids.append([(START if len(sents[0])-len(sent) > i else vocab.w2i[sent[i - len(sents[0])+len(sent)]]) for sent in sents]) mask = [(1 if len(sent) > i else 0) for sent in sents] masks.append(mask) tot_words += sum(mask) init_ids = [START] * len(sents) s = f_init.add_input(dy.lookup_batch(self.lookup, init_ids)) losses = [] for wid in wids: # calculate the softmax and loss score = W_exp * s.output() + b_exp # update the state of the RNN wemb = dy.lookup_batch(self.lookup, wid) s = s.add_input(wemb) loss = dy.pickneglogsoftmax_batch(score, labels) losses.append(loss) return dy.sum_batches(dy.esum(losses))
if(dropout > 0.): predicted = dy.dropout(predicted, dropout) cur_loss = dy.pickneglogsoftmax_batch(predicted, output_words[t]) c_t_minus_1 = c_t #Mask the loss in case mask == 0 if 0 in masks[t]: mask = dy.inputVector(masks[t]) mask = dy.reshape(mask, (1, ), batch_size) cur_loss = cur_loss * mask loss.append(cur_loss) #Get the average batch loss loss = dy.esum(loss) loss = dy.sum_batches(loss) / batch_size return loss #Generate the translations from the current trained state of the model def generate(in_seq, enc_fwd_lstm, enc_bwd_lstm, dec_lstm): #print "in generate" dy.renew_cg() embedded = embed_sentence(in_seq) encoded = encode_batch(enc_fwd_lstm, enc_bwd_lstm, embedded) w = dy.parameter(decoder_w) b = dy.parameter(decoder_b) w1 = dy.parameter(attention_w1) input_mat = dy.concatenate_cols(encoded) w1dt = None
def __step_batch(self, batch): dy.renew_cg() W_y = dy.parameter(self.W_y) b_y = dy.parameter(self.b_y) F = len(batch[0][0]) num_words = F * len(batch) src_batch = [x[0] for x in batch] tgt_batch = [x[1] for x in batch] src_rev_batch = [list(reversed(x)) for x in src_batch] # batch = [ [a1,a2,a3,a4,a5], [b1,b2,b3,b4,b5], [c1,c2,c3,c4] ..] # transpose the batch into # src_cws: [[a1,b1,c1,..], [a2,b2,c2,..], .. [a5,b5,</S>]] src_cws = map(list, zip(*src_batch)) # transpose src_rev_cws = map(list, zip(*src_rev_batch)) # Bidirectional representations l2r_state = self.l2r_builder.initial_state() r2l_state = self.r2l_builder.initial_state() l2r_contexts = [] r2l_contexts = [] for (cw_l2r_list, cw_r2l_list) in zip(src_cws, src_rev_cws): l2r_state = l2r_state.add_input( dy.lookup_batch(self.src_lookup, [ self.src_token_to_id.get(cw_l2r, 0) for cw_l2r in cw_l2r_list ])) r2l_state = r2l_state.add_input( dy.lookup_batch(self.src_lookup, [ self.src_token_to_id.get(cw_r2l, 0) for cw_r2l in cw_r2l_list ])) l2r_contexts.append( l2r_state.output()) #[<S>, x_1, x_2, ..., </S>] r2l_contexts.append( r2l_state.output()) #[</S> x_n, x_{n-1}, ... <S>] r2l_contexts.reverse() #[<S>, x_1, x_2, ..., </S>] # Combine the left and right representations for every word h_fs = [] for (l2r_i, r2l_i) in zip(l2r_contexts, r2l_contexts): h_fs.append(dy.concatenate([l2r_i, r2l_i])) h_fs_matrix = dy.concatenate_cols(h_fs) losses = [] # Decoder # batch = [ [a1,a2,a3,a4,a5], [b1,b2,b3,b4,b5], [c1,c2,c3,c4] ..] # transpose the batch into # tgt_cws: [[a1,b1,c1,..], [a2,b2,c2,..], .. [a5,b5,</S>]] # masks: [1,1,1,..], [1,1,1,..], ...[1,1,0,..]] tgt_cws = [] masks = [] maxLen = max([len(l) for l in tgt_batch]) for i in range(maxLen): tgt_cws.append([]) masks.append([]) for sentence in tgt_batch: for j in range(maxLen): if j > len(sentence) - 1: tgt_cws[j].append('</S>') masks[j].append(0) else: tgt_cws[j].append(sentence[j]) masks[j].append(1) c_t = dy.vecInput(self.hidden_size * 2) start = dy.concatenate( [dy.lookup(self.tgt_lookup, self.tgt_token_to_id['<S>']), c_t]) dec_state = self.dec_builder.initial_state().add_input(start) for (cws, nws, mask) in zip(tgt_cws, tgt_cws[1:], masks): h_e = dec_state.output() _, c_t = self.__attention_mlp(h_fs_matrix, h_e, F) # Get the embedding for the current target word embed_t = dy.lookup_batch( self.tgt_lookup, [self.tgt_token_to_id.get(cw, 0) for cw in cws]) # Create input vector to the decoder x_t = dy.concatenate([embed_t, c_t]) dec_state = dec_state.add_input(x_t) y_star = dy.affine_transform([b_y, W_y, dec_state.output()]) loss = dy.pickneglogsoftmax_batch( y_star, [self.tgt_token_to_id.get(nw, 0) for nw in nws]) mask_exp = dy.reshape(dy.inputVector(mask), (1, ), len(mask)) loss = loss * mask_exp losses.append(loss) return dy.sum_batches(dy.esum(losses)), num_words