class ExtractiveQAModel(Initializable):
    """The dictionary-equipped extractive QA model.

    Parameters
    ----------
    dim : int
        The default dimensionality for the components.
    emd_dim : int
        The dimensionality for the embeddings. If 0, `dim` is used.
    coattention : bool
        Use the coattention mechanism.
    num_input_words : int
        The number of input words. If 0, `vocab.size()` is used.
        The vocabulary object.
    use_definitions : bool
        Triggers the use of definitions.
    reuse_word_embeddings : bool
    compose_type : str

    """
    def __init__(self, dim, emb_dim, readout_dims, num_input_words,
                 def_num_input_words, vocab, use_definitions, def_word_gating,
                 compose_type, coattention, def_reader, reuse_word_embeddings,
                 random_unk, **kwargs):
        self._vocab = vocab
        if emb_dim == 0:
            emb_dim = dim
        if num_input_words == 0:
            num_input_words = vocab.size()
        if def_num_input_words == 0:
            def_num_input_words = num_input_words

        self._coattention = coattention
        self._num_input_words = num_input_words
        self._use_definitions = use_definitions
        self._random_unk = random_unk
        self._reuse_word_embeddings = reuse_word_embeddings

        lookup_num_words = num_input_words
        if reuse_word_embeddings:
            lookup_num_words = max(num_input_words, def_num_input_words)
        if random_unk:
            lookup_num_words = vocab.size()

        # Dima: we can have slightly less copy-paste here if we
        # copy the RecurrentFromFork class from my other projects.
        children = []
        self._lookup = LookupTable(lookup_num_words, emb_dim)
        self._encoder_fork = Linear(emb_dim, 4 * dim, name='encoder_fork')
        self._encoder_rnn = LSTM(dim, name='encoder_rnn')
        self._question_transform = Linear(dim, dim, name='question_transform')
        self._bidir_fork = Linear(3 * dim if coattention else 2 * dim,
                                  4 * dim,
                                  name='bidir_fork')
        self._bidir = Bidirectional(LSTM(dim), name='bidir')
        children.extend([
            self._lookup, self._encoder_fork, self._encoder_rnn,
            self._question_transform, self._bidir, self._bidir_fork
        ])

        activations = [Rectifier()] * len(readout_dims) + [None]
        readout_dims = [2 * dim] + readout_dims + [1]
        self._begin_readout = MLP(activations,
                                  readout_dims,
                                  name='begin_readout')
        self._end_readout = MLP(activations, readout_dims, name='end_readout')
        self._softmax = NDimensionalSoftmax()
        children.extend(
            [self._begin_readout, self._end_readout, self._softmax])

        if self._use_definitions:
            # A potential bug here: we pass the same vocab to the def reader.
            # If a different token is reserved for UNK in text and in the definitions,
            # we can be screwed.
            def_reader_class = eval(def_reader)
            def_reader_kwargs = dict(
                num_input_words=def_num_input_words,
                dim=dim,
                emb_dim=emb_dim,
                vocab=vocab,
                lookup=self._lookup if reuse_word_embeddings else None)
            if def_reader_class == MeanPoolReadDefinitions:
                def_reader_kwargs.update(dict(normalize=True, translate=False))
            self._def_reader = def_reader_class(**def_reader_kwargs)
            self._combiner = MeanPoolCombiner(dim=dim,
                                              emb_dim=emb_dim,
                                              def_word_gating=def_word_gating,
                                              compose_type=compose_type)
            children.extend([self._def_reader, self._combiner])

        super(ExtractiveQAModel, self).__init__(children=children, **kwargs)

        # create default input variables
        self.contexts = tensor.lmatrix('contexts')
        self.context_mask = tensor.matrix('contexts_mask')
        self.questions = tensor.lmatrix('questions')
        self.question_mask = tensor.matrix('questions_mask')
        self.answer_begins = tensor.lvector('answer_begins')
        self.answer_ends = tensor.lvector('answer_ends')
        input_vars = [
            self.contexts, self.context_mask, self.questions,
            self.question_mask, self.answer_begins, self.answer_ends
        ]
        if self._use_definitions:
            self.defs = tensor.lmatrix('defs')
            self.def_mask = tensor.matrix('def_mask')
            self.contexts_def_map = tensor.lmatrix('contexts_def_map')
            self.questions_def_map = tensor.lmatrix('questions_def_map')
            input_vars.extend([
                self.defs, self.def_mask, self.contexts_def_map,
                self.questions_def_map
            ])
        self.input_vars = OrderedDict([(var.name, var) for var in input_vars])

    def set_embeddings(self, embeddings):
        self._lookup.parameters[0].set_value(
            embeddings.astype(theano.config.floatX))

    def embeddings_var(self):
        return self._lookup.parameters[0]

    def def_reading_parameters(self):
        parameters = Selector(self._def_reader).get_parameters().values()
        parameters.extend(Selector(self._combiner).get_parameters().values())
        if self._reuse_word_embeddings:
            lookup_parameters = Selector(
                self._lookup).get_parameters().values()
            parameters = [p for p in parameters if p not in lookup_parameters]
        return parameters

    @application
    def _encode(self,
                application_call,
                text,
                mask,
                def_embs=None,
                def_map=None,
                text_name=None):
        if not self._random_unk:
            text = (tensor.lt(text, self._num_input_words) * text +
                    tensor.ge(text, self._num_input_words) * self._vocab.unk)
        if text_name:
            application_call.add_auxiliary_variable(
                unk_ratio(text, mask, self._vocab.unk),
                name='{}_unk_ratio'.format(text_name))
        embs = self._lookup.apply(text)
        if self._random_unk:
            embs = (tensor.lt(text, self._num_input_words)[:, :, None] * embs +
                    tensor.ge(text, self._num_input_words)[:, :, None] *
                    disconnected_grad(embs))
        if def_embs:
            embs = self._combiner.apply(embs, mask, def_embs, def_map)
        add_role(embs, EMBEDDINGS)
        encoded = flip01(
            self._encoder_rnn.apply(self._encoder_fork.apply(flip01(embs)),
                                    mask=mask.T)[0])
        return encoded

    @application
    def apply(self,
              application_call,
              contexts,
              contexts_mask,
              questions,
              questions_mask,
              answer_begins,
              answer_ends,
              defs=None,
              def_mask=None,
              contexts_def_map=None,
              questions_def_map=None):
        def_embs = None
        if self._use_definitions:
            def_embs = self._def_reader.apply(defs, def_mask)

        context_enc = self._encode(contexts, contexts_mask, def_embs,
                                   contexts_def_map, 'context')
        question_enc_pre = self._encode(questions, questions_mask, def_embs,
                                        questions_def_map, 'question')
        question_enc = tensor.tanh(
            self._question_transform.apply(question_enc_pre))

        # should be (batch size, context length, question_length)
        affinity = tensor.batched_dot(context_enc, flip12(question_enc))
        affinity_mask = contexts_mask[:, :, None] * questions_mask[:, None, :]
        affinity = affinity * affinity_mask - 1000.0 * (1 - affinity_mask)
        # soft-aligns every position in the context to positions in the question
        d2q_att_weights = self._softmax.apply(affinity, extra_ndim=1)
        application_call.add_auxiliary_variable(d2q_att_weights.copy(),
                                                name='d2q_att_weights')
        # soft-aligns every position in the question to positions in the document
        q2d_att_weights = self._softmax.apply(flip12(affinity), extra_ndim=1)
        application_call.add_auxiliary_variable(q2d_att_weights.copy(),
                                                name='q2d_att_weights')

        # question encoding "in the view of the document"
        question_enc_informed = tensor.batched_dot(q2d_att_weights,
                                                   context_enc)
        question_enc_concatenated = tensor.concatenate(
            [question_enc, question_enc_informed], 2)
        # document encoding "in the view of the question"
        context_enc_informed = tensor.batched_dot(d2q_att_weights,
                                                  question_enc_concatenated)

        if self._coattention:
            context_enc_concatenated = tensor.concatenate(
                [context_enc, context_enc_informed], 2)
        else:
            question_repr_repeated = tensor.repeat(question_enc[:, [-1], :],
                                                   context_enc.shape[1],
                                                   axis=1)
            context_enc_concatenated = tensor.concatenate(
                [context_enc, question_repr_repeated], 2)

        # note: forward and backward LSTMs share the
        # input weights in the current impl
        bidir_states = flip01(
            self._bidir.apply(self._bidir_fork.apply(
                flip01(context_enc_concatenated)),
                              mask=contexts_mask.T)[0])

        begin_readouts = self._begin_readout.apply(bidir_states)[:, :, 0]
        begin_readouts = begin_readouts * contexts_mask - 1000.0 * (
            1 - contexts_mask)
        begin_costs = self._softmax.categorical_cross_entropy(
            answer_begins, begin_readouts)

        end_readouts = self._end_readout.apply(bidir_states)[:, :, 0]
        end_readouts = end_readouts * contexts_mask - 1000.0 * (1 -
                                                                contexts_mask)
        end_costs = self._softmax.categorical_cross_entropy(
            answer_ends, end_readouts)

        predicted_begins = begin_readouts.argmax(axis=-1)
        predicted_ends = end_readouts.argmax(axis=-1)
        exact_match = (tensor.eq(predicted_begins, answer_begins) *
                       tensor.eq(predicted_ends, answer_ends))
        application_call.add_auxiliary_variable(predicted_begins,
                                                name='predicted_begins')
        application_call.add_auxiliary_variable(predicted_ends,
                                                name='predicted_ends')
        application_call.add_auxiliary_variable(exact_match,
                                                name='exact_match')

        return begin_costs + end_costs

    def apply_with_default_vars(self):
        return self.apply(*self.input_vars.values())
Exemplo n.º 2
0
class ESIM(Initializable):
    """
    ESIM model based on https://github.com/NYU-MLL/multiNLI/blob/master/python/models/esim.py
    """

    # seq_length, emb_dim, hidden_dim
    def __init__(
            self,
            dim,
            emb_dim,
            vocab,
            def_emb_translate_dim=-1,
            def_dim=-1,
            encoder='bilstm',
            bn=True,
            def_reader=None,
            def_combiner=None,
            dropout=0.5,
            num_input_words=-1,
            # Others
            **kwargs):

        self._dropout = dropout
        self._vocab = vocab
        self._emb_dim = emb_dim
        self._def_reader = def_reader
        self._def_combiner = def_combiner

        if encoder != 'bilstm':
            raise NotImplementedError()

        if def_emb_translate_dim < 0:
            self.def_emb_translate_dim = emb_dim
        else:
            self.def_emb_translate_dim = def_emb_translate_dim

        if def_dim < 0:
            self._def_dim = emb_dim
        else:
            self._def_dim = def_dim

        if num_input_words > 0:
            logger.info("Restricting vocab to " + str(num_input_words))
            self._num_input_words = num_input_words
        else:
            self._num_input_words = vocab.size()

        children = []

        if self.def_emb_translate_dim != self._emb_dim:
            self._translate_pre_def = Linear(input_dim=emb_dim,
                                             output_dim=def_emb_translate_dim)
            children.append(self._translate_pre_def)
        else:
            self._translate_pre_def = None

        ## Embedding
        self._lookup = LookupTable(self._num_input_words,
                                   emb_dim,
                                   weights_init=GlorotUniform())
        children.append(self._lookup)

        if def_reader:
            self._final_emb_dim = self._def_dim
            self._def_reader = def_reader
            self._def_combiner = def_combiner
            children.extend([self._def_reader, self._def_combiner])
        else:
            self._final_emb_dim = self._emb_dim

        ## BiLSTM
        self._hyp_bidir_fork = Linear(
            self._def_dim if def_reader else self._emb_dim,
            4 * dim,
            name='hyp_bidir_fork')
        self._hyp_bidir = Bidirectional(LSTM(dim), name='hyp_bidir')
        self._prem_bidir_fork = Linear(
            self._def_dim if def_reader else self._emb_dim,
            4 * dim,
            name='prem_bidir_fork')
        self._prem_bidir = Bidirectional(LSTM(dim), name='prem_bidir')
        children.extend([self._hyp_bidir_fork, self._hyp_bidir])
        children.extend([self._prem_bidir, self._prem_bidir_fork])

        ## BiLSTM no. 2 (encoded attentioned embeddings)
        self._hyp_bidir_fork2 = Linear(8 * dim,
                                       4 * dim,
                                       name='hyp_bidir_fork2')
        self._hyp_bidir2 = Bidirectional(LSTM(dim), name='hyp_bidir2')
        self._prem_bidir_fork2 = Linear(8 * dim,
                                        4 * dim,
                                        name='prem_bidir_fork2')
        self._prem_bidir2 = Bidirectional(LSTM(dim), name='prem_bidir2')
        children.extend([self._hyp_bidir_fork2, self._hyp_bidir2])
        children.extend([self._prem_bidir2, self._prem_bidir_fork2])

        self._rnns = [
            self._prem_bidir2, self._hyp_bidir2, self._prem_bidir,
            self._hyp_bidir
        ]

        ## MLP
        if bn:
            self._mlp = BatchNormalizedMLP([Tanh()], [8 * dim, dim],
                                           conserve_memory=False,
                                           name="mlp")
            self._pred = BatchNormalizedMLP([Softmax()], [dim, 3],
                                            conserve_memory=False,
                                            name="pred_mlp")
        else:
            self._mlp = MLP([Tanh()], [8 * dim, dim], name="mlp")
            self._pred = MLP([Softmax()], [dim, 3], name="pred_mlp")

        children.append(self._mlp)
        children.append(self._pred)

        ## Softmax
        self._ndim_softmax = NDimensionalSoftmax()
        children.append(self._ndim_softmax)

        super(ESIM, self).__init__(children=children, **kwargs)

    def get_embeddings_lookups(self):
        return [self._lookup]

    def set_embeddings(self, embeddings):
        self._lookup.parameters[0].set_value(
            embeddings.astype(theano.config.floatX))

    def get_def_embeddings_lookups(self):
        return [self._def_reader._def_lookup]

    def set_def_embeddings(self, embeddings):
        self._def_reader._def_lookup.parameters[0].set_value(
            embeddings.astype(theano.config.floatX))

    @application
    def apply(self,
              application_call,
              s1_preunk,
              s1_mask,
              s2_preunk,
              s2_mask,
              def_mask=None,
              defs=None,
              s1_def_map=None,
              s2_def_map=None,
              train_phase=True):
        # Shortlist words (sometimes we want smaller vocab, especially when dict is small)
        s1 = (tensor.lt(s1_preunk, self._num_input_words) * s1_preunk +
              tensor.ge(s1_preunk, self._num_input_words) * self._vocab.unk)
        s2 = (tensor.lt(s2_preunk, self._num_input_words) * s2_preunk +
              tensor.ge(s2_preunk, self._num_input_words) * self._vocab.unk)

        ### Embed ###

        s1_emb = self._lookup.apply(s1)
        s2_emb = self._lookup.apply(s2)

        application_call.add_auxiliary_variable(1 * s1_emb,
                                                name='s1_word_embeddings')

        if self._def_reader:
            assert defs is not None

            def_embs = self._def_reader.apply(defs, def_mask)

            if self._translate_pre_def:
                logger.info("Translate pre def")
                s1_emb = s1_emb.reshape(
                    (s1_emb.shape[0] * s1_emb.shape[1], s1_emb.shape[2]))
                s2_emb = s2_emb.reshape(
                    (s2_emb.shape[0] * s2_emb.shape[1], s2_emb.shape[2]))
                s1_emb = self._translate_pre_def.apply(s1_emb)
                s2_emb = self._translate_pre_def.apply(s2_emb)
                s1_emb = s1_emb.reshape(
                    (s1_preunk.shape[0], s1_preunk.shape[1], -1))
                s2_emb = s2_emb.reshape(
                    (s2_preunk.shape[0], s2_preunk.shape[1], -1))

            s1_emb = self._def_combiner.apply(s1_emb,
                                              s1_mask,
                                              def_embs,
                                              s1_def_map,
                                              word_ids=s1,
                                              train_phase=train_phase,
                                              call_name="s1")

            s2_emb = self._def_combiner.apply(s2_emb,
                                              s2_mask,
                                              def_embs,
                                              s2_def_map,
                                              word_ids=s2,
                                              train_phase=train_phase,
                                              call_name="s2")
        else:
            if train_phase and self._dropout > 0:
                s1_emb = apply_dropout(s1_emb, drop_prob=self._dropout)
                s2_emb = apply_dropout(s2_emb, drop_prob=self._dropout)

        ### Encode ###

        # TODO: Share this bilstm?
        s1_bilstm, _ = self._prem_bidir.apply(
            flip01(self._prem_bidir_fork.apply(s1_emb)),
            mask=s1_mask.T)  # (batch_size, n_seq, 2 * dim)
        s2_bilstm, _ = self._hyp_bidir.apply(
            flip01(self._hyp_bidir_fork.apply(s2_emb)),
            mask=s2_mask.T)  # (batch_size, n_seq, 2 * dim)
        s1_bilstm = flip01(s1_bilstm)
        s2_bilstm = flip01(s2_bilstm)

        ### Attention ###

        # Compute E matrix (eq. 11)
        # E_ij = <s1[i], s2[j]>
        # each call computes E[i, :]
        def compute_e_row(s2_i, s1_bilstm, s1_mask):
            b_size = s1_bilstm.shape[0]
            # s2_i is (batch_size, emb_dim)
            # s1_bilstm is (batch_size, seq_len, emb_dim)
            # s1_mask is (batch_size, seq_len)
            # s2_i = s2_i.reshape((s2_i.shape[0], s2_i.shape[1], 1))
            s2_i = s2_i.reshape((b_size, s2_i.shape[1], 1))
            s2_i = T.repeat(s2_i, 2, axis=2)
            # s2_i is (batch_size, emb_dim, 2)
            assert s1_bilstm.ndim == 3
            assert s2_i.ndim == 3
            score = T.batched_dot(s1_bilstm, s2_i)  # (batch_size, seq_len, 1)
            score = score[:, :, 0].reshape(
                (b_size, -1))  # (batch_size, seq_len)
            return score  # E[i, :]

        # NOTE: No point in masking here
        E, _ = theano.scan(compute_e_row,
                           sequences=[s1_bilstm.transpose(1, 0, 2)],
                           non_sequences=[s2_bilstm, s2_mask])
        # (seq_len, batch_size, seq_len)
        E = E.dimshuffle(1, 0, 2)
        assert E.ndim == 3

        s2s_att_weights = self._ndim_softmax.apply(E, extra_ndim=1)
        application_call.add_auxiliary_variable(s2s_att_weights.copy(),
                                                name='s2s_att_weights')

        # (batch_size, seq_len, seq_len)

        ### Compute tilde vectors (eq. 12 and 13) ###

        def compute_tilde_vector(e_i, s, s_mask):
            # e_i is (batch_size, seq_len)
            # s_mask is (batch_size, seq_len)
            # s_tilde_i = \sum e_ij b_j, (batch_size, seq_len)
            score = masked_softmax(e_i, s_mask, axis=1)
            score = score.dimshuffle(0, 1, "x")

            s_tilde_i = (score *
                         (s * s_mask.dimshuffle(0, 1, "x"))).sum(axis=1)
            return s_tilde_i

        # (batch_size, seq_len, def_dim)
        s1_tilde, _ = theano.scan(compute_tilde_vector,
                                  sequences=[E.dimshuffle(1, 0, 2)],
                                  non_sequences=[s2_bilstm, s2_mask])
        s1_tilde = s1_tilde.dimshuffle(1, 0, 2)
        s2_tilde, _ = theano.scan(compute_tilde_vector,
                                  sequences=[E.dimshuffle(2, 0, 1)],
                                  non_sequences=[s1_bilstm, s1_mask])
        s2_tilde = s2_tilde.dimshuffle(1, 0, 2)

        ### Compose (eq. 14 and 15) ###

        # (batch_size, seq_len, 8 * dim)
        s1_comp = T.concatenate(
            [s1_bilstm, s1_tilde, s1_bilstm - s1_tilde, s1_bilstm * s1_tilde],
            axis=2)
        s2_comp = T.concatenate(
            [s2_bilstm, s2_tilde, s2_bilstm - s2_tilde, s2_bilstm * s2_tilde],
            axis=2)
        ### Encode (eq. 16 and 17) ###

        # (batch_size, seq_len, 8 * dim)
        # TODO: Share this bilstm?
        s1_comp_bilstm, _ = self._prem_bidir2.apply(
            self._prem_bidir_fork2.apply(flip01(s1_comp)),
            mask=s1_mask.T)  # (batch_size, n_seq, 2 * dim)
        s2_comp_bilstm, _ = self._hyp_bidir2.apply(
            self._hyp_bidir_fork2.apply(flip01(s2_comp)),
            mask=s2_mask.T)  # (batch_size, n_seq, 2 * dim)
        s1_comp_bilstm = flip01(s1_comp_bilstm)
        s2_comp_bilstm = flip01(s2_comp_bilstm)
        ### Pooling Layer ###

        s1_comp_bilstm_ave = (s1_mask.dimshuffle(0, 1, "x") * s1_comp_bilstm).sum(axis=1) \
                            / s1_mask.sum(axis=1).dimshuffle(0, "x")

        s1_comp_bilstm_max = T.max( ((1 - s1_mask.dimshuffle(0, 1, "x")) * -10000) + \
                                    (s1_mask.dimshuffle(0, 1, "x")) * s1_comp_bilstm, axis=1)

        s2_comp_bilstm_ave = (s2_mask.dimshuffle(0, 1, "x") * s2_comp_bilstm).sum(axis=1) \
                             / s2_mask.sum(axis=1).dimshuffle(0, "x")
        # (batch_size, dim)
        s2_comp_bilstm_max = T.max(((1 - s2_mask.dimshuffle(0, 1, "x")) * -10000) + \
                                   (s2_mask.dimshuffle(0, 1, "x")) * s2_comp_bilstm, axis=1)

        ### Final classifier ###

        # MLP layer
        # (batch_size, 8 * dim)
        m = T.concatenate([
            s1_comp_bilstm_ave, s1_comp_bilstm_max, s2_comp_bilstm_ave,
            s2_comp_bilstm_max
        ],
                          axis=1)
        pre_logits = self._mlp.apply(m)

        if train_phase:
            pre_logits = apply_dropout(pre_logits, drop_prob=self._dropout)

        # Get prediction
        self.logits = self._pred.apply(pre_logits)

        return self.logits