Ejemplo n.º 1
0
    def decode(self, dec_lstm, vectors, output):
        # output = [EOS] + list(output) + [EOS]
        # output = [char2int[c] for c in output]

        w = dy.parameter(decoder_w)
        b = dy.parameter(decoder_b)
        w1 = dy.parameter(attention_w1)
        input_mat = dy.concatenate_cols(vectors)

        w1dt = w1 * input_mat

        last_output_embeddings = output_lookup[EOS]

        s = dec_lstm.initial_state()
        c_t_previous = dy.vecInput(STATE_SIZE * 2)

        loss = []

        for word in output:
            vector = dy.concatenate([last_output_embeddings, c_t_previous])
            s = s.add_input(vector)
            h_t = s.output()
            c_t, alpha_t = self.attend(input_mat, s, w1dt)

            h_c_concat = dy.concatenate([h_t, c_t])
            out_vector = dy.affine_transform([b, w, h_c_concat])

            loss_current = dy.pickneglogsoftmax(out_vector, word)
            last_output_embeddings = output_lookup[word]
            loss.append(loss_current)
            c_t_previous = c_t

        loss = dy.esum(loss)
        return loss
Ejemplo n.º 2
0
def decode(dec_lstm, vectors, output):
	#Convert the words to word-ids
	w = dy.parameter(decoder_w)
	b = dy.parameter(decoder_b)
	w1 = dy.parameter(attention_w1)
	input_mat = dy.concatenate_cols(vectors)
	w1dt = None

	last_output_embeddings = output_lookup[output[-1]]
	#s = dec_lstm.initial_state().add_input(dy.concatenate([dy.vecInput(state_size*2), last_output_embeddings]))
	#s = dec_lstm.initial_state([vectors[-1]])
	s = dec_lstm.initial_state()
	c_t_minus_1 = dy.vecInput(state_size*2)
	loss = []

	for word in output:
		w1dt = w1dt or w1 * input_mat
		vector = dy.concatenate([last_output_embeddings, c_t_minus_1])
		s = s.add_input(vector)
		h_t = s.output()
		c_t = attend(input_mat, s, w1dt)
		predicted = dy.affine_transform([b, w, dy.concatenate([h_t, c_t])])
		cur_loss = dy.pickneglogsoftmax(predicted, word)
		last_output_embeddings = output_lookup[word]
		loss.append(cur_loss)
		c_t_minus_1 = c_t

	loss = dy.esum(loss)
	#print "Loss = ", loss
	return loss
    def decode(self, dec_lstm, vectors, output_batch):
        # output = [EOS] + list(output) + [EOS]
        # output = [char2int[c] for c in output]

        w = dy.parameter(decoder_w)
        b = dy.parameter(decoder_b)
        w1 = dy.parameter(attention_w1)
        input_mat = dy.concatenate_cols(vectors)
        input_len = len(vectors)

        w1dt = w1 * input_mat

        curr_bsize = len(output_batch)

        #n_batch_start = [1] * curr_bsize
        #last_output_embeddings = dy.lookup_batch(output_lookup, n_batch_start)

        s = dec_lstm.initial_state()
        c_t_previous = dy.vecInput(STATE_SIZE * 2)

        loss = []

        output_batch, masks = self.create_sentence_batch(output_batch)

        for i, (word_batch,
                mask_word) in enumerate(zip(output_batch[1:], masks[1:]),
                                        start=1):
            last_output_embeddings = dy.lookup_batch(output_lookup,
                                                     output_batch[i - 1])
            vector = dy.concatenate([last_output_embeddings, c_t_previous])
            s = s.add_input(vector)
            h_t = s.output()
            c_t, alpha_t = self.attend(input_mat, s, w1dt, input_len,
                                       curr_bsize)

            h_c_concat = dy.concatenate([h_t, c_t])
            out_vector = dy.affine_transform([b, w, h_c_concat])

            if DROPOUT > 0.0:
                out_vector = dy.dropout(out_vector, DROPOUT)

            loss_current = dy.pickneglogsoftmax_batch(out_vector,
                                                      output_batch[i])

            if 0 in mask_word:
                mask_vals = dy.inputVector(mask_word)
                mask_vals = dy.reshape(mask_vals, (1, ), curr_bsize)
                loss_current = loss_current * mask_vals

            loss.append(loss_current)
            c_t_previous = c_t

        loss = dy.esum(loss)
        loss = dy.sum_batches(loss) / curr_bsize
        #perplexity = loss.value() * curr_bsize / float(sum([x.count(1) for x in masks[1:]]))
        return loss
Ejemplo n.º 4
0
    def update_batch(self, words_batch, tags_batch):
        dynet.renew_cg()
        length = max(len(words) for words in words_batch)
        word_ids = np.zeros((length, len(words_batch)), dtype='int32')
        for j, words in enumerate(words_batch):
            for i, word in enumerate(words):
                word_ids[i, j] = self.vw.w2i.get(word, self.UNK)
        tag_ids = np.zeros((length, len(words_batch)), dtype='int32')
        for j, tags in enumerate(tags_batch):
            for i, tag in enumerate(tags):
                tag_ids[i, j] = self.vt.w2i.get(tag, self.UNK)
        wembs = [dynet.lookup_batch(self._E, word_ids[i]) for i in range(length)]
        wembs = [dynet.noise(we, 0.1) for we in wembs]
        
        f_state = self._fwd_lstm.initial_state()
        b_state = self._bwd_lstm.initial_state()

        fw = [x.output() for x in f_state.add_inputs(wembs)]
        bw = [x.output() for x in b_state.add_inputs(reversed(wembs))]

        H = dynet.parameter(self._pH)
        O = dynet.parameter(self._pO)
        
        errs = []
        for i, (f, b) in enumerate(zip(fw, reversed(bw))):
            f_b = dynet.concatenate([f,b])
            r_t = O * (dynet.tanh(H * f_b))
            err = dynet.pickneglogsoftmax_batch(r_t, tag_ids[i])
            errs.append(dynet.sum_batches(err))
        sum_errs = dynet.esum(errs)
        squared = -sum_errs # * sum_errs
        losses = sum_errs.scalar_value()
        sum_errs.backward()
        self._sgd.update()
        return losses
Ejemplo n.º 5
0
    def predict_batch(self, words_batch):
        dynet.renew_cg()
        length = max(len(words) for words in words_batch)
        word_ids = np.zeros((length, len(words_batch)), dtype='int32')
        for j, words in enumerate(words_batch):
            for i, word in enumerate(words):
                word_ids[i, j] = self.vw.w2i.get(word, self.UNK)
        wembs = [dynet.lookup_batch(self._E, word_ids[i]) for i in range(length)]
        
        f_state = self._fwd_lstm.initial_state()
        b_state = self._bwd_lstm.initial_state()

        fw = [x.output() for x in f_state.add_inputs(wembs)]
        bw = [x.output() for x in b_state.add_inputs(reversed(wembs))]

        H = dynet.parameter(self._pH)
        O = dynet.parameter(self._pO)
        
        tags_batch = [[] for _ in range(len(words_batch))]
        for i, (f, b) in enumerate(zip(fw, reversed(bw))):
            r_t = O * (dynet.tanh(H * dynet.concatenate([f, b])))
            out = dynet.softmax(r_t).npvalue()
            for j in range(len(words_batch)):
                tags_batch[j].append(self.vt.i2w[np.argmax(out.T[j])])
        return tags_batch
Ejemplo n.º 6
0
def encode_sentence(enc_fwd_lstm, enc_bwd_lstm, sentence):
	sentence_rev = list(reversed(sentence))
	#Fwd-bwd encodings
	fwd_vectors = run_lstm(enc_fwd_lstm.initial_state(), sentence)
	bwd_vectors = run_lstm(enc_bwd_lstm.initial_state(), sentence_rev)
	bwd_vectors = list(reversed(bwd_vectors))
	vectors = [dy.concatenate(list(p)) for p in zip(fwd_vectors, bwd_vectors)]
	return vectors
Ejemplo n.º 7
0
 def encode_sentence_batch(self, enc_fwd_lstm, enc_bwd_lstm, sentence_batch):
     sentence_rev_batch = list(reversed(sentence_batch)) #self.reverse_sentence_batch(sentence_batch)
     fwd_vectors = self.run_lstm(enc_fwd_lstm.initial_state(), sentence_batch)
     bwd_vectors = self.run_lstm(enc_bwd_lstm.initial_state(), sentence_rev_batch)
     bwd_vectors = list(reversed(bwd_vectors))
     
     vectors = [dy.concatenate(list(p)) for p in zip(fwd_vectors, bwd_vectors)]
     return vectors
Ejemplo n.º 8
0
def decode_batch(dec_lstm, input_encodings, output_sentences):
	#print "in decode batch"
	w = dy.parameter(decoder_w)
	b = dy.parameter(decoder_b)
	w1 = dy.parameter(attention_w1)

	output_words, masks = sentences_to_batch(output_sentences)
	decoder_target_input = zip(output_words[1:], masks[1:])

	batch_size = len(output_sentences)
	input_length = len(input_encodings)

	input_mat = dy.concatenate_cols(input_encodings)
        #print "Computing w1dt"
	w1dt = w1 * input_mat

	s = dec_lstm.initial_state()
	c_t_minus_1 = dy.vecInput(state_size * 2)
	loss = []

	for t, (words_t, mask_t) in enumerate(decoder_target_input, start = 1):
		last_output_embeddings = dy.lookup_batch(output_lookup, output_words[t-1])
		vector = dy.concatenate([last_output_embeddings, c_t_minus_1])
		s = s.add_input(vector)
		h_t = s.output()
                #print "Calling attend"
		c_t = attend_batch(input_mat, s, w1dt, batch_size, input_length)
		predicted = dy.affine_transform([b, w, dy.concatenate([h_t, c_t])])
		if(dropout > 0.):
			predicted = dy.dropout(predicted, dropout)
		cur_loss = dy.pickneglogsoftmax_batch(predicted, words_t)
                c_t_minus_1 = c_t

		#Mask the loss in case mask == 0
		if 0 in mask_t:
		    mask = dy.inputVector(mask_t)
		    mask = dy.reshape(mask, (1, ), batch_size)
		    cur_loss = cur_loss * mask

		loss.append(cur_loss)
	
	#Get the average batch loss
	loss = dy.esum(loss)
	loss = dy.sum_batches(loss) / batch_size
	return loss
    def generate(self, in_seq, enc_fwd_lstm, enc_bwd_lstm, dec_lstm):
        embedded = self.embed_sentence(in_seq)
        encoded = self.encode_sentence(enc_fwd_lstm, enc_bwd_lstm, embedded)

        w = dy.parameter(decoder_w)
        b = dy.parameter(decoder_b)
        w1 = dy.parameter(attention_w1)
        input_mat = dy.concatenate_cols(encoded)

        w1dt = w1 * input_mat

        last_output_embeddings = output_lookup[EOS]

        s = dec_lstm.initial_state()
        c_t_previous = dy.vecInput(STATE_SIZE * 2)

        out = ''
        count_EOS = 0

        for i in range(len(in_seq) * 2):
            if count_EOS == 2: break
            vector = dy.concatenate([last_output_embeddings, c_t_previous])
            s = s.add_input(vector)
            h_t = s.output()
            c_t, alpha_t = self.attend(input_mat, s, w1dt)

            h_c_concat = dy.concatenate([h_t, c_t])
            out_vector = dy.affine_transform([b, w, h_c_concat])

            probs = dy.softmax(out_vector).vec_value()

            next_char = probs.index(max(probs))
            last_output_embeddings = output_lookup[next_char]
            c_t_previous = c_t

            if next_char == EOS:

                count_EOS += 1
                continue

            out += " " + id2word_en[next_char]

        return out
Ejemplo n.º 10
0
    def myDecode(self, dec_lstm, h_encodings, target_sents):
        w = dy.parameter(decoder_w)
        b = dy.parameter(decoder_b)
        w1 = dy.parameter(attention_w1)

        dec_wrds, dec_mask = self.pad_zero(target_sents)
        curr_bsize = len(target_sents)
        h_len = len(h_encodings)

        H_source = dy.concatenate_cols(h_encodings)
        s = dec_lstm.initial_state()
        ctx_t0 = dy.vecInput(hidden_size * 2)
        w1dt = w1 * H_source

        loss = []

        #print curr_bsize
        for sent in range(1, len(dec_wrds)):
            last_output_embeddings = dy.lookup_batch(output_lookup, dec_wrds[sent-1])
            x = dy.concatenate([ctx_t0, last_output_embeddings])
            s = s.add_input(x)
            h_t = s.output()

            ctx_t, alpha_t = self.attend(H_source, s, w1dt, h_len, curr_bsize)
            output_vector = w * dy.concatenate([h_t, ctx_t]) + b
            #probs = dy.softmax(output_vector)
            ctx_t0 = ctx_t
            if dropout_config:
                output_vector = dy.dropout(output_vector, dropout_val)

            temp_loss = dy.pickneglogsoftmax_batch(output_vector, dec_wrds[sent])

            if 0 in dec_mask[sent]:
                mask_expr = dy.inputVector(dec_mask[sent])
                mask_expr = dy.reshape(mask_expr, (1, ), curr_bsize)
                temp_loss = temp_loss * mask_expr

            loss.append(temp_loss)

        loss = dy.esum(loss)
        loss = dy.sum_batches(loss) / batch_size
        return loss
Ejemplo n.º 11
0
    def encode_generation(self, enc_fwd_lstm, enc_bwd_lstm, sentence):
        sentence_rev = list(reversed(sentence))

        enc_fwd = enc_fwd_lstm.initial_state()
        enc_bwd = enc_bwd_lstm.initial_state()
        fwd_vectors = self.run_lstm(enc_fwd, sentence)
        bwd_vectors = self.run_lstm(enc_bwd, sentence_rev)
        bwd_vectors = list(reversed(bwd_vectors))

        vectors = [dy.concatenate(list(p)) for p in zip(fwd_vectors, bwd_vectors)]
        return vectors
    def attend(self, input_mat, state, w1dt, input_len, batch_size):
        global attention_w2
        global attention_v
        w2 = dy.parameter(attention_w2)
        v = dy.parameter(attention_v)
        w2dt = w2 * dy.concatenate(list(state.s()))
        unnormalized = dy.transpose(v * dy.tanh(dy.colwise_add(w1dt, w2dt)))
        unnormalized = dy.reshape(unnormalized, (input_len, ), batch_size)
        att_weights = dy.softmax(unnormalized)

        context = input_mat * att_weights
        return context, att_weights
Ejemplo n.º 13
0
def decode_batch(dec_lstm, input_encodings, output_sentences):
	#print "in decode batch"
	w = dy.parameter(decoder_w)
	b = dy.parameter(decoder_b)
	w1 = dy.parameter(attention_w1)

	output_words, masks = sentences_to_batch(output_sentences)
	#decoder_target_input = zip(output_words[1:], masks[1:])

	batch_size = len(output_sentences)
	input_length = len(input_encodings)

	input_mat = dy.concatenate_cols(input_encodings)
        #print "Computing w1dt"
	w1dt = w1 * input_mat

	s = dec_lstm.initial_state()
	c_t_minus_1 = dy.vecInput(state_size * 2)
	loss = []

	for t in range(1, len(output_words)):
	    last_output_embeddings = dy.lookup_batch(output_lookup, output_words[t-1])
	    vector = dy.concatenate([c_t_minus_1, last_output_embeddings])
	    s = s.add_input(vector)
	    h_t = s.output()
            #print "Calling attend"
	    c_t = attend_batch(input_mat, s, w1dt, batch_size, input_length)
	    predicted = w * dy.concatenate([h_t, c_t]) + b
	    if(dropout > 0.):
		predicted = dy.dropout(predicted, dropout)
	    cur_loss = dy.pickneglogsoftmax_batch(predicted, output_words[t])
        c_t_minus_1 = c_t

	    #Mask the loss in case mask == 0
	    if 0 in masks[t]:
		mask = dy.inputVector(masks[t])
		mask = dy.reshape(mask, (1, ), batch_size)
		cur_loss = cur_loss * mask

	    loss.append(cur_loss)
Ejemplo n.º 14
0
    def generate(self, in_seq, enc_fwd_lstm, enc_bwd_lstm, dec_lstm):
        embedded = self.embed_sentence(in_seq, True)
        encoded = self.encode_generation(enc_fwd_lstm, enc_bwd_lstm, embedded)
        h_len = len(encoded)
        curr_bsize = 1

        w = dy.parameter(decoder_w)
        b = dy.parameter(decoder_b)
        w1 = dy.parameter(attention_w1)

        H_source = dy.concatenate_cols(encoded)
        s = dec_lstm.initial_state()
        ctx_t0 = dy.vecInput(hidden_size * 2)
        last_output_embeddings = output_lookup[word2idx_en['<s>']]
        w1dt = w1 * H_source

        out = []
        count_EOS = 0
        for i in range(len(in_seq)*2):
            if count_EOS == 1: break
            # w1dt can be computed and cached once for the entire decoding phase
            x = dy.concatenate([ctx_t0, last_output_embeddings])
            #print "Attention: Generate"
            s = s.add_input(x)
            h_t = s.output()
            ctx_t, alpha_t = self.attend(H_source, s, w1dt, h_len, curr_bsize)

            out_vector = w * dy.concatenate([h_t, ctx_t]) + b
            probs = dy.softmax(out_vector).vec_value()
            next_char = probs.index(max(probs))
            last_output_embeddings = output_lookup[next_char]
            if idx2word_en[next_char] == '<EOS>':
                count_EOS += 1
                continue

            out.append(idx2word_en[next_char])
            ctx_t0 = ctx_t

        return ' '.join(out)
Ejemplo n.º 15
0
def attend_batch(input_mat, state, w1dt, batch_size, input_length):
	#print "in attend batch"
	global attention_w2
	global attention_v
	w2 = dy.parameter(attention_w2)
	v = dy.parameter(attention_v)
        #print "Calculating w2dt"
	w2dt = w2*dy.concatenate(list(state.s()))
        unnormalized = dy.transpose(v * dy.tanh(dy.colwise_add(w1dt, w2dt)))
        attention_reshaped = dy.reshape(unnormalized, (input_length, ), batch_size)
	att_weights = dy.softmax(attention_reshaped)
	context = input_mat * att_weights
	return context
Ejemplo n.º 16
0
def generate(in_seq, enc_fwd_lstm, enc_bwd_lstm, dec_lstm):
	#print "in generate"
	dy.renew_cg()
	embedded = embed_sentence(in_seq)
	encoded = encode_batch(enc_fwd_lstm, enc_bwd_lstm, embedded)

	w = dy.parameter(decoder_w)
	b = dy.parameter(decoder_b)
	w1 = dy.parameter(attention_w1)
	input_mat = dy.concatenate_cols(encoded)
	w1dt = None

	last_output_embeddings = output_lookup[BOS]
	#s = dec_lstm.initial_state([encoded[-1]])
	s = dec_lstm.initial_state()
	c_t_minus_1 = dy.vecInput(state_size*2)

	out = []
	count_EOS = 0
	for i in range(len(in_seq)*2):
		if count_EOS == 1: break
		# w1dt can be computed and cached once for the entire decoding phase
		w1dt = w1dt or w1 * input_mat
		vector = dy.concatenate([last_output_embeddings, c_t_minus_1])
		s = s.add_input(vector)
		h_t = s.output()
		c_t = attend_batch(input_mat, s, w1dt, 1, 1)
		out_vector = dy.affine_transform([b, w, dy.concatenate([h_t, c_t])])
		probs = dy.softmax(out_vector).vec_value()
		next_word = probs.index(max(probs))
		last_output_embeddings = output_lookup[next_word]
		c_t_minus_1 = c_t
		if next_word == EOS:
			count_EOS += 1
			continue

		out.append(english_word_vocab.i2w[next_word])
	return " ".join(out[1:])
Ejemplo n.º 17
0
    def encode_sentence(self, enc_fwd_lstm, enc_bwd_lstm, sentence):
        dy.renew_cg()

        enc_sents, _ = self.pad_zero(sentence)
        sent_embed = self.embed_sentence(enc_sents)
        sentrev_embed = sent_embed[::-1]

        enc_fwd = enc_fwd_lstm.initial_state()
        enc_bwd = enc_bwd_lstm.initial_state()
        fwd_vectors = enc_fwd.transduce(sent_embed)
        bwd_vectors = enc_bwd.transduce(sentrev_embed)
        bwd_vectors = bwd_vectors[::-1]

        vectors = [dy.concatenate(list(p)) for p in zip(fwd_vectors, bwd_vectors)]
        return vectors
Ejemplo n.º 18
0
def attend(input_mat, state, w1dt):
	global attention_w2
	global attention_v
	w2 = dy.parameter(attention_w2)
	v = dy.parameter(attention_v)

	# input_mat: (encoder_state x seqlen) => input vecs concatenated as cols
	# w1dt: (attdim x seqlen)
	# w2dt: (attdim x attdim)
	w2dt = w2*dy.concatenate(list(state.s()))
	# att_weights: (seqlen,) row vector
	unnormalized = dy.transpose(v * dy.tanh(dy.colwise_add(w1dt, w2dt)))
	att_weights = dy.softmax(unnormalized)
	# context: (encoder_state)
	context = input_mat * att_weights
	return context
Ejemplo n.º 19
0
    def __call__(self, words):
        dynet.renew_cg()
        word_ids = [self.vw.w2i.get(w, self.UNK) for w in words]
        wembs = [self._E[w] for w in word_ids]
        
        f_state = self._fwd_lstm.initial_state()
        b_state = self._bwd_lstm.initial_state()

        fw = [x.output() for x in f_state.add_inputs(wembs)]
        bw = [x.output() for x in b_state.add_inputs(reversed(wembs))]

        H = dynet.parameter(self._pH)
        O = dynet.parameter(self._pO)
        
        tags = []
        for i, (f, b) in enumerate(zip(fw, reversed(bw))):
            r_t = O * (dynet.tanh(H * dynet.concatenate([f, b])))
            out = dynet.softmax(r_t)
            tags.append(self.vt.i2w[np.argmax(out.npvalue())])
        return tags
Ejemplo n.º 20
0
    def predict_emb(self, chars):
        dy.renew_cg()

        finit = self.char_fwd_lstm.initial_state()
        binit = self.char_bwd_lstm.initial_state()

        H = dy.parameter(self.lstm_to_rep_params)
        Hb = dy.parameter(self.lstm_to_rep_bias)
        O = dy.parameter(self.mlp_out)
        Ob = dy.parameter(self.mlp_out_bias)

        pad_char = self.c2i[PADDING_CHAR]
        char_ids = [pad_char] + chars + [pad_char]
        embeddings = [self.char_lookup[cid] for cid in char_ids]

        bi_fwd_out = finit.transduce(embeddings)
        bi_bwd_out = binit.transduce(reversed(embeddings))

        rep = dy.concatenate([bi_fwd_out[-1], bi_bwd_out[-1]])

        return O * dy.tanh(H * rep + Hb) + Ob
Ejemplo n.º 21
0
    def __step_batch(self, batch):
        dy.renew_cg()
        W_y = dy.parameter(self.W_y)
        b_y = dy.parameter(self.b_y)
        F = len(batch[0][0])
        num_words = F * len(batch)

        src_batch = [x[0] for x in batch]
        tgt_batch = [x[1] for x in batch]
        src_rev_batch = [list(reversed(x)) for x in src_batch]
        # batch  = [ [a1,a2,a3,a4,a5], [b1,b2,b3,b4,b5], [c1,c2,c3,c4] ..]
        # transpose the batch into
        #   src_cws: [[a1,b1,c1,..], [a2,b2,c2,..], .. [a5,b5,</S>]]

        src_cws = map(list, zip(*src_batch))  # transpose
        src_rev_cws = map(list, zip(*src_rev_batch))

        # Bidirectional representations
        l2r_state = self.l2r_builder.initial_state()
        r2l_state = self.r2l_builder.initial_state()
        l2r_contexts = []
        r2l_contexts = []
        for (cw_l2r_list, cw_r2l_list) in zip(src_cws, src_rev_cws):
            l2r_state = l2r_state.add_input(
                dy.lookup_batch(self.src_lookup, [
                    self.src_token_to_id.get(cw_l2r, 0)
                    for cw_l2r in cw_l2r_list
                ]))
            r2l_state = r2l_state.add_input(
                dy.lookup_batch(self.src_lookup, [
                    self.src_token_to_id.get(cw_r2l, 0)
                    for cw_r2l in cw_r2l_list
                ]))

            l2r_contexts.append(
                l2r_state.output())  #[<S>, x_1, x_2, ..., </S>]
            r2l_contexts.append(
                r2l_state.output())  #[</S> x_n, x_{n-1}, ... <S>]

        r2l_contexts.reverse()  #[<S>, x_1, x_2, ..., </S>]

        # Combine the left and right representations for every word
        h_fs = []
        for (l2r_i, r2l_i) in zip(l2r_contexts, r2l_contexts):
            h_fs.append(dy.concatenate([l2r_i, r2l_i]))
        h_fs_matrix = dy.concatenate_cols(h_fs)

        losses = []

        # Decoder
        # batch  = [ [a1,a2,a3,a4,a5], [b1,b2,b3,b4,b5], [c1,c2,c3,c4] ..]
        # transpose the batch into
        #   tgt_cws: [[a1,b1,c1,..], [a2,b2,c2,..], .. [a5,b5,</S>]]
        #   masks: [1,1,1,..], [1,1,1,..], ...[1,1,0,..]]

        tgt_cws = []
        masks = []
        maxLen = max([len(l) for l in tgt_batch])
        for i in range(maxLen):
            tgt_cws.append([])
            masks.append([])
        for sentence in tgt_batch:
            for j in range(maxLen):
                if j > len(sentence) - 1:
                    tgt_cws[j].append('</S>')
                    masks[j].append(0)
                else:
                    tgt_cws[j].append(sentence[j])
                    masks[j].append(1)
        c_t = dy.vecInput(self.hidden_size * 2)
        start = dy.concatenate(
            [dy.lookup(self.tgt_lookup, self.tgt_token_to_id['<S>']), c_t])
        dec_state = self.dec_builder.initial_state().add_input(start)
        for (cws, nws, mask) in zip(tgt_cws, tgt_cws[1:], masks):
            h_e = dec_state.output()
            _, c_t = self.__attention_mlp(h_fs_matrix, h_e, F)
            # Get the embedding for the current target word
            embed_t = dy.lookup_batch(
                self.tgt_lookup,
                [self.tgt_token_to_id.get(cw, 0) for cw in cws])
            # Create input vector to the decoder
            x_t = dy.concatenate([embed_t, c_t])
            dec_state = dec_state.add_input(x_t)
            y_star = dy.affine_transform([b_y, W_y, dec_state.output()])
            loss = dy.pickneglogsoftmax_batch(
                y_star, [self.tgt_token_to_id.get(nw, 0) for nw in nws])
            mask_exp = dy.reshape(dy.inputVector(mask), (1, ), len(mask))
            loss = loss * mask_exp
            losses.append(loss)
        return dy.sum_batches(dy.esum(losses)), num_words
Ejemplo n.º 22
0
    def translate_sentence_beam_compact(self, sent, k=2):
        dy.renew_cg()
        F = len(sent)
        W_y = dy.parameter(self.W_y)
        b_y = dy.parameter(self.b_y)
        sent_rev = list(reversed(sent))
        # Bidirectional representations
        l2r_state = self.l2r_builder.initial_state()
        r2l_state = self.r2l_builder.initial_state()
        l2r_contexts = []
        r2l_contexts = []
        for (cw_l2r, cw_r2l) in zip(sent, sent_rev):
            l2r_state = l2r_state.add_input(
                dy.lookup(self.src_lookup, self.src_token_to_id.get(cw_l2r,
                                                                    0)))
            r2l_state = r2l_state.add_input(
                dy.lookup(self.src_lookup, self.src_token_to_id.get(cw_r2l,
                                                                    0)))
            l2r_contexts.append(
                l2r_state.output())  #[<S>, x_1, x_2, ..., </S>]
            r2l_contexts.append(
                r2l_state.output())  #[</S> x_n, x_{n-1}, ... <S>]
        r2l_contexts.reverse()

        h_fs = []
        for (l2r_i, r2l_i) in zip(l2r_contexts, r2l_contexts):
            h_fs.append(dy.concatenate([l2r_i, r2l_i]))
        h_fs_matrix = dy.concatenate_cols(h_fs)
        valid = [True for i in range(k)]
        trans = [['<S>'] for i in range(k)]
        prob = [0 for i in range(k)]
        used = [set([0, F - 1]) for i in range(k)]
        cw = [trans[i][-1] for i in range(k)]
        c_t = dy.vecInput(l2r_contexts[0].npvalue().shape[0] * 2)
        start = dy.concatenate(
            [dy.lookup(self.tgt_lookup, self.tgt_token_to_id['<S>']), c_t])
        dec_state = [
            self.dec_builder.initial_state().add_input(start) for i in range(k)
        ]
        b = [i for i in range(k)]
        FIRST = True
        while True in valid:

            h_e = [dec_state[i].output() for i in range(k)]
            ac = [self.__attention_mlp(h_fs_matrix, elem, F) for elem in h_e]
            alignment = [elem[0] for elem in ac]
            c_t = [elem[1] for elem in ac]
            embed_t = [
                dy.lookup(self.tgt_lookup, self.tgt_token_to_id.get(cw[i], 0))
                for i in range(k)
            ]
            x_t = [dy.concatenate([embed_t[i], c_t[i]]) for i in range(k)]
            dec_state = [dec_state[b[i]].add_input(x_t[i]) for i in range(k)]
            y_star = [
                dy.affine_transform([b_y, W_y, dec_state[i].output()])
                for i in range(k)
            ]
            p = [dy.log_softmax(y_star[i]) for i in range(k)]
            tmp = [p[i].npvalue() for i in range(k)]
            l = []
            if not FIRST:
                idx = [np.argpartition(-tmp[i].T, k)[0][:k] for i in range(k)]
                val = [-np.partition(-tmp[i].T, k)[0][:k] for i in range(k)]
                for i in range(k):
                    for j in range(k):
                        l.append((self.tgt_id_to_token[idx[i][j]],
                                  val[i][j] + prob[i], i))

            else:
                idx = np.argpartition(-tmp[0].T, k)[0][:k]
                val = -np.partition(-tmp[0].T, k)[0][:k]
                FIRST = False
                for i in range(k):
                    l.append((self.tgt_id_to_token[idx[i]], val[i], i))

            l = sorted(l, key=lambda x: -x[1])[:k]
            # print l
            cw = [l[i][0] for i in range(k)]
            prob = [l[i][1] for i in range(k)]
            b = [l[i][2] for i in range(k)]

            trans = [list(trans[b[i]]) for i in range(k)]
            valid = [valid[b[i]] for i in range(k)]

            for i in range(k):
                if valid[i] == False: continue
                if cw[i] != '</S>' and len(trans[i]) < MAX_LEN:
                    if cw[i] == '<unk>':  #unknown words replacement
                        a = alignment[i].npvalue()
                        am = np.argmax(a)
                        while am in used:
                            a[am] = -1
                            am = np.argmax(a)
                        used[i].add(am)
                        if len(used[i]) == F / 2:
                            used[i] = set([0, F - 1])
                        trans[i].append(sent[am])
                    else:
                        trans[i].append(cw[i])
                else:
                    valid[i] = False

        index = prob.index(max(prob))
        return ' '.join(trans[index][1:])
Ejemplo n.º 23
0
def encode_batch(enc_fwd_lstm, enc_bwd_lstm, sentences):
	#print "in encode batch"
	global input_lookup
	fwd_state = enc_fwd_lstm.initial_state()
	bwd_state = enc_bwd_lstm.initial_state()

	input_words, masks = sentences_to_batch(sentences)
	input_embeddings = [dy.lookup_batch(input_lookup, wids) for wids in input_words]
	input_embeddings_rev = input_embeddings[::-1]

	#Get the forward and backward encodings
	fwd_vectors = fwd_state.transduce(input_embeddings)
	bwd_vectors = bwd_state.transduce(input_embeddings_rev)
    bwd_vectors = bwd_vectors[::-1]

	input_vectors = [dy.concatenate(list(p)) for p in zip(fwd_vectors, bwd_vectors)]
	return input_vectors

#Return the context after calculating attention : MLP method
def attend_batch(input_mat, state, w1dt, batch_size, input_length):
	#print "in attend batch"
	global attention_w2
	global attention_v
	w2 = dy.parameter(attention_w2)
	v = dy.parameter(attention_v)
        #print "Calculating w2dt"
	w2dt = w2*dy.concatenate(list(state.s()))
        unnormalized = dy.transpose(v * dy.tanh(dy.colwise_add(w1dt, w2dt)))
        attention_reshaped = dy.reshape(unnormalized, (input_length, ), batch_size)
	att_weights = dy.softmax(attention_reshaped)
	context = input_mat * att_weights
Ejemplo n.º 24
0
    def translate_sentence(self, sent):
        dy.renew_cg()
        W_y = dy.parameter(self.W_y)
        b_y = dy.parameter(self.b_y)
        F = len(sent)

        sent_rev = list(reversed(sent))

        # Bidirectional representations
        l2r_state = self.l2r_builder.initial_state()
        r2l_state = self.r2l_builder.initial_state()
        l2r_contexts = []
        r2l_contexts = []
        for (cw_l2r, cw_r2l) in zip(sent, sent_rev):
            l2r_state = l2r_state.add_input(
                dy.lookup(self.src_lookup, self.src_token_to_id.get(cw_l2r,
                                                                    0)))
            r2l_state = r2l_state.add_input(
                dy.lookup(self.src_lookup, self.src_token_to_id.get(cw_r2l,
                                                                    0)))
            l2r_contexts.append(
                l2r_state.output())  #[<S>, x_1, x_2, ..., </S>]
            r2l_contexts.append(
                r2l_state.output())  #[</S> x_n, x_{n-1}, ... <S>]
        r2l_contexts.reverse()

        h_fs = []
        for (l2r_i, r2l_i) in zip(l2r_contexts, r2l_contexts):
            h_fs.append(dy.concatenate([l2r_i, r2l_i]))
        h_fs_matrix = dy.concatenate_cols(h_fs)

        # Decoder
        trans_sentence = ['<S>']
        cw = trans_sentence[-1]
        c_t = dy.vecInput(l2r_contexts[0].npvalue().shape[0] * 2)
        start = dy.concatenate(
            [dy.lookup(self.tgt_lookup, self.tgt_token_to_id['<S>']), c_t])
        dec_state = self.dec_builder.initial_state().add_input(start)
        while len(trans_sentence) < MAX_LEN:
            h_e = dec_state.output()
            alignment, c_t = self.__attention_mlp(h_fs_matrix, h_e, F)
            embed_t = dy.lookup(self.tgt_lookup,
                                self.tgt_token_to_id.get(cw, 0))
            # Create input vector to the decoder
            x_t = dy.concatenate([embed_t, c_t])
            dec_state = dec_state.add_input(x_t)
            y_star = dy.affine_transform([b_y, W_y, dec_state.output()])
            cw = self.tgt_id_to_token[np.argmax(y_star.npvalue())]
            if cw == '<unk>':  #unknown words replacement
                a = alignment.npvalue()
                if np.argmax(a) == 0:
                    a[0] = 0
                if np.argmax(a) == len(a) - 1:
                    a[len(a) - 1] = 0
                trans_sentence.append(sent[np.argmax(a)])
                continue
            if cw == '</S>':
                break
            trans_sentence.append(cw)

        return ' '.join(trans_sentence[1:])