Example #1
0
 def transduce(self, embeds):
     expr_seq = []
     seq_len = embeds.dim()[0][1]
     for i in range(seq_len):
         expr_seq.append(dy.max_dim(dy.select_cols(embeds, [i]), 1))
     encodings = self.seq_transducer.transduce(ExpressionSequence(expr_seq))
     return self.seq_transducer.get_final_states()[-1].main_expr()
Example #2
0
    def translate(self, x, beam_size=1):
        """Translate a source sentence
        
        Translate a single source sentence by decoding using beam search

        Arguments:
            x (list): Source sentence (list of indices)
        
        Keyword Arguments:
            beam_size (int): Size of the beam for beam search. A value of 1 means greedy decoding (default: (1))
        
        Returns:
            list: generated translation (list of indices)
        """
        dy.renew_cg()
        input_len = len(x)
        encodings = self.encode([x], test=True)
        # Decode
        # Add parameters to the graph
        Wp, bp = self.Wp_p.expr(), self.bp_p.expr()
        Wo, bo = self.Wo_p.expr(), self.bo_p.expr()
        D, b = dy.transpose(dy.parameter(self.MT_p)), self.b_p.expr()
        # Initialize decoder with last encoding
        last_enc = dy.select_cols(encodings, [encodings.dim()[0][-1] - 1])
        init_state = dy.affine_transform([bp, Wp, last_enc])
        ds = self.dec.initial_state([init_state, dy.zeroes((self.dh, ))])
        # Initialize context
        context = dy.zeroes((self.enc_dim, ))
        # Initialize beam
        beam = [(ds, context, [self.trg_sos], 0.0)]
        # Loop
        for i in range(int(min(self.max_len, input_len * 1.5))):
            new_beam = []
            for ds, pc, pw, logprob in beam:
                embs = dy.lookup(self.MT_p, pw[-1])
                # Run LSTM
                ds = ds.add_input(dy.concatenate([embs, pc]))
                h = ds.output()
                # Compute next context
                context, _ = self.attend(encodings, h)
                # Compute output with residual connections
                output = dy.affine_transform(
                    [bo, Wo, dy.concatenate([h, context, embs])])
                # Score
                s = dy.affine_transform([b, D, output])
                # Probabilities
                p = dy.softmax(s).npvalue().flatten()
                # Careful of float error
                p = p / p.sum()
                kbest = np.argsort(p)
                for nw in kbest[-beam_size:]:
                    new_beam.append(
                        (ds, context, pw + [nw], logprob + np.log(p[nw])))

            beam = sorted(new_beam, key=lambda x: x[-1])[-beam_size:]

            if beam[-1][2][-1] == self.trg_eos:
                break

        return beam[-1][2]
Example #3
0
    def decode_loss(self, src_encodings, masks, tgt_seqs, sents_len):
        """
        :param tgt_seqs: (tgt_heads, tgt_labels): list (length=batch_size) of (src_len)
        """

        tgt_heads, tgt_labels = tgt_seqs

        src_len = len(tgt_heads[0])
        batch_size = len(tgt_heads)
        np_tgt_heads = np.array(tgt_heads).flatten()  # (src_len * batch_size)
        np_tgt_labels = np.array(tgt_labels).flatten()
        np_masks = np.array(masks).transpose().flatten()
        masks_expr = dy.inputVector(np_masks)
        masks_expr = dy.reshape(masks_expr, (1, ),
                                batch_size=src_len * batch_size)

        s_arc, s_label = self.cal_scores(
            src_encodings
        )  # (src_len, src_len, bs), ([(src_len, src_len, bs)])
        s_arc = dy.select_cols(s_arc, range(1, src_len + 1))
        s_label = [
            dy.select_cols(label, range(1, src_len + 1)) for label in s_label
        ]
        s_arc_value = s_arc.npvalue()
        s_arc_choice = np.argmax(
            s_arc_value,
            axis=0).transpose().flatten()  # (src_len * batch_size)
        s_pick_labels = [
            dy.pick_batch(
                dy.reshape(score, (src_len + 1, ),
                           batch_size=src_len * batch_size), s_arc_choice)
            for score in s_label
        ]
        s_argmax_labels = dy.concatenate(s_pick_labels,
                                         d=0)  # n_labels, src_len * batch_size

        reshape_s_arc = dy.reshape(s_arc, (src_len + 1, ),
                                   batch_size=src_len * batch_size)
        arc_loss = dy.pickneglogsoftmax_batch(reshape_s_arc, np_tgt_heads)
        arc_loss = arc_loss * masks_expr
        label_loss = dy.pickneglogsoftmax_batch(s_argmax_labels, np_tgt_labels)
        label_loss = label_loss * masks_expr
        loss = dy.sum_batches(arc_loss + label_loss)
        return loss
Example #4
0
def split(x, dim=1):
    head_shape, batch_size = x.dim()
    res = []
    if dim == 0:
        for i in range(head_shape[0]):
            res.append(dy.select_rows(x, [i]))
    elif dim == 1:

        for i in range(head_shape[1]):
            res.append(dy.select_cols(x, [i]))
    return res
Example #5
0
 def decode_loss(self, encodings, trg, test=False):
     """Compute the negative conditional log likelihood of the target sentence given the encoding of the source sentence
     
     Arguments:
         encodings (dynet.Expression): Source sentence encodings obtained with self.encode
         trg (list): List of target sentences
     
     Keyword Arguments:
         test (bool): Switch used for things like dropout where the behaviour is different at test time (default: (False)
     
     Returns:
         dynet.Expression: Expression of the loss averaged on the minibatch
     """
     y, masksy = self.prepare_batch(trg, self.trg_eos)
     slen, bsize = y.shape
     # Add parameters to the graph
     Wp, bp = self.Wp_p.expr(), self.bp_p.expr()
     Wo, bo = self.Wo_p.expr(), self.bo_p.expr()
     D, b = dy.transpose(dy.parameter(self.MT_p)), self.b_p.expr()
     # Initialize decoder with last encoding
     last_enc = dy.select_cols(encodings, [encodings.dim()[0][-1] - 1])
     init_state = dy.affine_transform([bp, Wp, last_enc])
     ds = self.dec.initial_state(
         [init_state, dy.zeroes((self.dh, ), batch_size=bsize)])
     # Initialize context
     context = dy.zeroes((self.enc_dim, ), batch_size=bsize)
     # Start decoding
     errs = []
     for cw, nw, mask in zip(y, y[1:], masksy[1:]):
         embs = dy.lookup_batch(self.MT_p, cw)
         # Run LSTM
         ds = ds.add_input(dy.concatenate([embs, context]))
         h = ds.output()
         # Compute next context
         context, _ = self.attend(encodings, h)
         # Compute output with residual connections
         output = dy.affine_transform(
             [bo, Wo, dy.concatenate([h, context, embs])])
         if not test:
             output = dy.dropout(output, self.dr)
         # Score
         s = dy.affine_transform([b, D, output])
         masksy_e = dy.inputTensor(mask, batched=True)
         # Loss
         err = dy.cmult(dy.pickneglogsoftmax_batch(s, nw), masksy_e)
         errs.append(err)
     # Add all losses together
     err = dy.sum_batches(dy.esum(errs)) / float(bsize)
     return err
Example #6
0
    def _get_scores(self, sentence, is_train, elmo_embeddings):
        lstm_outputs = self._featurize_sentence(
            sentence, is_train=is_train, elmo_embeddings=elmo_embeddings)

        other_encodings = []
        single_word_encodings = []
        temporary_span_to_index = {}
        for left in range(len(sentence)):
            for right in range(left + 1, len(sentence) + 1):
                encoding = self._get_span_encoding(left, right, lstm_outputs)
                span = (left, right)
                if right - left == 1:
                    temporary_span_to_index[span] = len(single_word_encodings)
                    single_word_encodings.append(encoding)
                else:
                    temporary_span_to_index[span] = len(other_encodings)
                    other_encodings.append(encoding)

        encodings = single_word_encodings + other_encodings
        span_to_index = {}
        for span, index in temporary_span_to_index.items():
            if span[1] - span[0] == 1:
                new_index = index
            else:
                new_index = index + len(single_word_encodings)
            span_to_index[span] = new_index
        span_encodings = dy.rectify(
            dy.reshape(self.f_encoding(dy.concatenate_to_batch(encodings)),
                       (self.hidden_dim, len(encodings))))
        label_scores = self.f_label(span_encodings)
        label_scores_reshaped = dy.reshape(
            label_scores, (self.label_vocab.size, len(encodings)))
        label_log_probabilities = dy.log_softmax(label_scores_reshaped)
        single_word_span_encodings = dy.select_cols(
            span_encodings, list(range(len(single_word_encodings))))
        tag_scores = self.f_tag(single_word_span_encodings)
        tag_scores_reshaped = dy.reshape(
            tag_scores, (self.tag_vocab.size, len(single_word_encodings)))
        tag_log_probabilities = dy.log_softmax(tag_scores_reshaped)
        return label_log_probabilities, tag_log_probabilities, span_to_index
Example #7
0
    def predict(self, text):
        dy.renew_cg()

        output_mgc = []
        last_mgc = self.start_lookup[0]

        x = self._make_input(text)
        x_speaker = self._get_speaker_embedding()

        x_fw = self.encoder_fw.initial_state().transduce(x)
        x_bw = self.encoder_bw.initial_state().transduce(reversed(x))

        encoder = [
            dy.concatenate([fw, bw, x_speaker])
            for fw, bw in zip(x_fw, reversed(x_bw))
        ]
        encoder = dy.concatenate(encoder, 1)
        hidden_encoder = self.att_w1 * encoder

        decoder = self.decoder.initial_state().add_input(
            self.decoder_start_lookup[0])

        last_att_pos = 0
        warm_up = 5
        finish = 5

        for k in range(5 * len(text)):
            attention_weights = (
                self.att_v *
                dy.tanh(hidden_encoder + self.att_w2 * decoder.output()))[0]
            current_pos = np.argmax(attention_weights.npvalue())

            if current_pos < last_att_pos:
                current_pos = last_att_pos

            if current_pos > last_att_pos:
                current_pos = last_att_pos + 1

            last_att_pos = current_pos
            att = dy.select_cols(encoder, [current_pos])

            if warm_up > 0:
                last_att_pos = 0
                warm_up -= 1

            if last_att_pos >= len(text):
                if finish > 0:
                    finish -= 1
                else:
                    break

            mgc_proj = dy.tanh(self.last_mgc_proj_w * last_mgc +
                               self.last_mgc_proj_b)
            decoder = decoder.add_input(dy.concatenate([mgc_proj, att]))
            hidden = dy.tanh(self.hid_w * decoder.output() + self.hid_b)

            highway_hidden = self.highway_w * att
            output_mgc.append(
                dy.logistic(highway_hidden + self.proj_w_1 * hidden +
                            self.proj_b_1))
            output_mgc.append(
                dy.logistic(highway_hidden + self.proj_w_2 * hidden +
                            self.proj_b_2))
            output_mgc.append(
                dy.logistic(highway_hidden + self.proj_w_3 * hidden +
                            self.proj_b_3))

            last_mgc = output_mgc[-1]
            '''
                Stop layer seems to finish in about 40% of cases with an average of 3 steps earlier
                However, it introduces another matmul, and further testing proved that it is faster without it

                The testing was done on a sentence level, the results may be better on a word parallelization, so 
                I'm not removing it yet.
            '''
            # output_stop = dy.tanh(self.stop_w * decoder.output() + self.stop_b)
            # if output_stop.value() < -0.5:
            #     break

        return output_mgc
Example #8
0
def split_cols(matrix):
    total, rows, cols, bt = matrix.dim()
    assert bt == 1
    return [dy.reshape(dy.select_cols(matrix, [i]), (rows,), batch_size=bt) for i in xrange(cols)]
Example #9
0
    def _span_parser_predict_pos(self, sentence, is_train, elmo_embeddings, gold=None):
        if gold is not None:
            assert isinstance(gold, ParseNode)

        lstm_outputs = self._featurize_sentence(sentence, is_train=is_train,
                                                elmo_embeddings=elmo_embeddings)

        other_encodings = []
        single_word_encodings = []
        temporary_span_to_index = {}
        for left in range(len(sentence)):
            for right in range(left + 1, len(sentence) + 1):
                encoding = self._get_span_encoding(left, right, lstm_outputs)
                span = (left, right)
                if right - left == 1:
                    temporary_span_to_index[span] = len(single_word_encodings)
                    single_word_encodings.append(encoding)
                else:
                    temporary_span_to_index[span] = len(other_encodings)
                    other_encodings.append(encoding)

        encodings = single_word_encodings + other_encodings
        span_to_index = {}
        for span, index in temporary_span_to_index.items():
            if span[1] - span[0] == 1:
                new_index = index
            else:
                new_index = index + len(single_word_encodings)
            span_to_index[span] = new_index
        span_encodings = dy.rectify(dy.reshape(self.f_encoding(dy.concatenate_to_batch(encodings)),
                                               (self.hidden_dim, len(encodings))))
        label_scores = self.f_label(span_encodings)
        label_scores_reshaped = dy.reshape(label_scores, (self.label_vocab.size, len(encodings)))
        label_log_probabilities = dy.log_softmax(label_scores_reshaped)
        single_word_span_encodings = dy.select_cols(span_encodings,
                                                    list(range(len(single_word_encodings))))
        tag_scores = self.f_tag(single_word_span_encodings)
        tag_scores_reshaped = dy.reshape(tag_scores,
                                         (self.tag_vocab.size, len(single_word_encodings)))
        tag_log_probabilities = dy.log_softmax(tag_scores_reshaped)

        if is_train:
            total_loss = dy.zeros(1)
            span_to_gold_label = get_all_spans(gold)
            for span, oracle_label in span_to_gold_label.items():
                oracle_label_index = self.label_vocab.index(oracle_label)
                index = span_to_index[span]
                if span[1] - span[0] == 1:
                    oracle_tag = sentence[span[0]][0]
                    total_loss -= tag_log_probabilities[self.tag_vocab.index(oracle_tag)][index]
                total_loss -= label_log_probabilities[oracle_label_index][index]
            return total_loss
        else:
            label_log_probabilities_np = label_log_probabilities.npvalue()
            tag_log_probabilities_np = tag_log_probabilities.npvalue()
            sentence_with_tags = []
            num_correct = 0
            total = 0
            # print('output has gold pos tags')
            for word_index, (oracle_tag, word) in enumerate(sentence):
                # tag_index = np.argmax(tag_log_probabilities_np[:, word_index])
                # tag = self.tag_vocab.value(tag_index)
                # oracle_tag_is_deletable = oracle_tag in deletable_tags
                # predicted_tag_is_deletable = tag in deletable_tags
                # if oracle_tag is not None and oracle_tag not in self.tag_vocab.indices:
                #     print(oracle_tag, 'not in tag vocab')
                #     oracle_tag = None
                # if oracle_tag is not None:
                #     oracle_tag_index = self.tag_vocab.index(oracle_tag)
                #     if oracle_tag_index == tag_index and tag != oracle_tag:
                #         if oracle_tag[0] != '-':
                #             print(tag, oracle_tag)
                #         tag = oracle_tag
                #     num_correct += tag_index == oracle_tag_index

                # if oracle_tag is not None and oracle_tag_is_deletable != predicted_tag_is_deletable:
                #     # print('falling back on gold tag', oracle_tag, tag)
                #     sentence_with_tags.append((oracle_tag, word))
                # else:
                #     sentence_with_tags.append((tag, word))
                assert oracle_tag is not None
                sentence_with_tags.append((oracle_tag, word))

                total += 1
            tree, additional_info = optimal_parser(label_log_probabilities_np,
                                                   span_to_index,
                                                   sentence_with_tags,
                                                   self.empty_label_index,
                                                   self.label_vocab,
                                                   gold)
            return tree, (additional_info, num_correct, total)