def update_batch(self, words_batch, tags_batch): dynet.renew_cg() length = max(len(words) for words in words_batch) word_ids = np.zeros((length, len(words_batch)), dtype='int32') for j, words in enumerate(words_batch): for i, word in enumerate(words): word_ids[i, j] = self.vw.w2i.get(word, self.UNK) tag_ids = np.zeros((length, len(words_batch)), dtype='int32') for j, tags in enumerate(tags_batch): for i, tag in enumerate(tags): tag_ids[i, j] = self.vt.w2i.get(tag, self.UNK) wembs = [dynet.lookup_batch(self._E, word_ids[i]) for i in range(length)] wembs = [dynet.noise(we, 0.1) for we in wembs] f_state = self._fwd_lstm.initial_state() b_state = self._bwd_lstm.initial_state() fw = [x.output() for x in f_state.add_inputs(wembs)] bw = [x.output() for x in b_state.add_inputs(reversed(wembs))] H = dynet.parameter(self._pH) O = dynet.parameter(self._pO) errs = [] for i, (f, b) in enumerate(zip(fw, reversed(bw))): f_b = dynet.concatenate([f,b]) r_t = O * (dynet.tanh(H * f_b)) err = dynet.pickneglogsoftmax_batch(r_t, tag_ids[i]) errs.append(dynet.sum_batches(err)) sum_errs = dynet.esum(errs) squared = -sum_errs # * sum_errs losses = sum_errs.scalar_value() sum_errs.backward() self._sgd.update() return losses
def do_gpu(): G.renew_cg() W = G.parameter(gpW) W = W*W*W*W*W*W*W z = G.squared_distance(W,W) z.value() z.backward()
def predict_batch(self, words_batch): dynet.renew_cg() length = max(len(words) for words in words_batch) word_ids = np.zeros((length, len(words_batch)), dtype='int32') for j, words in enumerate(words_batch): for i, word in enumerate(words): word_ids[i, j] = self.vw.w2i.get(word, self.UNK) wembs = [dynet.lookup_batch(self._E, word_ids[i]) for i in range(length)] f_state = self._fwd_lstm.initial_state() b_state = self._bwd_lstm.initial_state() fw = [x.output() for x in f_state.add_inputs(wembs)] bw = [x.output() for x in b_state.add_inputs(reversed(wembs))] H = dynet.parameter(self._pH) O = dynet.parameter(self._pO) tags_batch = [[] for _ in range(len(words_batch))] for i, (f, b) in enumerate(zip(fw, reversed(bw))): r_t = O * (dynet.tanh(H * dynet.concatenate([f, b]))) out = dynet.softmax(r_t).npvalue() for j in range(len(words_batch)): tags_batch[j].append(self.vt.i2w[np.argmax(out.T[j])]) return tags_batch
def train(self, train_set_de, train_set_en, n_epochs): batched_X_Y = minibatched_data(train_set_de, train_set_en, BATCH_SIZE) for i in range(n_epochs): random.shuffle(batched_X_Y) j = 0 for batch in batched_X_Y: j += BATCH_SIZE batched_X = [x[0] for x in batch] batched_Y = [x[1] for x in batch] #print CURR_BATCH_SIZE, batched_X dy.renew_cg() enc_sentence, dec_sentence = batched_X, batched_Y loss = self.get_loss(enc_sentence, dec_sentence, encoder_fwd_lstm, encoder_bwd_lstm, decoder_lstm) loss_value = loss.value() loss.backward() trainer.update() perplexity = np.exp(loss.value() * len(batched_Y) / float(sum([len(sent) for sent in batched_Y])) ) if j % 1000 == 0: with codecs.open('models_batched/output_loss_samples.txt', 'a+') as open_file: open_file.write("Epoch: " + str(i) + " Samples:" + str(j) + " LOSS: " + str(loss_value) + " Perplexity: " + str(perplexity) + "\n") print ("Epoch: " + str(i) + " Samples:" + str(j) + " LOSS: " + str(loss_value) + " Perplexity: " + str(perplexity) ) #dy.renew_cg() #open_file.write(self.generate(enc_sentence, encoder_fwd_lstm, encoder_bwd_lstm, decoder_lstm) + "\n") #open_file.write("BLEU SCORE : " + str(self.calculate_bleu_score()) + "\n") if j % 20000 == 0: model.save('models_batched/model_'+str(i) + '_' + str(j), [encoder_fwd_lstm, encoder_bwd_lstm, decoder_lstm, input_lookup, output_lookup, attention_w1, attention_w2, attention_v, decoder_w, decoder_b])
def get_loss(self, input_sentence, output_sentence, enc_fwd_lstm, enc_bwd_lstm, dec_lstm): dy.renew_cg() embedded = self.embed_sentence(input_sentence) encoded = self.encode_sentence(encoder_fwd_lstm, encoder_bwd_lstm, embedded) return self.decode(dec_lstm, encoded, output_sentence)
def calculate_bleu_score(self): print "generating sentences" with open("gen_e1k_18.txt", "w") as out: for i, sent in enumerate(encoder_test_wids): dy.renew_cg() decoded_test_sent = self.generate(sent, enc_fwd_lstm, enc_bwd_lstm, dec_lstm) out.write(decoded_test_sent + "\n")
def predictNextWord(sentence, builder, wlookup, mR, mB): dy.renew_cg() init_state = builder.initial_state() R = dy.parameter(mR) bias = dy.parameter(mB) state = init_state for cw in sentence: # assume word is already a word-id x_t = dy.lookup(wlookup, int(cw)) state = state.add_input(x_t) y_t = state.output() r_t = bias + (R * y_t) prob = dy.softmax(r_t) return prob
def encode_sentence(self, enc_fwd_lstm, enc_bwd_lstm, sentence): dy.renew_cg() enc_sents, _ = self.pad_zero(sentence) sent_embed = self.embed_sentence(enc_sents) sentrev_embed = sent_embed[::-1] enc_fwd = enc_fwd_lstm.initial_state() enc_bwd = enc_bwd_lstm.initial_state() fwd_vectors = enc_fwd.transduce(sent_embed) bwd_vectors = enc_bwd.transduce(sentrev_embed) bwd_vectors = bwd_vectors[::-1] vectors = [dy.concatenate(list(p)) for p in zip(fwd_vectors, bwd_vectors)] return vectors
def calc_lm_loss(sents): dy.renew_cg() # parameters -> expressions W_exp = dy.parameter(W_sm) b_exp = dy.parameter(b_sm) # initialize the RNN f_init = RNN.initial_state() # get the wids and masks for each step tot_words = 0 wids = [] masks = [] for i in range(len(sents[0])): wids.append([(vw.w2i[sent[i]] if len(sent) > i else STOP) for sent in sents]) mask = [(1 if len(sent) > i else 0) for sent in sents] masks.append(mask) tot_words += sum(mask) # start the rnn by inputting "<start>" init_ids = [START] * len(sents) s = f_init.add_input(dy.lookup_batch(WORDS_LOOKUP, init_ids)) # feed word vectors into the RNN and predict the next word losses = [] for wid, mask in zip(wids, masks): # calculate the softmax and loss score = W_exp * s.output() + b_exp loss = dy.pickneglogsoftmax_batch(score, wid) # mask the loss if at least one sentence is shorter if mask[-1] != 1: mask_expr = dy.inputVector(mask) mask_expr = dy.reshape(mask_expr, (1,), MB_SIZE) loss = loss * mask_expr losses.append(loss) # update the state of the RNN wemb = dy.lookup_batch(WORDS_LOOKUP, wid) s = s.add_input(wemb) return dy.sum_batches(dy.esum(losses)), tot_words
def __call__(self, words): dynet.renew_cg() word_ids = [self.vw.w2i.get(w, self.UNK) for w in words] wembs = [self._E[w] for w in word_ids] f_state = self._fwd_lstm.initial_state() b_state = self._bwd_lstm.initial_state() fw = [x.output() for x in f_state.add_inputs(wembs)] bw = [x.output() for x in b_state.add_inputs(reversed(wembs))] H = dynet.parameter(self._pH) O = dynet.parameter(self._pO) tags = [] for i, (f, b) in enumerate(zip(fw, reversed(bw))): r_t = O * (dynet.tanh(H * dynet.concatenate([f, b]))) out = dynet.softmax(r_t) tags.append(self.vt.i2w[np.argmax(out.npvalue())]) return tags
def predict_emb(self, chars): dy.renew_cg() finit = self.char_fwd_lstm.initial_state() binit = self.char_bwd_lstm.initial_state() H = dy.parameter(self.lstm_to_rep_params) Hb = dy.parameter(self.lstm_to_rep_bias) O = dy.parameter(self.mlp_out) Ob = dy.parameter(self.mlp_out_bias) pad_char = self.c2i[PADDING_CHAR] char_ids = [pad_char] + chars + [pad_char] embeddings = [self.char_lookup[cid] for cid in char_ids] bi_fwd_out = finit.transduce(embeddings) bi_bwd_out = binit.transduce(reversed(embeddings)) rep = dy.concatenate([bi_fwd_out[-1], bi_bwd_out[-1]]) return O * dy.tanh(H * rep + Hb) + Ob
def generate(in_seq, enc_fwd_lstm, enc_bwd_lstm, dec_lstm): #print "in generate" dy.renew_cg() embedded = embed_sentence(in_seq) encoded = encode_batch(enc_fwd_lstm, enc_bwd_lstm, embedded) w = dy.parameter(decoder_w) b = dy.parameter(decoder_b) w1 = dy.parameter(attention_w1) input_mat = dy.concatenate_cols(encoded) w1dt = None last_output_embeddings = output_lookup[BOS] #s = dec_lstm.initial_state([encoded[-1]]) s = dec_lstm.initial_state() c_t_minus_1 = dy.vecInput(state_size*2) out = [] count_EOS = 0 for i in range(len(in_seq)*2): if count_EOS == 1: break # w1dt can be computed and cached once for the entire decoding phase w1dt = w1dt or w1 * input_mat vector = dy.concatenate([last_output_embeddings, c_t_minus_1]) s = s.add_input(vector) h_t = s.output() c_t = attend_batch(input_mat, s, w1dt, 1, 1) out_vector = dy.affine_transform([b, w, dy.concatenate([h_t, c_t])]) probs = dy.softmax(out_vector).vec_value() next_word = probs.index(max(probs)) last_output_embeddings = output_lookup[next_word] c_t_minus_1 = c_t if next_word == EOS: count_EOS += 1 continue out.append(english_word_vocab.i2w[next_word]) return " ".join(out[1:])
def __step_batch(self, batch): dy.renew_cg() W_y = dy.parameter(self.W_y) b_y = dy.parameter(self.b_y) F = len(batch[0][0]) num_words = F * len(batch) src_batch = [x[0] for x in batch] tgt_batch = [x[1] for x in batch] src_rev_batch = [list(reversed(x)) for x in src_batch] # batch = [ [a1,a2,a3,a4,a5], [b1,b2,b3,b4,b5], [c1,c2,c3,c4] ..] # transpose the batch into # src_cws: [[a1,b1,c1,..], [a2,b2,c2,..], .. [a5,b5,</S>]] src_cws = map(list, zip(*src_batch)) # transpose src_rev_cws = map(list, zip(*src_rev_batch)) # Bidirectional representations l2r_state = self.l2r_builder.initial_state() r2l_state = self.r2l_builder.initial_state() l2r_contexts = [] r2l_contexts = [] for (cw_l2r_list, cw_r2l_list) in zip(src_cws, src_rev_cws): l2r_state = l2r_state.add_input( dy.lookup_batch(self.src_lookup, [ self.src_token_to_id.get(cw_l2r, 0) for cw_l2r in cw_l2r_list ])) r2l_state = r2l_state.add_input( dy.lookup_batch(self.src_lookup, [ self.src_token_to_id.get(cw_r2l, 0) for cw_r2l in cw_r2l_list ])) l2r_contexts.append( l2r_state.output()) #[<S>, x_1, x_2, ..., </S>] r2l_contexts.append( r2l_state.output()) #[</S> x_n, x_{n-1}, ... <S>] r2l_contexts.reverse() #[<S>, x_1, x_2, ..., </S>] # Combine the left and right representations for every word h_fs = [] for (l2r_i, r2l_i) in zip(l2r_contexts, r2l_contexts): h_fs.append(dy.concatenate([l2r_i, r2l_i])) h_fs_matrix = dy.concatenate_cols(h_fs) losses = [] # Decoder # batch = [ [a1,a2,a3,a4,a5], [b1,b2,b3,b4,b5], [c1,c2,c3,c4] ..] # transpose the batch into # tgt_cws: [[a1,b1,c1,..], [a2,b2,c2,..], .. [a5,b5,</S>]] # masks: [1,1,1,..], [1,1,1,..], ...[1,1,0,..]] tgt_cws = [] masks = [] maxLen = max([len(l) for l in tgt_batch]) for i in range(maxLen): tgt_cws.append([]) masks.append([]) for sentence in tgt_batch: for j in range(maxLen): if j > len(sentence) - 1: tgt_cws[j].append('</S>') masks[j].append(0) else: tgt_cws[j].append(sentence[j]) masks[j].append(1) c_t = dy.vecInput(self.hidden_size * 2) start = dy.concatenate( [dy.lookup(self.tgt_lookup, self.tgt_token_to_id['<S>']), c_t]) dec_state = self.dec_builder.initial_state().add_input(start) for (cws, nws, mask) in zip(tgt_cws, tgt_cws[1:], masks): h_e = dec_state.output() _, c_t = self.__attention_mlp(h_fs_matrix, h_e, F) # Get the embedding for the current target word embed_t = dy.lookup_batch( self.tgt_lookup, [self.tgt_token_to_id.get(cw, 0) for cw in cws]) # Create input vector to the decoder x_t = dy.concatenate([embed_t, c_t]) dec_state = dec_state.add_input(x_t) y_star = dy.affine_transform([b_y, W_y, dec_state.output()]) loss = dy.pickneglogsoftmax_batch( y_star, [self.tgt_token_to_id.get(nw, 0) for nw in nws]) mask_exp = dy.reshape(dy.inputVector(mask), (1, ), len(mask)) loss = loss * mask_exp losses.append(loss) return dy.sum_batches(dy.esum(losses)), num_words
def get_batch_loss(input_sentences, output_sentences, enc_fwd_lstm, enc_bwd_lstm, dec_lstm): #print "in batch loss" dy.renew_cg() encoded_batch = encode_batch(enc_fwd_lstm, enc_bwd_lstm, input_sentences) return decode_batch(dec_lstm, encoded_batch, output_sentences)
def translate_sentence(self, sent): dy.renew_cg() W_y = dy.parameter(self.W_y) b_y = dy.parameter(self.b_y) F = len(sent) sent_rev = list(reversed(sent)) # Bidirectional representations l2r_state = self.l2r_builder.initial_state() r2l_state = self.r2l_builder.initial_state() l2r_contexts = [] r2l_contexts = [] for (cw_l2r, cw_r2l) in zip(sent, sent_rev): l2r_state = l2r_state.add_input( dy.lookup(self.src_lookup, self.src_token_to_id.get(cw_l2r, 0))) r2l_state = r2l_state.add_input( dy.lookup(self.src_lookup, self.src_token_to_id.get(cw_r2l, 0))) l2r_contexts.append( l2r_state.output()) #[<S>, x_1, x_2, ..., </S>] r2l_contexts.append( r2l_state.output()) #[</S> x_n, x_{n-1}, ... <S>] r2l_contexts.reverse() h_fs = [] for (l2r_i, r2l_i) in zip(l2r_contexts, r2l_contexts): h_fs.append(dy.concatenate([l2r_i, r2l_i])) h_fs_matrix = dy.concatenate_cols(h_fs) # Decoder trans_sentence = ['<S>'] cw = trans_sentence[-1] c_t = dy.vecInput(l2r_contexts[0].npvalue().shape[0] * 2) start = dy.concatenate( [dy.lookup(self.tgt_lookup, self.tgt_token_to_id['<S>']), c_t]) dec_state = self.dec_builder.initial_state().add_input(start) while len(trans_sentence) < MAX_LEN: h_e = dec_state.output() alignment, c_t = self.__attention_mlp(h_fs_matrix, h_e, F) embed_t = dy.lookup(self.tgt_lookup, self.tgt_token_to_id.get(cw, 0)) # Create input vector to the decoder x_t = dy.concatenate([embed_t, c_t]) dec_state = dec_state.add_input(x_t) y_star = dy.affine_transform([b_y, W_y, dec_state.output()]) cw = self.tgt_id_to_token[np.argmax(y_star.npvalue())] if cw == '<unk>': #unknown words replacement a = alignment.npvalue() if np.argmax(a) == 0: a[0] = 0 if np.argmax(a) == len(a) - 1: a[len(a) - 1] = 0 trans_sentence.append(sent[np.argmax(a)]) continue if cw == '</S>': break trans_sentence.append(cw) return ' '.join(trans_sentence[1:])
def get_loss(input_sentence, output_sentence, enc_fwd_lstm, enc_bwd_lstm, dec_lstm): dy.renew_cg() embedded = embed_sentence(input_sentence) encoded = encode_sentence(enc_fwd_lstm, enc_bwd_lstm, embedded) #print "Encoded: ", encoded return decode(dec_lstm, encoded, output_sentence)
def translate_sentence_beam_compact(self, sent, k=2): dy.renew_cg() F = len(sent) W_y = dy.parameter(self.W_y) b_y = dy.parameter(self.b_y) sent_rev = list(reversed(sent)) # Bidirectional representations l2r_state = self.l2r_builder.initial_state() r2l_state = self.r2l_builder.initial_state() l2r_contexts = [] r2l_contexts = [] for (cw_l2r, cw_r2l) in zip(sent, sent_rev): l2r_state = l2r_state.add_input( dy.lookup(self.src_lookup, self.src_token_to_id.get(cw_l2r, 0))) r2l_state = r2l_state.add_input( dy.lookup(self.src_lookup, self.src_token_to_id.get(cw_r2l, 0))) l2r_contexts.append( l2r_state.output()) #[<S>, x_1, x_2, ..., </S>] r2l_contexts.append( r2l_state.output()) #[</S> x_n, x_{n-1}, ... <S>] r2l_contexts.reverse() h_fs = [] for (l2r_i, r2l_i) in zip(l2r_contexts, r2l_contexts): h_fs.append(dy.concatenate([l2r_i, r2l_i])) h_fs_matrix = dy.concatenate_cols(h_fs) valid = [True for i in range(k)] trans = [['<S>'] for i in range(k)] prob = [0 for i in range(k)] used = [set([0, F - 1]) for i in range(k)] cw = [trans[i][-1] for i in range(k)] c_t = dy.vecInput(l2r_contexts[0].npvalue().shape[0] * 2) start = dy.concatenate( [dy.lookup(self.tgt_lookup, self.tgt_token_to_id['<S>']), c_t]) dec_state = [ self.dec_builder.initial_state().add_input(start) for i in range(k) ] b = [i for i in range(k)] FIRST = True while True in valid: h_e = [dec_state[i].output() for i in range(k)] ac = [self.__attention_mlp(h_fs_matrix, elem, F) for elem in h_e] alignment = [elem[0] for elem in ac] c_t = [elem[1] for elem in ac] embed_t = [ dy.lookup(self.tgt_lookup, self.tgt_token_to_id.get(cw[i], 0)) for i in range(k) ] x_t = [dy.concatenate([embed_t[i], c_t[i]]) for i in range(k)] dec_state = [dec_state[b[i]].add_input(x_t[i]) for i in range(k)] y_star = [ dy.affine_transform([b_y, W_y, dec_state[i].output()]) for i in range(k) ] p = [dy.log_softmax(y_star[i]) for i in range(k)] tmp = [p[i].npvalue() for i in range(k)] l = [] if not FIRST: idx = [np.argpartition(-tmp[i].T, k)[0][:k] for i in range(k)] val = [-np.partition(-tmp[i].T, k)[0][:k] for i in range(k)] for i in range(k): for j in range(k): l.append((self.tgt_id_to_token[idx[i][j]], val[i][j] + prob[i], i)) else: idx = np.argpartition(-tmp[0].T, k)[0][:k] val = -np.partition(-tmp[0].T, k)[0][:k] FIRST = False for i in range(k): l.append((self.tgt_id_to_token[idx[i]], val[i], i)) l = sorted(l, key=lambda x: -x[1])[:k] # print l cw = [l[i][0] for i in range(k)] prob = [l[i][1] for i in range(k)] b = [l[i][2] for i in range(k)] trans = [list(trans[b[i]]) for i in range(k)] valid = [valid[b[i]] for i in range(k)] for i in range(k): if valid[i] == False: continue if cw[i] != '</S>' and len(trans[i]) < MAX_LEN: if cw[i] == '<unk>': #unknown words replacement a = alignment[i].npvalue() am = np.argmax(a) while am in used: a[am] = -1 am = np.argmax(a) used[i].add(am) if len(used[i]) == F / 2: used[i] = set([0, F - 1]) trans[i].append(sent[am]) else: trans[i].append(cw[i]) else: valid[i] = False index = prob.index(max(prob)) return ' '.join(trans[index][1:])