def decode(self, dec_lstm, vectors, output_batch):
        # output = [EOS] + list(output) + [EOS]
        # output = [char2int[c] for c in output]

        w = dy.parameter(decoder_w)
        b = dy.parameter(decoder_b)
        w1 = dy.parameter(attention_w1)
        input_mat = dy.concatenate_cols(vectors)
        input_len = len(vectors)

        w1dt = w1 * input_mat

        curr_bsize = len(output_batch)

        #n_batch_start = [1] * curr_bsize
        #last_output_embeddings = dy.lookup_batch(output_lookup, n_batch_start)

        s = dec_lstm.initial_state()
        c_t_previous = dy.vecInput(STATE_SIZE * 2)

        loss = []

        output_batch, masks = self.create_sentence_batch(output_batch)

        for i, (word_batch,
                mask_word) in enumerate(zip(output_batch[1:], masks[1:]),
                                        start=1):
            last_output_embeddings = dy.lookup_batch(output_lookup,
                                                     output_batch[i - 1])
            vector = dy.concatenate([last_output_embeddings, c_t_previous])
            s = s.add_input(vector)
            h_t = s.output()
            c_t, alpha_t = self.attend(input_mat, s, w1dt, input_len,
                                       curr_bsize)

            h_c_concat = dy.concatenate([h_t, c_t])
            out_vector = dy.affine_transform([b, w, h_c_concat])

            if DROPOUT > 0.0:
                out_vector = dy.dropout(out_vector, DROPOUT)

            loss_current = dy.pickneglogsoftmax_batch(out_vector,
                                                      output_batch[i])

            if 0 in mask_word:
                mask_vals = dy.inputVector(mask_word)
                mask_vals = dy.reshape(mask_vals, (1, ), curr_bsize)
                loss_current = loss_current * mask_vals

            loss.append(loss_current)
            c_t_previous = c_t

        loss = dy.esum(loss)
        loss = dy.sum_batches(loss) / curr_bsize
        #perplexity = loss.value() * curr_bsize / float(sum([x.count(1) for x in masks[1:]]))
        return loss
Пример #2
0
def decode_batch(dec_lstm, input_encodings, output_sentences):
	#print "in decode batch"
	w = dy.parameter(decoder_w)
	b = dy.parameter(decoder_b)
	w1 = dy.parameter(attention_w1)

	output_words, masks = sentences_to_batch(output_sentences)
	decoder_target_input = zip(output_words[1:], masks[1:])

	batch_size = len(output_sentences)
	input_length = len(input_encodings)

	input_mat = dy.concatenate_cols(input_encodings)
        #print "Computing w1dt"
	w1dt = w1 * input_mat

	s = dec_lstm.initial_state()
	c_t_minus_1 = dy.vecInput(state_size * 2)
	loss = []

	for t, (words_t, mask_t) in enumerate(decoder_target_input, start = 1):
		last_output_embeddings = dy.lookup_batch(output_lookup, output_words[t-1])
		vector = dy.concatenate([last_output_embeddings, c_t_minus_1])
		s = s.add_input(vector)
		h_t = s.output()
                #print "Calling attend"
		c_t = attend_batch(input_mat, s, w1dt, batch_size, input_length)
		predicted = dy.affine_transform([b, w, dy.concatenate([h_t, c_t])])
		if(dropout > 0.):
			predicted = dy.dropout(predicted, dropout)
		cur_loss = dy.pickneglogsoftmax_batch(predicted, words_t)
                c_t_minus_1 = c_t

		#Mask the loss in case mask == 0
		if 0 in mask_t:
		    mask = dy.inputVector(mask_t)
		    mask = dy.reshape(mask, (1, ), batch_size)
		    cur_loss = cur_loss * mask

		loss.append(cur_loss)
	
	#Get the average batch loss
	loss = dy.esum(loss)
	loss = dy.sum_batches(loss) / batch_size
	return loss
Пример #3
0
    def myDecode(self, dec_lstm, h_encodings, target_sents):
        w = dy.parameter(decoder_w)
        b = dy.parameter(decoder_b)
        w1 = dy.parameter(attention_w1)

        dec_wrds, dec_mask = self.pad_zero(target_sents)
        curr_bsize = len(target_sents)
        h_len = len(h_encodings)

        H_source = dy.concatenate_cols(h_encodings)
        s = dec_lstm.initial_state()
        ctx_t0 = dy.vecInput(hidden_size * 2)
        w1dt = w1 * H_source

        loss = []

        #print curr_bsize
        for sent in range(1, len(dec_wrds)):
            last_output_embeddings = dy.lookup_batch(output_lookup, dec_wrds[sent-1])
            x = dy.concatenate([ctx_t0, last_output_embeddings])
            s = s.add_input(x)
            h_t = s.output()

            ctx_t, alpha_t = self.attend(H_source, s, w1dt, h_len, curr_bsize)
            output_vector = w * dy.concatenate([h_t, ctx_t]) + b
            #probs = dy.softmax(output_vector)
            ctx_t0 = ctx_t
            if dropout_config:
                output_vector = dy.dropout(output_vector, dropout_val)

            temp_loss = dy.pickneglogsoftmax_batch(output_vector, dec_wrds[sent])

            if 0 in dec_mask[sent]:
                mask_expr = dy.inputVector(dec_mask[sent])
                mask_expr = dy.reshape(mask_expr, (1, ), curr_bsize)
                temp_loss = temp_loss * mask_expr

            loss.append(temp_loss)

        loss = dy.esum(loss)
        loss = dy.sum_batches(loss) / batch_size
        return loss
Пример #4
0
def calc_lm_loss(sents):

    dy.renew_cg()
    # parameters -> expressions
    W_exp = dy.parameter(W_sm)
    b_exp = dy.parameter(b_sm)

    # initialize the RNN
    f_init = RNN.initial_state()

    # get the wids and masks for each step
    tot_words = 0
    wids = []
    masks = []

    for i in range(len(sents[0])):
        wids.append([(vw.w2i[sent[i]] if len(sent) > i else STOP) for sent in sents])
        mask = [(1 if len(sent) > i else 0) for sent in sents]
        masks.append(mask)
        tot_words += sum(mask)

    # start the rnn by inputting "<start>"
    init_ids = [START] * len(sents)
    s = f_init.add_input(dy.lookup_batch(WORDS_LOOKUP, init_ids))

    # feed word vectors into the RNN and predict the next word
    losses = []
    for wid, mask in zip(wids, masks):
        # calculate the softmax and loss
        score = W_exp * s.output() + b_exp
        loss = dy.pickneglogsoftmax_batch(score, wid)
        # mask the loss if at least one sentence is shorter
        if mask[-1] != 1:
            mask_expr = dy.inputVector(mask)
            mask_expr = dy.reshape(mask_expr, (1,), MB_SIZE)
            loss = loss * mask_expr
        losses.append(loss)
        # update the state of the RNN
        wemb = dy.lookup_batch(WORDS_LOOKUP, wid)
        s = s.add_input(wemb)

    return dy.sum_batches(dy.esum(losses)), tot_words
Пример #5
0
def decode_batch(dec_lstm, input_encodings, output_sentences):
	#print "in decode batch"
	w = dy.parameter(decoder_w)
	b = dy.parameter(decoder_b)
	w1 = dy.parameter(attention_w1)

	output_words, masks = sentences_to_batch(output_sentences)
	#decoder_target_input = zip(output_words[1:], masks[1:])

	batch_size = len(output_sentences)
	input_length = len(input_encodings)

	input_mat = dy.concatenate_cols(input_encodings)
        #print "Computing w1dt"
	w1dt = w1 * input_mat

	s = dec_lstm.initial_state()
	c_t_minus_1 = dy.vecInput(state_size * 2)
	loss = []

	for t in range(1, len(output_words)):
	    last_output_embeddings = dy.lookup_batch(output_lookup, output_words[t-1])
	    vector = dy.concatenate([c_t_minus_1, last_output_embeddings])
	    s = s.add_input(vector)
	    h_t = s.output()
            #print "Calling attend"
	    c_t = attend_batch(input_mat, s, w1dt, batch_size, input_length)
	    predicted = w * dy.concatenate([h_t, c_t]) + b
	    if(dropout > 0.):
		predicted = dy.dropout(predicted, dropout)
	    cur_loss = dy.pickneglogsoftmax_batch(predicted, output_words[t])
        c_t_minus_1 = c_t

	    #Mask the loss in case mask == 0
	    if 0 in masks[t]:
		mask = dy.inputVector(masks[t])
		mask = dy.reshape(mask, (1, ), batch_size)
		cur_loss = cur_loss * mask

	    loss.append(cur_loss)
Пример #6
0
def build_batch_sentences_graph(isents, builders):
    #NEXT STAGE - remove sentences that are not the same length from batch, and remove all mask stuff
    renew_cg()
    f_init, b_init = [b.initial_state() for b in builders]
    W = parameter(W_sm)
    b = parameter(b_sm)

    #Train first forward direction LSTM
    #Init forward aux variables
    f_tot_words = 0
    f_wids_list = []
    f_masks = []
    #Prepare the batch. Mask shorter sentences
    for i in range(len(isents[0])):
        f_wids_list.append([(vw.w2i[sent[i]] if len(sent) > i else STOP)
                            for sent in isents])
        mask = [(1 if len(sent) > i else 0) for sent in isents]
        f_masks.append(mask)
        f_tot_words += sum(mask)
    #Feed the batched words into forward lstm
    f_outputs = []
    state = f_init
    for wids in f_wids_list:  #wids is a list of i-th word in each sentence across the batch
        wembs = dy.lookup_batch(WORDS_LOOKUP, wids)
        state = state.add_input(wembs)
        f_outputs.append(state.output())  #get output batch

    #Train backward direction of lstm
    b_tot_words = 0
    b_wids_list = []
    b_masks = []
    # Prepare the batch. Mask shorter sentences
    for i in range(len(isents[0])):
        b_wids_list.append([(vw.w2i[sent[::-1][i]] if len(sent) > i else START)
                            for sent in isents])
        mask = [(1 if len(sent) > i else 0) for sent in isents]
        b_masks.append(mask)
        b_tot_words += sum(mask)
    # Feed the batched words into backward lstm
    b_outputs = []
    state = b_init
    for wids in b_wids_list:  # wids is a list of i-th word in each sentence across the batch
        wembs = dy.lookup_batch(WORDS_LOOKUP, wids)
        state = state.add_input(wembs)
        b_outputs.append(state.output())  # get output batch

    #Compute loss via batch softmax
    errs = []

    for i in range(len(isents[0]) - 3):
        y = concatenate([f_outputs[i], b_outputs[::-1][i + 2]])
        r = b + (W * y)
        err = pickneglogsoftmax_batch(r, f_wids_list[i + 1])
        #print "FMASK", f_masks
        if f_masks[i][-1] != 1 or b_masks[::-1][i + 2][-1] != 1:
            complete_mask = f_masks + b_masks
            #   print "COMPLETE", complete_mask
            mask_expr = dy.inputVector(complete_mask)
            mask_expr = dy.reshape(mask_expr, (1, ), MB_SIZE)
            err = err * mask_expr
        errs.append(err)

    losses = sum_batches(esum(errs))
    return losses
Пример #7
0
 def loss(self, observation, target_rep):
     return dy.squared_distance(observation, dy.inputVector(target_rep))
Пример #8
0
def build_batch_sentences_graph(isents, builders):
    #NEXT STAGE - remove sentences that are not the same length from batch, and remove all mask stuff
    renew_cg()
    f_init, b_init = [b.initial_state() for b in builders]
    W = parameter(W_sm)
    b = parameter(b_sm)

    #Train first forward direction LSTM
    #Init forward aux variables
    f_tot_words = 0
    f_wids_list = []
    f_masks = []
    #Prepare the batch. Mask shorter sentences
    for i in range(len(isents[0])):
        f_wids_list.append([(vw.w2i[sent[i]] if len(sent) > i else STOP) for sent in isents])
        mask = [(1 if len(sent) > i else 0) for sent in isents]
        f_masks.append(mask)
        f_tot_words += sum(mask)
    #Feed the batched words into forward lstm
    f_outputs = []
    state = f_init
    for wids in f_wids_list: #wids is a list of i-th word in each sentence across the batch
        wembs = dy.lookup_batch(WORDS_LOOKUP, wids)
        state = state.add_input(wembs)
        f_outputs.append(state.output()) #get output batch

    #Train backward direction of lstm
    b_tot_words = 0
    b_wids_list = []
    b_masks = []
    # Prepare the batch. Mask shorter sentences
    for i in range(len(isents[0])):
        b_wids_list.append([(vw.w2i[sent[::-1][i]] if len(sent) > i else START) for sent in isents])
        mask = [(1 if len(sent) > i else 0) for sent in isents]
        b_masks.append(mask)
        b_tot_words += sum(mask)
    # Feed the batched words into backward lstm
    b_outputs = []
    state = b_init
    for wids in b_wids_list:  # wids is a list of i-th word in each sentence across the batch
        wembs = dy.lookup_batch(WORDS_LOOKUP, wids)
        state = state.add_input(wembs)
        b_outputs.append(state.output())  # get output batch

    #Compute loss via batch softmax
    errs = []

    for i in range(len(isents[0])-3):
        y = concatenate([f_outputs[i], b_outputs[::-1][i+2]])
        r = b + (W * y)
        err = pickneglogsoftmax_batch(r, f_wids_list[i+1])
        #print "FMASK", f_masks
        if f_masks[i][-1] != 1 or b_masks[::-1][i+2][-1] != 1:
            complete_mask = f_masks + b_masks
         #   print "COMPLETE", complete_mask
            mask_expr = dy.inputVector(complete_mask)
            mask_expr = dy.reshape(mask_expr, (1,), MB_SIZE)
            err = err * mask_expr
        errs.append(err)

    losses = sum_batches(esum(errs))
    return losses
Пример #9
0
    def __step_batch(self, batch):
        dy.renew_cg()
        W_y = dy.parameter(self.W_y)
        b_y = dy.parameter(self.b_y)
        F = len(batch[0][0])
        num_words = F * len(batch)

        src_batch = [x[0] for x in batch]
        tgt_batch = [x[1] for x in batch]
        src_rev_batch = [list(reversed(x)) for x in src_batch]
        # batch  = [ [a1,a2,a3,a4,a5], [b1,b2,b3,b4,b5], [c1,c2,c3,c4] ..]
        # transpose the batch into
        #   src_cws: [[a1,b1,c1,..], [a2,b2,c2,..], .. [a5,b5,</S>]]

        src_cws = map(list, zip(*src_batch))  # transpose
        src_rev_cws = map(list, zip(*src_rev_batch))

        # Bidirectional representations
        l2r_state = self.l2r_builder.initial_state()
        r2l_state = self.r2l_builder.initial_state()
        l2r_contexts = []
        r2l_contexts = []
        for (cw_l2r_list, cw_r2l_list) in zip(src_cws, src_rev_cws):
            l2r_state = l2r_state.add_input(
                dy.lookup_batch(self.src_lookup, [
                    self.src_token_to_id.get(cw_l2r, 0)
                    for cw_l2r in cw_l2r_list
                ]))
            r2l_state = r2l_state.add_input(
                dy.lookup_batch(self.src_lookup, [
                    self.src_token_to_id.get(cw_r2l, 0)
                    for cw_r2l in cw_r2l_list
                ]))

            l2r_contexts.append(
                l2r_state.output())  #[<S>, x_1, x_2, ..., </S>]
            r2l_contexts.append(
                r2l_state.output())  #[</S> x_n, x_{n-1}, ... <S>]

        r2l_contexts.reverse()  #[<S>, x_1, x_2, ..., </S>]

        # Combine the left and right representations for every word
        h_fs = []
        for (l2r_i, r2l_i) in zip(l2r_contexts, r2l_contexts):
            h_fs.append(dy.concatenate([l2r_i, r2l_i]))
        h_fs_matrix = dy.concatenate_cols(h_fs)

        losses = []

        # Decoder
        # batch  = [ [a1,a2,a3,a4,a5], [b1,b2,b3,b4,b5], [c1,c2,c3,c4] ..]
        # transpose the batch into
        #   tgt_cws: [[a1,b1,c1,..], [a2,b2,c2,..], .. [a5,b5,</S>]]
        #   masks: [1,1,1,..], [1,1,1,..], ...[1,1,0,..]]

        tgt_cws = []
        masks = []
        maxLen = max([len(l) for l in tgt_batch])
        for i in range(maxLen):
            tgt_cws.append([])
            masks.append([])
        for sentence in tgt_batch:
            for j in range(maxLen):
                if j > len(sentence) - 1:
                    tgt_cws[j].append('</S>')
                    masks[j].append(0)
                else:
                    tgt_cws[j].append(sentence[j])
                    masks[j].append(1)
        c_t = dy.vecInput(self.hidden_size * 2)
        start = dy.concatenate(
            [dy.lookup(self.tgt_lookup, self.tgt_token_to_id['<S>']), c_t])
        dec_state = self.dec_builder.initial_state().add_input(start)
        for (cws, nws, mask) in zip(tgt_cws, tgt_cws[1:], masks):
            h_e = dec_state.output()
            _, c_t = self.__attention_mlp(h_fs_matrix, h_e, F)
            # Get the embedding for the current target word
            embed_t = dy.lookup_batch(
                self.tgt_lookup,
                [self.tgt_token_to_id.get(cw, 0) for cw in cws])
            # Create input vector to the decoder
            x_t = dy.concatenate([embed_t, c_t])
            dec_state = dec_state.add_input(x_t)
            y_star = dy.affine_transform([b_y, W_y, dec_state.output()])
            loss = dy.pickneglogsoftmax_batch(
                y_star, [self.tgt_token_to_id.get(nw, 0) for nw in nws])
            mask_exp = dy.reshape(dy.inputVector(mask), (1, ), len(mask))
            loss = loss * mask_exp
            losses.append(loss)
        return dy.sum_batches(dy.esum(losses)), num_words