def __init__(self, dim, emb_dim, readout_dims, num_input_words,
                 def_num_input_words, vocab, use_definitions, def_word_gating,
                 compose_type, coattention, def_reader, reuse_word_embeddings,
                 random_unk, **kwargs):
        self._vocab = vocab
        if emb_dim == 0:
            emb_dim = dim
        if num_input_words == 0:
            num_input_words = vocab.size()
        if def_num_input_words == 0:
            def_num_input_words = num_input_words

        self._coattention = coattention
        self._num_input_words = num_input_words
        self._use_definitions = use_definitions
        self._random_unk = random_unk
        self._reuse_word_embeddings = reuse_word_embeddings

        lookup_num_words = num_input_words
        if reuse_word_embeddings:
            lookup_num_words = max(num_input_words, def_num_input_words)
        if random_unk:
            lookup_num_words = vocab.size()

        # Dima: we can have slightly less copy-paste here if we
        # copy the RecurrentFromFork class from my other projects.
        children = []
        self._lookup = LookupTable(lookup_num_words, emb_dim)
        self._encoder_fork = Linear(emb_dim, 4 * dim, name='encoder_fork')
        self._encoder_rnn = LSTM(dim, name='encoder_rnn')
        self._question_transform = Linear(dim, dim, name='question_transform')
        self._bidir_fork = Linear(3 * dim if coattention else 2 * dim,
                                  4 * dim,
                                  name='bidir_fork')
        self._bidir = Bidirectional(LSTM(dim), name='bidir')
        children.extend([
            self._lookup, self._encoder_fork, self._encoder_rnn,
            self._question_transform, self._bidir, self._bidir_fork
        ])

        activations = [Rectifier()] * len(readout_dims) + [None]
        readout_dims = [2 * dim] + readout_dims + [1]
        self._begin_readout = MLP(activations,
                                  readout_dims,
                                  name='begin_readout')
        self._end_readout = MLP(activations, readout_dims, name='end_readout')
        self._softmax = NDimensionalSoftmax()
        children.extend(
            [self._begin_readout, self._end_readout, self._softmax])

        if self._use_definitions:
            # A potential bug here: we pass the same vocab to the def reader.
            # If a different token is reserved for UNK in text and in the definitions,
            # we can be screwed.
            def_reader_class = eval(def_reader)
            def_reader_kwargs = dict(
                num_input_words=def_num_input_words,
                dim=dim,
                emb_dim=emb_dim,
                vocab=vocab,
                lookup=self._lookup if reuse_word_embeddings else None)
            if def_reader_class == MeanPoolReadDefinitions:
                def_reader_kwargs.update(dict(normalize=True, translate=False))
            self._def_reader = def_reader_class(**def_reader_kwargs)
            self._combiner = MeanPoolCombiner(dim=dim,
                                              emb_dim=emb_dim,
                                              def_word_gating=def_word_gating,
                                              compose_type=compose_type)
            children.extend([self._def_reader, self._combiner])

        super(ExtractiveQAModel, self).__init__(children=children, **kwargs)

        # create default input variables
        self.contexts = tensor.lmatrix('contexts')
        self.context_mask = tensor.matrix('contexts_mask')
        self.questions = tensor.lmatrix('questions')
        self.question_mask = tensor.matrix('questions_mask')
        self.answer_begins = tensor.lvector('answer_begins')
        self.answer_ends = tensor.lvector('answer_ends')
        input_vars = [
            self.contexts, self.context_mask, self.questions,
            self.question_mask, self.answer_begins, self.answer_ends
        ]
        if self._use_definitions:
            self.defs = tensor.lmatrix('defs')
            self.def_mask = tensor.matrix('def_mask')
            self.contexts_def_map = tensor.lmatrix('contexts_def_map')
            self.questions_def_map = tensor.lmatrix('questions_def_map')
            input_vars.extend([
                self.defs, self.def_mask, self.contexts_def_map,
                self.questions_def_map
            ])
        self.input_vars = OrderedDict([(var.name, var) for var in input_vars])
class ExtractiveQAModel(Initializable):
    """The dictionary-equipped extractive QA model.

    Parameters
    ----------
    dim : int
        The default dimensionality for the components.
    emd_dim : int
        The dimensionality for the embeddings. If 0, `dim` is used.
    coattention : bool
        Use the coattention mechanism.
    num_input_words : int
        The number of input words. If 0, `vocab.size()` is used.
        The vocabulary object.
    use_definitions : bool
        Triggers the use of definitions.
    reuse_word_embeddings : bool
    compose_type : str

    """
    def __init__(self, dim, emb_dim, readout_dims, num_input_words,
                 def_num_input_words, vocab, use_definitions, def_word_gating,
                 compose_type, coattention, def_reader, reuse_word_embeddings,
                 random_unk, **kwargs):
        self._vocab = vocab
        if emb_dim == 0:
            emb_dim = dim
        if num_input_words == 0:
            num_input_words = vocab.size()
        if def_num_input_words == 0:
            def_num_input_words = num_input_words

        self._coattention = coattention
        self._num_input_words = num_input_words
        self._use_definitions = use_definitions
        self._random_unk = random_unk
        self._reuse_word_embeddings = reuse_word_embeddings

        lookup_num_words = num_input_words
        if reuse_word_embeddings:
            lookup_num_words = max(num_input_words, def_num_input_words)
        if random_unk:
            lookup_num_words = vocab.size()

        # Dima: we can have slightly less copy-paste here if we
        # copy the RecurrentFromFork class from my other projects.
        children = []
        self._lookup = LookupTable(lookup_num_words, emb_dim)
        self._encoder_fork = Linear(emb_dim, 4 * dim, name='encoder_fork')
        self._encoder_rnn = LSTM(dim, name='encoder_rnn')
        self._question_transform = Linear(dim, dim, name='question_transform')
        self._bidir_fork = Linear(3 * dim if coattention else 2 * dim,
                                  4 * dim,
                                  name='bidir_fork')
        self._bidir = Bidirectional(LSTM(dim), name='bidir')
        children.extend([
            self._lookup, self._encoder_fork, self._encoder_rnn,
            self._question_transform, self._bidir, self._bidir_fork
        ])

        activations = [Rectifier()] * len(readout_dims) + [None]
        readout_dims = [2 * dim] + readout_dims + [1]
        self._begin_readout = MLP(activations,
                                  readout_dims,
                                  name='begin_readout')
        self._end_readout = MLP(activations, readout_dims, name='end_readout')
        self._softmax = NDimensionalSoftmax()
        children.extend(
            [self._begin_readout, self._end_readout, self._softmax])

        if self._use_definitions:
            # A potential bug here: we pass the same vocab to the def reader.
            # If a different token is reserved for UNK in text and in the definitions,
            # we can be screwed.
            def_reader_class = eval(def_reader)
            def_reader_kwargs = dict(
                num_input_words=def_num_input_words,
                dim=dim,
                emb_dim=emb_dim,
                vocab=vocab,
                lookup=self._lookup if reuse_word_embeddings else None)
            if def_reader_class == MeanPoolReadDefinitions:
                def_reader_kwargs.update(dict(normalize=True, translate=False))
            self._def_reader = def_reader_class(**def_reader_kwargs)
            self._combiner = MeanPoolCombiner(dim=dim,
                                              emb_dim=emb_dim,
                                              def_word_gating=def_word_gating,
                                              compose_type=compose_type)
            children.extend([self._def_reader, self._combiner])

        super(ExtractiveQAModel, self).__init__(children=children, **kwargs)

        # create default input variables
        self.contexts = tensor.lmatrix('contexts')
        self.context_mask = tensor.matrix('contexts_mask')
        self.questions = tensor.lmatrix('questions')
        self.question_mask = tensor.matrix('questions_mask')
        self.answer_begins = tensor.lvector('answer_begins')
        self.answer_ends = tensor.lvector('answer_ends')
        input_vars = [
            self.contexts, self.context_mask, self.questions,
            self.question_mask, self.answer_begins, self.answer_ends
        ]
        if self._use_definitions:
            self.defs = tensor.lmatrix('defs')
            self.def_mask = tensor.matrix('def_mask')
            self.contexts_def_map = tensor.lmatrix('contexts_def_map')
            self.questions_def_map = tensor.lmatrix('questions_def_map')
            input_vars.extend([
                self.defs, self.def_mask, self.contexts_def_map,
                self.questions_def_map
            ])
        self.input_vars = OrderedDict([(var.name, var) for var in input_vars])

    def set_embeddings(self, embeddings):
        self._lookup.parameters[0].set_value(
            embeddings.astype(theano.config.floatX))

    def embeddings_var(self):
        return self._lookup.parameters[0]

    def def_reading_parameters(self):
        parameters = Selector(self._def_reader).get_parameters().values()
        parameters.extend(Selector(self._combiner).get_parameters().values())
        if self._reuse_word_embeddings:
            lookup_parameters = Selector(
                self._lookup).get_parameters().values()
            parameters = [p for p in parameters if p not in lookup_parameters]
        return parameters

    @application
    def _encode(self,
                application_call,
                text,
                mask,
                def_embs=None,
                def_map=None,
                text_name=None):
        if not self._random_unk:
            text = (tensor.lt(text, self._num_input_words) * text +
                    tensor.ge(text, self._num_input_words) * self._vocab.unk)
        if text_name:
            application_call.add_auxiliary_variable(
                unk_ratio(text, mask, self._vocab.unk),
                name='{}_unk_ratio'.format(text_name))
        embs = self._lookup.apply(text)
        if self._random_unk:
            embs = (tensor.lt(text, self._num_input_words)[:, :, None] * embs +
                    tensor.ge(text, self._num_input_words)[:, :, None] *
                    disconnected_grad(embs))
        if def_embs:
            embs = self._combiner.apply(embs, mask, def_embs, def_map)
        add_role(embs, EMBEDDINGS)
        encoded = flip01(
            self._encoder_rnn.apply(self._encoder_fork.apply(flip01(embs)),
                                    mask=mask.T)[0])
        return encoded

    @application
    def apply(self,
              application_call,
              contexts,
              contexts_mask,
              questions,
              questions_mask,
              answer_begins,
              answer_ends,
              defs=None,
              def_mask=None,
              contexts_def_map=None,
              questions_def_map=None):
        def_embs = None
        if self._use_definitions:
            def_embs = self._def_reader.apply(defs, def_mask)

        context_enc = self._encode(contexts, contexts_mask, def_embs,
                                   contexts_def_map, 'context')
        question_enc_pre = self._encode(questions, questions_mask, def_embs,
                                        questions_def_map, 'question')
        question_enc = tensor.tanh(
            self._question_transform.apply(question_enc_pre))

        # should be (batch size, context length, question_length)
        affinity = tensor.batched_dot(context_enc, flip12(question_enc))
        affinity_mask = contexts_mask[:, :, None] * questions_mask[:, None, :]
        affinity = affinity * affinity_mask - 1000.0 * (1 - affinity_mask)
        # soft-aligns every position in the context to positions in the question
        d2q_att_weights = self._softmax.apply(affinity, extra_ndim=1)
        application_call.add_auxiliary_variable(d2q_att_weights.copy(),
                                                name='d2q_att_weights')
        # soft-aligns every position in the question to positions in the document
        q2d_att_weights = self._softmax.apply(flip12(affinity), extra_ndim=1)
        application_call.add_auxiliary_variable(q2d_att_weights.copy(),
                                                name='q2d_att_weights')

        # question encoding "in the view of the document"
        question_enc_informed = tensor.batched_dot(q2d_att_weights,
                                                   context_enc)
        question_enc_concatenated = tensor.concatenate(
            [question_enc, question_enc_informed], 2)
        # document encoding "in the view of the question"
        context_enc_informed = tensor.batched_dot(d2q_att_weights,
                                                  question_enc_concatenated)

        if self._coattention:
            context_enc_concatenated = tensor.concatenate(
                [context_enc, context_enc_informed], 2)
        else:
            question_repr_repeated = tensor.repeat(question_enc[:, [-1], :],
                                                   context_enc.shape[1],
                                                   axis=1)
            context_enc_concatenated = tensor.concatenate(
                [context_enc, question_repr_repeated], 2)

        # note: forward and backward LSTMs share the
        # input weights in the current impl
        bidir_states = flip01(
            self._bidir.apply(self._bidir_fork.apply(
                flip01(context_enc_concatenated)),
                              mask=contexts_mask.T)[0])

        begin_readouts = self._begin_readout.apply(bidir_states)[:, :, 0]
        begin_readouts = begin_readouts * contexts_mask - 1000.0 * (
            1 - contexts_mask)
        begin_costs = self._softmax.categorical_cross_entropy(
            answer_begins, begin_readouts)

        end_readouts = self._end_readout.apply(bidir_states)[:, :, 0]
        end_readouts = end_readouts * contexts_mask - 1000.0 * (1 -
                                                                contexts_mask)
        end_costs = self._softmax.categorical_cross_entropy(
            answer_ends, end_readouts)

        predicted_begins = begin_readouts.argmax(axis=-1)
        predicted_ends = end_readouts.argmax(axis=-1)
        exact_match = (tensor.eq(predicted_begins, answer_begins) *
                       tensor.eq(predicted_ends, answer_ends))
        application_call.add_auxiliary_variable(predicted_begins,
                                                name='predicted_begins')
        application_call.add_auxiliary_variable(predicted_ends,
                                                name='predicted_ends')
        application_call.add_auxiliary_variable(exact_match,
                                                name='exact_match')

        return begin_costs + end_costs

    def apply_with_default_vars(self):
        return self.apply(*self.input_vars.values())
Ejemplo n.º 3
0
    def __init__(
            self,
            dim,
            emb_dim,
            vocab,
            def_emb_translate_dim=-1,
            def_dim=-1,
            encoder='bilstm',
            bn=True,
            def_reader=None,
            def_combiner=None,
            dropout=0.5,
            num_input_words=-1,
            # Others
            **kwargs):

        self._dropout = dropout
        self._vocab = vocab
        self._emb_dim = emb_dim
        self._def_reader = def_reader
        self._def_combiner = def_combiner

        if encoder != 'bilstm':
            raise NotImplementedError()

        if def_emb_translate_dim < 0:
            self.def_emb_translate_dim = emb_dim
        else:
            self.def_emb_translate_dim = def_emb_translate_dim

        if def_dim < 0:
            self._def_dim = emb_dim
        else:
            self._def_dim = def_dim

        if num_input_words > 0:
            logger.info("Restricting vocab to " + str(num_input_words))
            self._num_input_words = num_input_words
        else:
            self._num_input_words = vocab.size()

        children = []

        if self.def_emb_translate_dim != self._emb_dim:
            self._translate_pre_def = Linear(input_dim=emb_dim,
                                             output_dim=def_emb_translate_dim)
            children.append(self._translate_pre_def)
        else:
            self._translate_pre_def = None

        ## Embedding
        self._lookup = LookupTable(self._num_input_words,
                                   emb_dim,
                                   weights_init=GlorotUniform())
        children.append(self._lookup)

        if def_reader:
            self._final_emb_dim = self._def_dim
            self._def_reader = def_reader
            self._def_combiner = def_combiner
            children.extend([self._def_reader, self._def_combiner])
        else:
            self._final_emb_dim = self._emb_dim

        ## BiLSTM
        self._hyp_bidir_fork = Linear(
            self._def_dim if def_reader else self._emb_dim,
            4 * dim,
            name='hyp_bidir_fork')
        self._hyp_bidir = Bidirectional(LSTM(dim), name='hyp_bidir')
        self._prem_bidir_fork = Linear(
            self._def_dim if def_reader else self._emb_dim,
            4 * dim,
            name='prem_bidir_fork')
        self._prem_bidir = Bidirectional(LSTM(dim), name='prem_bidir')
        children.extend([self._hyp_bidir_fork, self._hyp_bidir])
        children.extend([self._prem_bidir, self._prem_bidir_fork])

        ## BiLSTM no. 2 (encoded attentioned embeddings)
        self._hyp_bidir_fork2 = Linear(8 * dim,
                                       4 * dim,
                                       name='hyp_bidir_fork2')
        self._hyp_bidir2 = Bidirectional(LSTM(dim), name='hyp_bidir2')
        self._prem_bidir_fork2 = Linear(8 * dim,
                                        4 * dim,
                                        name='prem_bidir_fork2')
        self._prem_bidir2 = Bidirectional(LSTM(dim), name='prem_bidir2')
        children.extend([self._hyp_bidir_fork2, self._hyp_bidir2])
        children.extend([self._prem_bidir2, self._prem_bidir_fork2])

        self._rnns = [
            self._prem_bidir2, self._hyp_bidir2, self._prem_bidir,
            self._hyp_bidir
        ]

        ## MLP
        if bn:
            self._mlp = BatchNormalizedMLP([Tanh()], [8 * dim, dim],
                                           conserve_memory=False,
                                           name="mlp")
            self._pred = BatchNormalizedMLP([Softmax()], [dim, 3],
                                            conserve_memory=False,
                                            name="pred_mlp")
        else:
            self._mlp = MLP([Tanh()], [8 * dim, dim], name="mlp")
            self._pred = MLP([Softmax()], [dim, 3], name="pred_mlp")

        children.append(self._mlp)
        children.append(self._pred)

        ## Softmax
        self._ndim_softmax = NDimensionalSoftmax()
        children.append(self._ndim_softmax)

        super(ESIM, self).__init__(children=children, **kwargs)
Ejemplo n.º 4
0
    def __init__(self,
                 emb_dim,
                 dim,
                 num_input_words,
                 num_output_words,
                 vocab,
                 proximity_coef=0,
                 proximity_distance='l2',
                 encoder='lstm',
                 decoder='lstm',
                 shared_rnn=False,
                 translate_layer=None,
                 word_dropout=0.,
                 tied_in_out=False,
                 vocab_keys=None,
                 seed=0,
                 reconstruction_coef=1.,
                 provide_targets=False,
                 **kwargs):
        """
        translate_layer: either a string containing the activation function to use
                         either a list containg the list of activations for a MLP
        """
        if emb_dim == 0:
            emb_dim = dim
        if num_input_words == 0:
            num_input_words = vocab.size()
        if num_output_words == 0:
            num_output_words = vocab.size()

        self._word_dropout = word_dropout

        self._tied_in_out = tied_in_out

        if not encoder:
            if proximity_coef:
                raise ValueError("Err: meaningless penalty term (no encoder)")
            if not vocab_keys:
                raise ValueError("Err: specify a key vocabulary (no encoder)")

        if tied_in_out and num_input_words != num_output_words:
            raise ValueError("Can't tie in and out embeddings. Different "
                             "vocabulary size")
        if shared_rnn and (encoder != 'lstm' or decoder != 'lstm'):
            raise ValueError(
                "can't share RNN because either encoder or decoder"
                "is not an RNN")
        if shared_rnn and decoder == 'lstm_c':
            raise ValueError(
                "can't share RNN because the decoder takes different"
                "inputs")
        if word_dropout < 0 or word_dropout > 1:
            raise ValueError("invalid value for word dropout",
                             str(word_dropout))
        if proximity_distance not in ['l1', 'l2', 'cos']:
            raise ValueError(
                "unrecognized distance: {}".format(proximity_distance))

        if proximity_coef and emb_dim != dim and not translate_layer:
            raise ValueError(
                """if proximity penalisation, emb_dim should equal dim or 
                              there should be a translate layer""")

        if encoder not in [
                None, 'lstm', 'bilstm', 'mean', 'weighted_mean', 'max_bilstm',
                'bilstm_sum', 'max_bilstm_sum'
        ]:
            raise ValueError('encoder not recognized')
        if decoder not in ['skip-gram', 'lstm', 'lstm_c']:
            raise ValueError('decoder not recognized')

        self._proximity_distance = proximity_distance
        self._decoder = decoder
        self._encoder = encoder
        self._num_input_words = num_input_words
        self._num_output_words = num_output_words
        self._vocab = vocab
        self._proximity_coef = proximity_coef
        self._reconstruction_coef = reconstruction_coef
        self._provide_targets = provide_targets

        self._word_to_id = WordToIdOp(self._vocab)
        if vocab_keys:
            self._key_to_id = WordToIdOp(vocab_keys)

        children = []

        if encoder or (not encoder and decoder in ['lstm', 'lstm_c']):
            self._main_lookup = LookupTable(self._num_input_words,
                                            emb_dim,
                                            name='main_lookup')
            children.append(self._main_lookup)
        if provide_targets:
            # this is useful to simulate Hill's baseline without pretrained embeddings
            # in the encoder, only as targets for the encoder.
            self._target_lookup = LookupTable(self._num_input_words,
                                              emb_dim,
                                              name='target_lookup')
            children.append(self._target_lookup)
        if not encoder:
            self._key_lookup = LookupTable(vocab_keys.size(),
                                           emb_dim,
                                           name='key_lookup')
            children.append(self._key_lookup)
        elif encoder == 'lstm':
            self._encoder_fork = Linear(emb_dim, 4 * dim, name='encoder_fork')
            self._encoder_rnn = LSTM(dim, name='encoder_rnn')
            children.extend([self._encoder_fork, self._encoder_rnn])
        elif encoder in ['bilstm', 'max_bilstm']:
            # dim is the dim of the concatenated vector
            self._encoder_fork = Linear(emb_dim, 2 * dim, name='encoder_fork')
            self._encoder_rnn = Bidirectional(LSTM(dim / 2,
                                                   name='encoder_rnn'))
            children.extend([self._encoder_fork, self._encoder_rnn])
        elif encoder in ['bilstm_sum', 'max_bilstm_sum']:
            self._encoder_fork = Linear(emb_dim, 4 * dim, name='encoder_fork')
            self._encoder_rnn = BidirectionalSum(LSTM(dim, name='encoder_rnn'))
            children.extend([self._encoder_fork, self._encoder_rnn])
        elif encoder == 'mean':
            pass
        elif encoder == 'weighted_mean':
            self._encoder_w = MLP([Logistic()], [dim, 1],
                                  name="encoder_weights")
            children.extend([self._encoder_w])
        else:
            raise NotImplementedError()

        if decoder in ['lstm', 'lstm_c']:
            dim_after_translate = emb_dim
            if shared_rnn:
                self._decoder_fork = self._encoder_fork
                self._decoder_rnn = self._encoder_rnn
            else:
                if decoder == 'lstm_c':
                    dim_2 = dim + emb_dim
                else:
                    dim_2 = dim
                self._decoder_fork = Linear(dim_2,
                                            4 * dim,
                                            name='decoder_fork')
                self._decoder_rnn = LSTM(dim, name='decoder_rnn')
            children.extend([self._decoder_fork, self._decoder_rnn])
        elif decoder == 'skip-gram':
            dim_after_translate = emb_dim

        self._translate_layer = None
        activations = {'sigmoid': Logistic(), 'tanh': Tanh(), 'linear': None}

        if translate_layer:
            if type(translate_layer) == str:
                translate_layer = [translate_layer]
            assert (type(translate_layer) == list)
            activations_translate = [activations[a] for a in translate_layer]
            dims_translate = [
                dim,
            ] * len(translate_layer) + [dim_after_translate]
            self._translate_layer = MLP(activations_translate,
                                        dims_translate,
                                        name="translate_layer")
            children.append(self._translate_layer)

        if not self._tied_in_out:
            self._pre_softmax = Linear(emb_dim, self._num_output_words)
            children.append(self._pre_softmax)
        if decoder in ['lstm', 'lstm_c']:
            self._softmax = NDimensionalSoftmax()
        elif decoder in ['skip-gram']:
            self._softmax = Softmax()
        children.append(self._softmax)

        super(Seq2Seq, self).__init__(children=children, **kwargs)
Ejemplo n.º 5
0
class ESIM(Initializable):
    """
    ESIM model based on https://github.com/NYU-MLL/multiNLI/blob/master/python/models/esim.py
    """

    # seq_length, emb_dim, hidden_dim
    def __init__(
            self,
            dim,
            emb_dim,
            vocab,
            def_emb_translate_dim=-1,
            def_dim=-1,
            encoder='bilstm',
            bn=True,
            def_reader=None,
            def_combiner=None,
            dropout=0.5,
            num_input_words=-1,
            # Others
            **kwargs):

        self._dropout = dropout
        self._vocab = vocab
        self._emb_dim = emb_dim
        self._def_reader = def_reader
        self._def_combiner = def_combiner

        if encoder != 'bilstm':
            raise NotImplementedError()

        if def_emb_translate_dim < 0:
            self.def_emb_translate_dim = emb_dim
        else:
            self.def_emb_translate_dim = def_emb_translate_dim

        if def_dim < 0:
            self._def_dim = emb_dim
        else:
            self._def_dim = def_dim

        if num_input_words > 0:
            logger.info("Restricting vocab to " + str(num_input_words))
            self._num_input_words = num_input_words
        else:
            self._num_input_words = vocab.size()

        children = []

        if self.def_emb_translate_dim != self._emb_dim:
            self._translate_pre_def = Linear(input_dim=emb_dim,
                                             output_dim=def_emb_translate_dim)
            children.append(self._translate_pre_def)
        else:
            self._translate_pre_def = None

        ## Embedding
        self._lookup = LookupTable(self._num_input_words,
                                   emb_dim,
                                   weights_init=GlorotUniform())
        children.append(self._lookup)

        if def_reader:
            self._final_emb_dim = self._def_dim
            self._def_reader = def_reader
            self._def_combiner = def_combiner
            children.extend([self._def_reader, self._def_combiner])
        else:
            self._final_emb_dim = self._emb_dim

        ## BiLSTM
        self._hyp_bidir_fork = Linear(
            self._def_dim if def_reader else self._emb_dim,
            4 * dim,
            name='hyp_bidir_fork')
        self._hyp_bidir = Bidirectional(LSTM(dim), name='hyp_bidir')
        self._prem_bidir_fork = Linear(
            self._def_dim if def_reader else self._emb_dim,
            4 * dim,
            name='prem_bidir_fork')
        self._prem_bidir = Bidirectional(LSTM(dim), name='prem_bidir')
        children.extend([self._hyp_bidir_fork, self._hyp_bidir])
        children.extend([self._prem_bidir, self._prem_bidir_fork])

        ## BiLSTM no. 2 (encoded attentioned embeddings)
        self._hyp_bidir_fork2 = Linear(8 * dim,
                                       4 * dim,
                                       name='hyp_bidir_fork2')
        self._hyp_bidir2 = Bidirectional(LSTM(dim), name='hyp_bidir2')
        self._prem_bidir_fork2 = Linear(8 * dim,
                                        4 * dim,
                                        name='prem_bidir_fork2')
        self._prem_bidir2 = Bidirectional(LSTM(dim), name='prem_bidir2')
        children.extend([self._hyp_bidir_fork2, self._hyp_bidir2])
        children.extend([self._prem_bidir2, self._prem_bidir_fork2])

        self._rnns = [
            self._prem_bidir2, self._hyp_bidir2, self._prem_bidir,
            self._hyp_bidir
        ]

        ## MLP
        if bn:
            self._mlp = BatchNormalizedMLP([Tanh()], [8 * dim, dim],
                                           conserve_memory=False,
                                           name="mlp")
            self._pred = BatchNormalizedMLP([Softmax()], [dim, 3],
                                            conserve_memory=False,
                                            name="pred_mlp")
        else:
            self._mlp = MLP([Tanh()], [8 * dim, dim], name="mlp")
            self._pred = MLP([Softmax()], [dim, 3], name="pred_mlp")

        children.append(self._mlp)
        children.append(self._pred)

        ## Softmax
        self._ndim_softmax = NDimensionalSoftmax()
        children.append(self._ndim_softmax)

        super(ESIM, self).__init__(children=children, **kwargs)

    def get_embeddings_lookups(self):
        return [self._lookup]

    def set_embeddings(self, embeddings):
        self._lookup.parameters[0].set_value(
            embeddings.astype(theano.config.floatX))

    def get_def_embeddings_lookups(self):
        return [self._def_reader._def_lookup]

    def set_def_embeddings(self, embeddings):
        self._def_reader._def_lookup.parameters[0].set_value(
            embeddings.astype(theano.config.floatX))

    @application
    def apply(self,
              application_call,
              s1_preunk,
              s1_mask,
              s2_preunk,
              s2_mask,
              def_mask=None,
              defs=None,
              s1_def_map=None,
              s2_def_map=None,
              train_phase=True):
        # Shortlist words (sometimes we want smaller vocab, especially when dict is small)
        s1 = (tensor.lt(s1_preunk, self._num_input_words) * s1_preunk +
              tensor.ge(s1_preunk, self._num_input_words) * self._vocab.unk)
        s2 = (tensor.lt(s2_preunk, self._num_input_words) * s2_preunk +
              tensor.ge(s2_preunk, self._num_input_words) * self._vocab.unk)

        ### Embed ###

        s1_emb = self._lookup.apply(s1)
        s2_emb = self._lookup.apply(s2)

        application_call.add_auxiliary_variable(1 * s1_emb,
                                                name='s1_word_embeddings')

        if self._def_reader:
            assert defs is not None

            def_embs = self._def_reader.apply(defs, def_mask)

            if self._translate_pre_def:
                logger.info("Translate pre def")
                s1_emb = s1_emb.reshape(
                    (s1_emb.shape[0] * s1_emb.shape[1], s1_emb.shape[2]))
                s2_emb = s2_emb.reshape(
                    (s2_emb.shape[0] * s2_emb.shape[1], s2_emb.shape[2]))
                s1_emb = self._translate_pre_def.apply(s1_emb)
                s2_emb = self._translate_pre_def.apply(s2_emb)
                s1_emb = s1_emb.reshape(
                    (s1_preunk.shape[0], s1_preunk.shape[1], -1))
                s2_emb = s2_emb.reshape(
                    (s2_preunk.shape[0], s2_preunk.shape[1], -1))

            s1_emb = self._def_combiner.apply(s1_emb,
                                              s1_mask,
                                              def_embs,
                                              s1_def_map,
                                              word_ids=s1,
                                              train_phase=train_phase,
                                              call_name="s1")

            s2_emb = self._def_combiner.apply(s2_emb,
                                              s2_mask,
                                              def_embs,
                                              s2_def_map,
                                              word_ids=s2,
                                              train_phase=train_phase,
                                              call_name="s2")
        else:
            if train_phase and self._dropout > 0:
                s1_emb = apply_dropout(s1_emb, drop_prob=self._dropout)
                s2_emb = apply_dropout(s2_emb, drop_prob=self._dropout)

        ### Encode ###

        # TODO: Share this bilstm?
        s1_bilstm, _ = self._prem_bidir.apply(
            flip01(self._prem_bidir_fork.apply(s1_emb)),
            mask=s1_mask.T)  # (batch_size, n_seq, 2 * dim)
        s2_bilstm, _ = self._hyp_bidir.apply(
            flip01(self._hyp_bidir_fork.apply(s2_emb)),
            mask=s2_mask.T)  # (batch_size, n_seq, 2 * dim)
        s1_bilstm = flip01(s1_bilstm)
        s2_bilstm = flip01(s2_bilstm)

        ### Attention ###

        # Compute E matrix (eq. 11)
        # E_ij = <s1[i], s2[j]>
        # each call computes E[i, :]
        def compute_e_row(s2_i, s1_bilstm, s1_mask):
            b_size = s1_bilstm.shape[0]
            # s2_i is (batch_size, emb_dim)
            # s1_bilstm is (batch_size, seq_len, emb_dim)
            # s1_mask is (batch_size, seq_len)
            # s2_i = s2_i.reshape((s2_i.shape[0], s2_i.shape[1], 1))
            s2_i = s2_i.reshape((b_size, s2_i.shape[1], 1))
            s2_i = T.repeat(s2_i, 2, axis=2)
            # s2_i is (batch_size, emb_dim, 2)
            assert s1_bilstm.ndim == 3
            assert s2_i.ndim == 3
            score = T.batched_dot(s1_bilstm, s2_i)  # (batch_size, seq_len, 1)
            score = score[:, :, 0].reshape(
                (b_size, -1))  # (batch_size, seq_len)
            return score  # E[i, :]

        # NOTE: No point in masking here
        E, _ = theano.scan(compute_e_row,
                           sequences=[s1_bilstm.transpose(1, 0, 2)],
                           non_sequences=[s2_bilstm, s2_mask])
        # (seq_len, batch_size, seq_len)
        E = E.dimshuffle(1, 0, 2)
        assert E.ndim == 3

        s2s_att_weights = self._ndim_softmax.apply(E, extra_ndim=1)
        application_call.add_auxiliary_variable(s2s_att_weights.copy(),
                                                name='s2s_att_weights')

        # (batch_size, seq_len, seq_len)

        ### Compute tilde vectors (eq. 12 and 13) ###

        def compute_tilde_vector(e_i, s, s_mask):
            # e_i is (batch_size, seq_len)
            # s_mask is (batch_size, seq_len)
            # s_tilde_i = \sum e_ij b_j, (batch_size, seq_len)
            score = masked_softmax(e_i, s_mask, axis=1)
            score = score.dimshuffle(0, 1, "x")

            s_tilde_i = (score *
                         (s * s_mask.dimshuffle(0, 1, "x"))).sum(axis=1)
            return s_tilde_i

        # (batch_size, seq_len, def_dim)
        s1_tilde, _ = theano.scan(compute_tilde_vector,
                                  sequences=[E.dimshuffle(1, 0, 2)],
                                  non_sequences=[s2_bilstm, s2_mask])
        s1_tilde = s1_tilde.dimshuffle(1, 0, 2)
        s2_tilde, _ = theano.scan(compute_tilde_vector,
                                  sequences=[E.dimshuffle(2, 0, 1)],
                                  non_sequences=[s1_bilstm, s1_mask])
        s2_tilde = s2_tilde.dimshuffle(1, 0, 2)

        ### Compose (eq. 14 and 15) ###

        # (batch_size, seq_len, 8 * dim)
        s1_comp = T.concatenate(
            [s1_bilstm, s1_tilde, s1_bilstm - s1_tilde, s1_bilstm * s1_tilde],
            axis=2)
        s2_comp = T.concatenate(
            [s2_bilstm, s2_tilde, s2_bilstm - s2_tilde, s2_bilstm * s2_tilde],
            axis=2)
        ### Encode (eq. 16 and 17) ###

        # (batch_size, seq_len, 8 * dim)
        # TODO: Share this bilstm?
        s1_comp_bilstm, _ = self._prem_bidir2.apply(
            self._prem_bidir_fork2.apply(flip01(s1_comp)),
            mask=s1_mask.T)  # (batch_size, n_seq, 2 * dim)
        s2_comp_bilstm, _ = self._hyp_bidir2.apply(
            self._hyp_bidir_fork2.apply(flip01(s2_comp)),
            mask=s2_mask.T)  # (batch_size, n_seq, 2 * dim)
        s1_comp_bilstm = flip01(s1_comp_bilstm)
        s2_comp_bilstm = flip01(s2_comp_bilstm)
        ### Pooling Layer ###

        s1_comp_bilstm_ave = (s1_mask.dimshuffle(0, 1, "x") * s1_comp_bilstm).sum(axis=1) \
                            / s1_mask.sum(axis=1).dimshuffle(0, "x")

        s1_comp_bilstm_max = T.max( ((1 - s1_mask.dimshuffle(0, 1, "x")) * -10000) + \
                                    (s1_mask.dimshuffle(0, 1, "x")) * s1_comp_bilstm, axis=1)

        s2_comp_bilstm_ave = (s2_mask.dimshuffle(0, 1, "x") * s2_comp_bilstm).sum(axis=1) \
                             / s2_mask.sum(axis=1).dimshuffle(0, "x")
        # (batch_size, dim)
        s2_comp_bilstm_max = T.max(((1 - s2_mask.dimshuffle(0, 1, "x")) * -10000) + \
                                   (s2_mask.dimshuffle(0, 1, "x")) * s2_comp_bilstm, axis=1)

        ### Final classifier ###

        # MLP layer
        # (batch_size, 8 * dim)
        m = T.concatenate([
            s1_comp_bilstm_ave, s1_comp_bilstm_max, s2_comp_bilstm_ave,
            s2_comp_bilstm_max
        ],
                          axis=1)
        pre_logits = self._mlp.apply(m)

        if train_phase:
            pre_logits = apply_dropout(pre_logits, drop_prob=self._dropout)

        # Get prediction
        self.logits = self._pred.apply(pre_logits)

        return self.logits