Example #1
0
    def predict_batch(self, words_batch):
        dynet.renew_cg()
        length = max(len(words) for words in words_batch)
        word_ids = np.zeros((length, len(words_batch)), dtype='int32')
        for j, words in enumerate(words_batch):
            for i, word in enumerate(words):
                word_ids[i, j] = self.vw.w2i.get(word, self.UNK)
        wembs = [dynet.lookup_batch(self._E, word_ids[i]) for i in range(length)]
        
        f_state = self._fwd_lstm.initial_state()
        b_state = self._bwd_lstm.initial_state()

        fw = [x.output() for x in f_state.add_inputs(wembs)]
        bw = [x.output() for x in b_state.add_inputs(reversed(wembs))]

        H = dynet.parameter(self._pH)
        O = dynet.parameter(self._pO)
        
        tags_batch = [[] for _ in range(len(words_batch))]
        for i, (f, b) in enumerate(zip(fw, reversed(bw))):
            r_t = O * (dynet.tanh(H * dynet.concatenate([f, b])))
            out = dynet.softmax(r_t).npvalue()
            for j in range(len(words_batch)):
                tags_batch[j].append(self.vt.i2w[np.argmax(out.T[j])])
        return tags_batch
Example #2
0
    def update_batch(self, words_batch, tags_batch):
        dynet.renew_cg()
        length = max(len(words) for words in words_batch)
        word_ids = np.zeros((length, len(words_batch)), dtype='int32')
        for j, words in enumerate(words_batch):
            for i, word in enumerate(words):
                word_ids[i, j] = self.vw.w2i.get(word, self.UNK)
        tag_ids = np.zeros((length, len(words_batch)), dtype='int32')
        for j, tags in enumerate(tags_batch):
            for i, tag in enumerate(tags):
                tag_ids[i, j] = self.vt.w2i.get(tag, self.UNK)
        wembs = [dynet.lookup_batch(self._E, word_ids[i]) for i in range(length)]
        wembs = [dynet.noise(we, 0.1) for we in wembs]
        
        f_state = self._fwd_lstm.initial_state()
        b_state = self._bwd_lstm.initial_state()

        fw = [x.output() for x in f_state.add_inputs(wembs)]
        bw = [x.output() for x in b_state.add_inputs(reversed(wembs))]

        H = dynet.parameter(self._pH)
        O = dynet.parameter(self._pO)
        
        errs = []
        for i, (f, b) in enumerate(zip(fw, reversed(bw))):
            f_b = dynet.concatenate([f,b])
            r_t = O * (dynet.tanh(H * f_b))
            err = dynet.pickneglogsoftmax_batch(r_t, tag_ids[i])
            errs.append(dynet.sum_batches(err))
        sum_errs = dynet.esum(errs)
        squared = -sum_errs # * sum_errs
        losses = sum_errs.scalar_value()
        sum_errs.backward()
        self._sgd.update()
        return losses
Example #3
0
def encode_batch(enc_fwd_lstm, enc_bwd_lstm, sentences):
	#print "in encode batch"
	global input_lookup
	
	input_words, masks = sentences_to_batch(sentences)
	input_embeddings = [dy.lookup_batch(input_lookup, wids) for wids in input_words]
	input_embeddings_rev = input_embeddings[::-1]
 def embed_sentence_batch(self, sentence_batch):
     global input_lookup
     padded_sentence_batch, _ = self.create_sentence_batch(sentence_batch)
     batched_lookup = [
         dy.lookup_batch(input_lookup, wids)
         for wids in padded_sentence_batch
     ]
     return batched_lookup
Example #5
0
    def build_sentence_graph(self, sents, labels):
        renew_cg()
        f_init = self.builder.initial_state()

        STOP = vocab.w2i["<stop>"]
        START = vocab.w2i["<start>"]

        W_exp = parameter(self.R)
        b_exp = parameter(self.bias)
        state = f_init

        # get the wids and masks for each step
        tot_words = 0
        wids = []
        masks = []

        for i in range(len(sents[0])):
            wids.append([(START if len(sents[0]) - len(sent) > i else
                          vocab.w2i[sent[i - len(sents[0]) + len(sent)]])
                         for sent in sents])
            #wids.append([(vocab.w2i[sent[i]] if len(sent) > i else STOP) for sent in sents])
            mask = [(1 if len(sent) > i else 0) for sent in sents]
            masks.append(mask)
            tot_words += sum(mask)

        #print "wids:"
        #print wids

        # start the rnn by inputting "<start>"
        init_ids = [START] * len(sents)
        s = f_init.add_input(dy.lookup_batch(self.lookup, init_ids))

        # feed word vectors into the RNN and predict the next word
        losses = []
        for wid in wids:
            # calculate the softmax and loss
            score = W_exp * s.output() + b_exp
            # update the state of the RNN
            wemb = dy.lookup_batch(self.lookup, wid)
            s = s.add_input(wemb)

        loss = dy.pickneglogsoftmax_batch(score, labels)

        losses.append(loss)
        return dy.sum_batches(dy.esum(losses))
    def decode(self, dec_lstm, vectors, output_batch):
        # output = [EOS] + list(output) + [EOS]
        # output = [char2int[c] for c in output]

        w = dy.parameter(decoder_w)
        b = dy.parameter(decoder_b)
        w1 = dy.parameter(attention_w1)
        input_mat = dy.concatenate_cols(vectors)
        input_len = len(vectors)

        w1dt = w1 * input_mat

        curr_bsize = len(output_batch)

        #n_batch_start = [1] * curr_bsize
        #last_output_embeddings = dy.lookup_batch(output_lookup, n_batch_start)

        s = dec_lstm.initial_state()
        c_t_previous = dy.vecInput(STATE_SIZE * 2)

        loss = []

        output_batch, masks = self.create_sentence_batch(output_batch)

        for i, (word_batch,
                mask_word) in enumerate(zip(output_batch[1:], masks[1:]),
                                        start=1):
            last_output_embeddings = dy.lookup_batch(output_lookup,
                                                     output_batch[i - 1])
            vector = dy.concatenate([last_output_embeddings, c_t_previous])
            s = s.add_input(vector)
            h_t = s.output()
            c_t, alpha_t = self.attend(input_mat, s, w1dt, input_len,
                                       curr_bsize)

            h_c_concat = dy.concatenate([h_t, c_t])
            out_vector = dy.affine_transform([b, w, h_c_concat])

            if DROPOUT > 0.0:
                out_vector = dy.dropout(out_vector, DROPOUT)

            loss_current = dy.pickneglogsoftmax_batch(out_vector,
                                                      output_batch[i])

            if 0 in mask_word:
                mask_vals = dy.inputVector(mask_word)
                mask_vals = dy.reshape(mask_vals, (1, ), curr_bsize)
                loss_current = loss_current * mask_vals

            loss.append(loss_current)
            c_t_previous = c_t

        loss = dy.esum(loss)
        loss = dy.sum_batches(loss) / curr_bsize
        #perplexity = loss.value() * curr_bsize / float(sum([x.count(1) for x in masks[1:]]))
        return loss
Example #7
0
def calc_lm_loss(sents):

    dy.renew_cg()
    # parameters -> expressions
    W_exp = dy.parameter(W_sm)
    b_exp = dy.parameter(b_sm)

    # initialize the RNN
    f_init = RNN.initial_state()

    # get the wids and masks for each step
    tot_words = 0
    wids = []
    masks = []

    for i in range(len(sents[0])):
        wids.append([(vw.w2i[sent[i]] if len(sent) > i else STOP) for sent in sents])
        mask = [(1 if len(sent) > i else 0) for sent in sents]
        masks.append(mask)
        tot_words += sum(mask)

    # start the rnn by inputting "<start>"
    init_ids = [START] * len(sents)
    s = f_init.add_input(dy.lookup_batch(WORDS_LOOKUP, init_ids))

    # feed word vectors into the RNN and predict the next word
    losses = []
    for wid, mask in zip(wids, masks):
        # calculate the softmax and loss
        score = W_exp * s.output() + b_exp
        loss = dy.pickneglogsoftmax_batch(score, wid)
        # mask the loss if at least one sentence is shorter
        if mask[-1] != 1:
            mask_expr = dy.inputVector(mask)
            mask_expr = dy.reshape(mask_expr, (1,), MB_SIZE)
            loss = loss * mask_expr
        losses.append(loss)
        # update the state of the RNN
        wemb = dy.lookup_batch(WORDS_LOOKUP, wid)
        s = s.add_input(wemb)

    return dy.sum_batches(dy.esum(losses)), tot_words
Example #8
0
def encode_batch(enc_fwd_lstm, enc_bwd_lstm, sentences):
	#print "in encode batch"
	global input_lookup
	fwd_state = enc_fwd_lstm.initial_state()
	bwd_state = enc_bwd_lstm.initial_state()

	input_words, masks = sentences_to_batch(sentences)
	input_embeddings = [dy.lookup_batch(input_lookup, wids) for wids in input_words]
	input_embeddings_rev = input_embeddings[::-1]

	#Get the forward and backward encodings
	fwd_vectors = fwd_state.transduce(input_embeddings)
	bwd_vectors = bwd_state.transduce(input_embeddings_rev)
    def build_sentence_graph(self, sents, labels):
        renew_cg()
        f_init = self.builder.initial_state()

        START = vocab.w2i["<start>"]

        W_exp = parameter(self.R)
        b_exp = parameter(self.bias)
        state = f_init

        # get the wids and masks for each step
        tot_words = 0
        wids = []
        masks = []

        #pad the sequences that are shorter, with a START at the beginning
        for i in range(len(sents[0])):
            wids.append([(START if len(sents[0]) - len(sent) > i else
                          vocab.w2i[sent[i - len(sents[0]) + len(sent)]])
                         for sent in sents])
            mask = [(1 if len(sent) > i else 0) for sent in sents]
            masks.append(mask)
            tot_words += sum(mask)

        init_ids = [START] * len(sents)
        s = f_init.add_input(dy.lookup_batch(self.lookup, init_ids))

        losses = []
        for wid in wids:
            # calculate the softmax and loss
            score = W_exp * s.output() + b_exp
            # update the state of the RNN
            wemb = dy.lookup_batch(self.lookup, wid)
            s = s.add_input(wemb)

        loss = dy.pickneglogsoftmax_batch(score, labels)
        losses.append(loss)
        return dy.sum_batches(dy.esum(losses))
    def build_sentence_graph(self, sents, labels):
        renew_cg()
        f_init = self.builder.initial_state()

        START = vocab.w2i["<start>"]

        W_exp = parameter(self.R)
        b_exp = parameter(self.bias)
        state = f_init

        # get the wids and masks for each step
        tot_words = 0
        wids = []
        masks = []

        #pad the sequences that are shorter, with a START at the beginning
        for i in range(len(sents[0])):
            wids.append([(START if len(sents[0])-len(sent) > i else vocab.w2i[sent[i - len(sents[0])+len(sent)]]) for sent in sents])
            mask = [(1 if len(sent) > i else 0) for sent in sents]
            masks.append(mask)
            tot_words += sum(mask)

        init_ids = [START] * len(sents)
        s = f_init.add_input(dy.lookup_batch(self.lookup, init_ids))

        losses = []
        for wid in wids:
            # calculate the softmax and loss
            score = W_exp * s.output() + b_exp
            # update the state of the RNN
            wemb = dy.lookup_batch(self.lookup, wid)
            s = s.add_input(wemb)

        loss = dy.pickneglogsoftmax_batch(score, labels)
        losses.append(loss)
        return dy.sum_batches(dy.esum(losses))
Example #11
0
def decode_batch(dec_lstm, input_encodings, output_sentences):
	#print "in decode batch"
	w = dy.parameter(decoder_w)
	b = dy.parameter(decoder_b)
	w1 = dy.parameter(attention_w1)

	output_words, masks = sentences_to_batch(output_sentences)
	decoder_target_input = zip(output_words[1:], masks[1:])

	batch_size = len(output_sentences)
	input_length = len(input_encodings)

	input_mat = dy.concatenate_cols(input_encodings)
        #print "Computing w1dt"
	w1dt = w1 * input_mat

	s = dec_lstm.initial_state()
	c_t_minus_1 = dy.vecInput(state_size * 2)
	loss = []

	for t, (words_t, mask_t) in enumerate(decoder_target_input, start = 1):
		last_output_embeddings = dy.lookup_batch(output_lookup, output_words[t-1])
		vector = dy.concatenate([last_output_embeddings, c_t_minus_1])
		s = s.add_input(vector)
		h_t = s.output()
                #print "Calling attend"
		c_t = attend_batch(input_mat, s, w1dt, batch_size, input_length)
		predicted = dy.affine_transform([b, w, dy.concatenate([h_t, c_t])])
		if(dropout > 0.):
			predicted = dy.dropout(predicted, dropout)
		cur_loss = dy.pickneglogsoftmax_batch(predicted, words_t)
                c_t_minus_1 = c_t

		#Mask the loss in case mask == 0
		if 0 in mask_t:
		    mask = dy.inputVector(mask_t)
		    mask = dy.reshape(mask, (1, ), batch_size)
		    cur_loss = cur_loss * mask

		loss.append(cur_loss)
	
	#Get the average batch loss
	loss = dy.esum(loss)
	loss = dy.sum_batches(loss) / batch_size
	return loss
Example #12
0
    def myDecode(self, dec_lstm, h_encodings, target_sents):
        w = dy.parameter(decoder_w)
        b = dy.parameter(decoder_b)
        w1 = dy.parameter(attention_w1)

        dec_wrds, dec_mask = self.pad_zero(target_sents)
        curr_bsize = len(target_sents)
        h_len = len(h_encodings)

        H_source = dy.concatenate_cols(h_encodings)
        s = dec_lstm.initial_state()
        ctx_t0 = dy.vecInput(hidden_size * 2)
        w1dt = w1 * H_source

        loss = []

        #print curr_bsize
        for sent in range(1, len(dec_wrds)):
            last_output_embeddings = dy.lookup_batch(output_lookup, dec_wrds[sent-1])
            x = dy.concatenate([ctx_t0, last_output_embeddings])
            s = s.add_input(x)
            h_t = s.output()

            ctx_t, alpha_t = self.attend(H_source, s, w1dt, h_len, curr_bsize)
            output_vector = w * dy.concatenate([h_t, ctx_t]) + b
            #probs = dy.softmax(output_vector)
            ctx_t0 = ctx_t
            if dropout_config:
                output_vector = dy.dropout(output_vector, dropout_val)

            temp_loss = dy.pickneglogsoftmax_batch(output_vector, dec_wrds[sent])

            if 0 in dec_mask[sent]:
                mask_expr = dy.inputVector(dec_mask[sent])
                mask_expr = dy.reshape(mask_expr, (1, ), curr_bsize)
                temp_loss = temp_loss * mask_expr

            loss.append(temp_loss)

        loss = dy.esum(loss)
        loss = dy.sum_batches(loss) / batch_size
        return loss
Example #13
0
def decode_batch(dec_lstm, input_encodings, output_sentences):
	#print "in decode batch"
	w = dy.parameter(decoder_w)
	b = dy.parameter(decoder_b)
	w1 = dy.parameter(attention_w1)

	output_words, masks = sentences_to_batch(output_sentences)
	#decoder_target_input = zip(output_words[1:], masks[1:])

	batch_size = len(output_sentences)
	input_length = len(input_encodings)

	input_mat = dy.concatenate_cols(input_encodings)
        #print "Computing w1dt"
	w1dt = w1 * input_mat

	s = dec_lstm.initial_state()
	c_t_minus_1 = dy.vecInput(state_size * 2)
	loss = []

	for t in range(1, len(output_words)):
	    last_output_embeddings = dy.lookup_batch(output_lookup, output_words[t-1])
	    vector = dy.concatenate([c_t_minus_1, last_output_embeddings])
	    s = s.add_input(vector)
	    h_t = s.output()
            #print "Calling attend"
	    c_t = attend_batch(input_mat, s, w1dt, batch_size, input_length)
	    predicted = w * dy.concatenate([h_t, c_t]) + b
	    if(dropout > 0.):
		predicted = dy.dropout(predicted, dropout)
	    cur_loss = dy.pickneglogsoftmax_batch(predicted, output_words[t])
        c_t_minus_1 = c_t

	    #Mask the loss in case mask == 0
	    if 0 in masks[t]:
		mask = dy.inputVector(masks[t])
		mask = dy.reshape(mask, (1, ), batch_size)
		cur_loss = cur_loss * mask

	    loss.append(cur_loss)
Example #14
0
 def embed_sentence(self, sentence, generation=False):
     if generation:
         embeddings = [input_lookup[word] for word in sentence]
     else:
         embeddings = [dy.lookup_batch(input_lookup, wids) for wids in sentence]
     return embeddings
Example #15
0
def build_batch_sentences_graph(isents, builders):
    #NEXT STAGE - remove sentences that are not the same length from batch, and remove all mask stuff
    renew_cg()
    f_init, b_init = [b.initial_state() for b in builders]
    W = parameter(W_sm)
    b = parameter(b_sm)

    #Train first forward direction LSTM
    #Init forward aux variables
    f_tot_words = 0
    f_wids_list = []
    f_masks = []
    #Prepare the batch. Mask shorter sentences
    for i in range(len(isents[0])):
        f_wids_list.append([(vw.w2i[sent[i]] if len(sent) > i else STOP) for sent in isents])
        mask = [(1 if len(sent) > i else 0) for sent in isents]
        f_masks.append(mask)
        f_tot_words += sum(mask)
    #Feed the batched words into forward lstm
    f_outputs = []
    state = f_init
    for wids in f_wids_list: #wids is a list of i-th word in each sentence across the batch
        wembs = dy.lookup_batch(WORDS_LOOKUP, wids)
        state = state.add_input(wembs)
        f_outputs.append(state.output()) #get output batch

    #Train backward direction of lstm
    b_tot_words = 0
    b_wids_list = []
    b_masks = []
    # Prepare the batch. Mask shorter sentences
    for i in range(len(isents[0])):
        b_wids_list.append([(vw.w2i[sent[::-1][i]] if len(sent) > i else START) for sent in isents])
        mask = [(1 if len(sent) > i else 0) for sent in isents]
        b_masks.append(mask)
        b_tot_words += sum(mask)
    # Feed the batched words into backward lstm
    b_outputs = []
    state = b_init
    for wids in b_wids_list:  # wids is a list of i-th word in each sentence across the batch
        wembs = dy.lookup_batch(WORDS_LOOKUP, wids)
        state = state.add_input(wembs)
        b_outputs.append(state.output())  # get output batch

    #Compute loss via batch softmax
    errs = []

    for i in range(len(isents[0])-3):
        y = concatenate([f_outputs[i], b_outputs[::-1][i+2]])
        r = b + (W * y)
        err = pickneglogsoftmax_batch(r, f_wids_list[i+1])
        #print "FMASK", f_masks
        if f_masks[i][-1] != 1 or b_masks[::-1][i+2][-1] != 1:
            complete_mask = f_masks + b_masks
         #   print "COMPLETE", complete_mask
            mask_expr = dy.inputVector(complete_mask)
            mask_expr = dy.reshape(mask_expr, (1,), MB_SIZE)
            err = err * mask_expr
        errs.append(err)

    losses = sum_batches(esum(errs))
    return losses
Example #16
0
def build_batch_sentences_graph(isents, builders):
    #NEXT STAGE - remove sentences that are not the same length from batch, and remove all mask stuff
    renew_cg()
    f_init, b_init = [b.initial_state() for b in builders]
    W = parameter(W_sm)
    b = parameter(b_sm)

    #Train first forward direction LSTM
    #Init forward aux variables
    f_tot_words = 0
    f_wids_list = []
    f_masks = []
    #Prepare the batch. Mask shorter sentences
    for i in range(len(isents[0])):
        f_wids_list.append([(vw.w2i[sent[i]] if len(sent) > i else STOP)
                            for sent in isents])
        mask = [(1 if len(sent) > i else 0) for sent in isents]
        f_masks.append(mask)
        f_tot_words += sum(mask)
    #Feed the batched words into forward lstm
    f_outputs = []
    state = f_init
    for wids in f_wids_list:  #wids is a list of i-th word in each sentence across the batch
        wembs = dy.lookup_batch(WORDS_LOOKUP, wids)
        state = state.add_input(wembs)
        f_outputs.append(state.output())  #get output batch

    #Train backward direction of lstm
    b_tot_words = 0
    b_wids_list = []
    b_masks = []
    # Prepare the batch. Mask shorter sentences
    for i in range(len(isents[0])):
        b_wids_list.append([(vw.w2i[sent[::-1][i]] if len(sent) > i else START)
                            for sent in isents])
        mask = [(1 if len(sent) > i else 0) for sent in isents]
        b_masks.append(mask)
        b_tot_words += sum(mask)
    # Feed the batched words into backward lstm
    b_outputs = []
    state = b_init
    for wids in b_wids_list:  # wids is a list of i-th word in each sentence across the batch
        wembs = dy.lookup_batch(WORDS_LOOKUP, wids)
        state = state.add_input(wembs)
        b_outputs.append(state.output())  # get output batch

    #Compute loss via batch softmax
    errs = []

    for i in range(len(isents[0]) - 3):
        y = concatenate([f_outputs[i], b_outputs[::-1][i + 2]])
        r = b + (W * y)
        err = pickneglogsoftmax_batch(r, f_wids_list[i + 1])
        #print "FMASK", f_masks
        if f_masks[i][-1] != 1 or b_masks[::-1][i + 2][-1] != 1:
            complete_mask = f_masks + b_masks
            #   print "COMPLETE", complete_mask
            mask_expr = dy.inputVector(complete_mask)
            mask_expr = dy.reshape(mask_expr, (1, ), MB_SIZE)
            err = err * mask_expr
        errs.append(err)

    losses = sum_batches(esum(errs))
    return losses
Example #17
0
    def __step_batch(self, batch):
        dy.renew_cg()
        W_y = dy.parameter(self.W_y)
        b_y = dy.parameter(self.b_y)
        F = len(batch[0][0])
        num_words = F * len(batch)

        src_batch = [x[0] for x in batch]
        tgt_batch = [x[1] for x in batch]
        src_rev_batch = [list(reversed(x)) for x in src_batch]
        # batch  = [ [a1,a2,a3,a4,a5], [b1,b2,b3,b4,b5], [c1,c2,c3,c4] ..]
        # transpose the batch into
        #   src_cws: [[a1,b1,c1,..], [a2,b2,c2,..], .. [a5,b5,</S>]]

        src_cws = map(list, zip(*src_batch))  # transpose
        src_rev_cws = map(list, zip(*src_rev_batch))

        # Bidirectional representations
        l2r_state = self.l2r_builder.initial_state()
        r2l_state = self.r2l_builder.initial_state()
        l2r_contexts = []
        r2l_contexts = []
        for (cw_l2r_list, cw_r2l_list) in zip(src_cws, src_rev_cws):
            l2r_state = l2r_state.add_input(
                dy.lookup_batch(self.src_lookup, [
                    self.src_token_to_id.get(cw_l2r, 0)
                    for cw_l2r in cw_l2r_list
                ]))
            r2l_state = r2l_state.add_input(
                dy.lookup_batch(self.src_lookup, [
                    self.src_token_to_id.get(cw_r2l, 0)
                    for cw_r2l in cw_r2l_list
                ]))

            l2r_contexts.append(
                l2r_state.output())  #[<S>, x_1, x_2, ..., </S>]
            r2l_contexts.append(
                r2l_state.output())  #[</S> x_n, x_{n-1}, ... <S>]

        r2l_contexts.reverse()  #[<S>, x_1, x_2, ..., </S>]

        # Combine the left and right representations for every word
        h_fs = []
        for (l2r_i, r2l_i) in zip(l2r_contexts, r2l_contexts):
            h_fs.append(dy.concatenate([l2r_i, r2l_i]))
        h_fs_matrix = dy.concatenate_cols(h_fs)

        losses = []

        # Decoder
        # batch  = [ [a1,a2,a3,a4,a5], [b1,b2,b3,b4,b5], [c1,c2,c3,c4] ..]
        # transpose the batch into
        #   tgt_cws: [[a1,b1,c1,..], [a2,b2,c2,..], .. [a5,b5,</S>]]
        #   masks: [1,1,1,..], [1,1,1,..], ...[1,1,0,..]]

        tgt_cws = []
        masks = []
        maxLen = max([len(l) for l in tgt_batch])
        for i in range(maxLen):
            tgt_cws.append([])
            masks.append([])
        for sentence in tgt_batch:
            for j in range(maxLen):
                if j > len(sentence) - 1:
                    tgt_cws[j].append('</S>')
                    masks[j].append(0)
                else:
                    tgt_cws[j].append(sentence[j])
                    masks[j].append(1)
        c_t = dy.vecInput(self.hidden_size * 2)
        start = dy.concatenate(
            [dy.lookup(self.tgt_lookup, self.tgt_token_to_id['<S>']), c_t])
        dec_state = self.dec_builder.initial_state().add_input(start)
        for (cws, nws, mask) in zip(tgt_cws, tgt_cws[1:], masks):
            h_e = dec_state.output()
            _, c_t = self.__attention_mlp(h_fs_matrix, h_e, F)
            # Get the embedding for the current target word
            embed_t = dy.lookup_batch(
                self.tgt_lookup,
                [self.tgt_token_to_id.get(cw, 0) for cw in cws])
            # Create input vector to the decoder
            x_t = dy.concatenate([embed_t, c_t])
            dec_state = dec_state.add_input(x_t)
            y_star = dy.affine_transform([b_y, W_y, dec_state.output()])
            loss = dy.pickneglogsoftmax_batch(
                y_star, [self.tgt_token_to_id.get(nw, 0) for nw in nws])
            mask_exp = dy.reshape(dy.inputVector(mask), (1, ), len(mask))
            loss = loss * mask_exp
            losses.append(loss)
        return dy.sum_batches(dy.esum(losses)), num_words