Esempio n. 1
0
    def _context_sensitive_shift(self, inputs):
        """
        Compute a buffer top representation by mixing buffer top and hidden state.

        NB: This hasn't been an especially effective tool so far.
        """
        assert self.use_tracking_lstm
        buffer_top, tracking_hidden = inputs[2:4]

        # Exclude the cell value from the computation.
        tracking_hidden = tracking_hidden[:, :hidden_dim]

        inp = T.concatenate([tracking_hidden, buffer_top], axis=1)
        inp_dim = self._spec.word_embedding_dim + self.tracking_lstm_hidden_dim
        layer = util.ReLULayer if self.context_sensitive_use_relu else util.Linear
        return layer(inp, inp_dim, self._spec.model_dim, self._vs,
                     name="context_comb_unit", use_bias=True,
                     initializer=util.HeKaimingInitializer())
Esempio n. 2
0
    def _step(self, transitions_t, ss_mask_gen_matrix_t, stack_t, buffer_cur_t,
              tracking_hidden, attention_hidden, buffer,
              ground_truth_transitions_visible, premise_stack_tops,
              projected_stack_tops):
        """TODO document"""
        batch_size, _ = self.X.shape

        # Extract top buffer values.
        idxs = buffer_cur_t + (T.arange(batch_size) * self.seq_length)

        if self.context_sensitive_shift:
            # Combine with the hidden state from previous unit.
            tracking_h_t = tracking_hidden[:, :self.tracking_lstm_hidden_dim]
            context_comb_input_t = T.concatenate([tracking_h_t, buffer[idxs]],
                                                 axis=1)
            context_comb_input_dim = self.word_embedding_dim + self.tracking_lstm_hidden_dim
            comb_layer = util.ReLULayer if self.context_sensitive_use_relu else util.Linear
            buffer_top_t = comb_layer(context_comb_input_t,
                                      context_comb_input_dim,
                                      self.model_dim,
                                      self._vs,
                                      name="context_comb_unit",
                                      use_bias=True,
                                      initializer=util.HeKaimingInitializer())
        else:
            buffer_top_t = buffer[idxs]

        if self._prediction_and_tracking_network is not None:
            # We are predicting our own stack operations.
            h_dim = self.model_dim / 2  # TODO(SB): Turn this off when not using TreeLSTM.

            predict_inp = T.concatenate([
                stack_t[:, 0, :h_dim], stack_t[:, 1, :h_dim],
                buffer_top_t[:, :h_dim]
            ],
                                        axis=1)

            if self.use_tracking_lstm:
                # Update the hidden state and obtain predicted actions.
                tracking_hidden, actions_t = self._prediction_and_tracking_network(
                    tracking_hidden,
                    predict_inp,
                    h_dim * 3,
                    self.tracking_lstm_hidden_dim,
                    self._vs,
                    name="prediction_and_tracking")
            else:
                # Obtain predicted actions directly.
                actions_t = self._prediction_and_tracking_network(
                    predict_inp,
                    h_dim * 3,
                    util.NUM_TRANSITION_TYPES,
                    self._vs,
                    name="prediction_and_tracking")

        if self.train_with_predicted_transitions:
            # Model 2 case.
            if self.interpolate:
                # Only use ground truth transitions if they are marked as visible to the model.
                effective_ss_mask_gen_matrix_t = ss_mask_gen_matrix_t * ground_truth_transitions_visible
                # Interpolate between truth and prediction using bernoulli RVs
                # generated prior to the step.
                mask = (transitions_t * effective_ss_mask_gen_matrix_t +
                        actions_t.argmax(axis=1) *
                        (1 - effective_ss_mask_gen_matrix_t))
            else:
                # Use predicted actions to build a mask.
                mask = actions_t.argmax(axis=1)
        elif self._predict_transitions:
            # Use transitions provided from external parser when not masked out
            mask = (transitions_t * ground_truth_transitions_visible +
                    actions_t.argmax(axis=1) *
                    (1 - ground_truth_transitions_visible))
        else:
            # Model 0 case.
            mask = transitions_t

        # Now update the stack: first precompute reduce results.
        if self.model_dim != self.stack_dim:
            stack1 = stack_t[:, 0, :self.model_dim].reshape(
                (-1, self.model_dim))
            stack2 = stack_t[:, 1, :self.model_dim].reshape(
                (-1, self.model_dim))
        else:
            stack1 = stack_t[:, 0].reshape((-1, self.model_dim))
            stack2 = stack_t[:, 1].reshape((-1, self.model_dim))
        reduce_items = (stack1, stack2)
        if self.connect_tracking_comp:
            tracking_h_t = tracking_hidden[:, :self.tracking_lstm_hidden_dim]
            reduce_value = self._compose_network(
                reduce_items,
                tracking_h_t,
                self.model_dim,
                self._vs,
                name="compose",
                external_state_dim=self.tracking_lstm_hidden_dim)
        else:
            reduce_value = self._compose_network(reduce_items,
                                                 (self.model_dim, ) * 2,
                                                 self.model_dim,
                                                 self._vs,
                                                 name="compose")

        # Compute new stack value.
        stack_next = update_stack(stack_t, buffer_top_t, reduce_value, mask,
                                  self.model_dim)

        # If attention is to be used and premise_stack_tops is not None (i.e.
        # we're processing the hypothesis) calculate the attention weighed representation.
        if self.use_attention != "None" and self.is_hypothesis:
            h_dim = self.model_dim / 2
            if self.use_attention in {"TreeWangJiang", "TreeThang"}:
                mask_ = mask.dimshuffle(0, "x")
                mask_ = T.cast(mask_, dtype=theano.config.floatX)
                attention_hidden_l = stack_t[:, 0, self.model_dim:] * mask_
                attention_hidden_r = stack_t[:, 1, self.model_dim:] * mask_

                tree_attention_hidden = self._attention_unit(
                    attention_hidden_l,
                    attention_hidden_r,
                    stack_next[:, 0, :h_dim],
                    premise_stack_tops,
                    projected_stack_tops,
                    h_dim,
                    self._vs,
                    name="attention_unit")
                stack_next = T.set_subtensor(stack_next[:, 0, self.model_dim:],
                                             tree_attention_hidden)
            else:
                attention_hidden = self._attention_unit(attention_hidden,
                                                        stack_next[:,
                                                                   0, :h_dim],
                                                        premise_stack_tops,
                                                        projected_stack_tops,
                                                        h_dim,
                                                        self._vs,
                                                        name="attention_unit")

        # Move buffer cursor as necessary. Since mask == 1 when reduce, we
        # should increment each buffer cursor by 1 - mask.
        buffer_cur_next = buffer_cur_t + (1 - mask)

        if self._predict_transitions:
            ret_val = stack_next, buffer_cur_next, tracking_hidden, attention_hidden, actions_t
        else:
            ret_val = stack_next, buffer_cur_next, tracking_hidden, attention_hidden

        if not self.interpolate:
            # Use ss_mask as a redundant return value.
            ret_val = (ss_mask_gen_matrix_t, ) + ret_val

        return ret_val
Esempio n. 3
0
def build_sentence_model(cls, vocab_size, seq_length, tokens, transitions,
                         num_classes, training_mode, ground_truth_transitions_visible, vs,
                         initial_embeddings=None, project_embeddings=False, ss_mask_gen=None, ss_prob=0.0):
    """
    Construct a classifier which makes use of some hard-stack model.

    Args:
      cls: Hard stack class to use (from e.g. `spinn.fat_stack`)
      vocab_size:
      seq_length: Length of each sequence provided to the stack model
      tokens: Theano batch (integer matrix), `batch_size * seq_length`
      transitions: Theano batch (integer matrix), `batch_size * seq_length`
      num_classes: Number of output classes
      training_mode: A Theano scalar indicating whether to act as a training model
        with dropout (1.0) or to act as an eval model with rescaling (0.0).
      ground_truth_transitions_visible: A Theano scalar. If set (1.0), allow the model access
        to ground truth transitions. This can be disabled at evaluation time to force Model 1
        (or 2S) to evaluate in the Model 2 style with predicted transitions. Has no effect on Model 0.
      vs: Variable store.
    """

    # Prepare layer which performs stack element composition.
    if cls is spinn.plain_rnn.RNN:
        compose_network = partial(util.LSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        embedding_projection_network = None
    elif cls is spinn.cbow.CBOW:
        compose_network = None
        embedding_projection_network = None
    else:
        if FLAGS.lstm_composition:
            compose_network = partial(util.TreeLSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        else:
            assert not FLAGS.connect_tracking_comp, "Can only connect tracking and composition unit while using TreeLSTM"
            compose_network = partial(util.ReLULayer,
                                      initializer=util.HeKaimingInitializer())

        if project_embeddings:
            embedding_projection_network = util.Linear
        else:
            assert FLAGS.word_embedding_dim == FLAGS.model_dim, \
                "word_embedding_dim must equal model_dim unless a projection layer is used."
            embedding_projection_network = util.IdentityLayer

    # Build hard stack which scans over input sequence.
    sentence_model = cls(
        FLAGS.model_dim, FLAGS.word_embedding_dim, vocab_size, seq_length,
        compose_network, embedding_projection_network, training_mode, ground_truth_transitions_visible, vs,
        predict_use_cell=FLAGS.predict_use_cell,
        use_tracking_lstm=FLAGS.use_tracking_lstm,
        tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim,
        X=tokens,
        transitions=transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        connect_tracking_comp=FLAGS.connect_tracking_comp,
        context_sensitive_shift=FLAGS.context_sensitive_shift,
        context_sensitive_use_relu=FLAGS.context_sensitive_use_relu,
        use_input_batch_norm=False)

    # Extract top element of final stack timestep.
    if FLAGS.lstm_composition or cls is spinn.plain_rnn.RNN:
        sentence_vector = sentence_model.final_representations[:,:FLAGS.model_dim / 2].reshape((-1, FLAGS.model_dim / 2))
        sentence_vector_dim = FLAGS.model_dim / 2
    else:
        sentence_vector = sentence_model.final_representations.reshape((-1, FLAGS.model_dim))
        sentence_vector_dim = FLAGS.model_dim

    sentence_vector = util.BatchNorm(sentence_vector, sentence_vector_dim, vs, "sentence_vector", training_mode)
    sentence_vector = util.Dropout(sentence_vector, FLAGS.semantic_classifier_keep_rate, training_mode)

    # Feed forward through a single output layer
    logits = util.Linear(
        sentence_vector, sentence_vector_dim, num_classes, vs,
        name="semantic_classifier", use_bias=True)

    return sentence_model.transitions_pred, logits
Esempio n. 4
0
def build_sentence_pair_model(cls, vocab_size, seq_length, tokens, transitions,
                     num_classes, training_mode, ground_truth_transitions_visible, vs,
                     initial_embeddings=None, project_embeddings=False, ss_mask_gen=None, ss_prob=0.0):
    """
    Construct a classifier which makes use of some hard-stack model.

    Args:
      cls: Hard stack class to use (from e.g. `spinn.fat_stack`)
      vocab_size:
      seq_length: Length of each sequence provided to the stack model
      tokens: Theano batch (integer matrix), `batch_size * seq_length`
      transitions: Theano batch (integer matrix), `batch_size * seq_length`
      num_classes: Number of output classes
      training_mode: A Theano scalar indicating whether to act as a training model
        with dropout (1.0) or to act as an eval model with rescaling (0.0).
      ground_truth_transitions_visible: A Theano scalar. If set (1.0), allow the model access
        to ground truth transitions. This can be disabled at evaluation time to force Model 1
        (or 2S) to evaluate in the Model 2 style with predicted transitions. Has no effect on Model 0.
      vs: Variable store.
    """


    # Prepare layer which performs stack element composition.
    if cls is spinn.plain_rnn.RNN:
        compose_network = partial(util.LSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        embedding_projection_network = None
    elif cls is spinn.cbow.CBOW:
        compose_network = None
        embedding_projection_network = None
    else:
        if FLAGS.lstm_composition:
            compose_network = partial(util.TreeLSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        else:
            assert not FLAGS.connect_tracking_comp, "Can only connect tracking and composition unit while using TreeLSTM"
            compose_network = partial(util.ReLULayer,
                                      initializer=util.HeKaimingInitializer())

        if project_embeddings:
            embedding_projection_network = util.Linear
        else:
            assert FLAGS.word_embedding_dim == FLAGS.model_dim, \
                "word_embedding_dim must equal model_dim unless a projection layer is used."
            embedding_projection_network = util.IdentityLayer

    # Split the two sentences
    premise_tokens = tokens[:, :, 0]
    hypothesis_tokens = tokens[:, :, 1]

    premise_transitions = transitions[:, :, 0]
    hypothesis_transitions = transitions[:, :, 1]

    # Build two hard stack models which scan over input sequences.
    premise_model = cls(
        FLAGS.model_dim, FLAGS.word_embedding_dim, vocab_size, seq_length,
        compose_network, embedding_projection_network, training_mode, ground_truth_transitions_visible, vs,
        predict_use_cell=FLAGS.predict_use_cell,
        use_tracking_lstm=FLAGS.use_tracking_lstm,
        tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim,
        X=premise_tokens,
        transitions=premise_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        connect_tracking_comp=FLAGS.connect_tracking_comp,
        context_sensitive_shift=FLAGS.context_sensitive_shift,
        context_sensitive_use_relu=FLAGS.context_sensitive_use_relu,
        use_attention=FLAGS.use_attention,
        initialize_hyp_tracking_state=FLAGS.initialize_hyp_tracking_state)

    premise_stack_tops = premise_model.stack_tops if FLAGS.use_attention != "None" else None
    premise_tracking_c_state_final = premise_model.tracking_c_state_final if cls not in [spinn.plain_rnn.RNN, 
                                                                                            spinn.cbow.CBOW] else None
    hypothesis_model = cls(
        FLAGS.model_dim, FLAGS.word_embedding_dim, vocab_size, seq_length,
        compose_network, embedding_projection_network, training_mode, ground_truth_transitions_visible, vs,
        predict_use_cell=FLAGS.predict_use_cell,
        use_tracking_lstm=FLAGS.use_tracking_lstm,
        tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim,
        X=hypothesis_tokens,
        transitions=hypothesis_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        connect_tracking_comp=FLAGS.connect_tracking_comp,
        context_sensitive_shift=FLAGS.context_sensitive_shift,
        context_sensitive_use_relu=FLAGS.context_sensitive_use_relu,
        use_attention=FLAGS.use_attention,
        premise_stack_tops=premise_stack_tops,
        is_hypothesis=True,
        initialize_hyp_tracking_state=FLAGS.initialize_hyp_tracking_state,
        premise_tracking_c_state_final=premise_tracking_c_state_final)

    # Extract top element of final stack timestep.
    if FLAGS.use_attention == "None" or FLAGS.use_difference_feature or FLAGS.use_product_feature:
        premise_vector = premise_model.final_representations
        hypothesis_vector = hypothesis_model.final_representations

        if (FLAGS.lstm_composition and cls is not spinn.cbow.CBOW) or cls is spinn.plain_rnn.RNN:
            premise_vector = premise_vector[:,:FLAGS.model_dim / 2].reshape((-1, FLAGS.model_dim / 2))
            hypothesis_vector = hypothesis_vector[:,:FLAGS.model_dim / 2].reshape((-1, FLAGS.model_dim / 2))
            sentence_vector_dim = FLAGS.model_dim / 2
        else:
            premise_vector = premise_vector.reshape((-1, FLAGS.model_dim))
            hypothesis_vector = hypothesis_vector.reshape((-1, FLAGS.model_dim))
            sentence_vector_dim = FLAGS.model_dim

    if FLAGS.use_attention != "None":
        # Use the attention weighted representation
        h_dim = FLAGS.model_dim / 2
        mlp_input = hypothesis_model.final_weighed_representation.reshape((-1, h_dim))
        mlp_input_dim = h_dim
    else:
        # Create standard MLP features
        mlp_input = T.concatenate([premise_vector, hypothesis_vector], axis=1)
        mlp_input_dim = 2 * sentence_vector_dim

    if FLAGS.use_difference_feature:
        mlp_input = T.concatenate([mlp_input, premise_vector - hypothesis_vector], axis=1)
        mlp_input_dim += sentence_vector_dim

    if FLAGS.use_product_feature:
        mlp_input = T.concatenate([mlp_input, premise_vector * hypothesis_vector], axis=1)
        mlp_input_dim += sentence_vector_dim

    mlp_input = util.BatchNorm(mlp_input, mlp_input_dim, vs, "sentence_vectors", training_mode)
    mlp_input = util.Dropout(mlp_input, FLAGS.semantic_classifier_keep_rate, training_mode)

    if FLAGS.classifier_type == "ResNet":
        features = util.Linear(
            mlp_input, mlp_input_dim, FLAGS.sentence_pair_combination_layer_dim, vs,
            name="resnet/linear", use_bias=True)
        features_dim = FLAGS.sentence_pair_combination_layer_dim

        for layer in range(FLAGS.num_sentence_pair_combination_layers):
            features = util.HeKaimingResidualLayerSet(features, features_dim, vs, training_mode, name="resnet/" + str(layer), 
                dropout_keep_rate=FLAGS.semantic_classifier_keep_rate, depth=FLAGS.resnet_unit_depth, 
                initializer=util.HeKaimingInitializer())
            features = util.BatchNorm(features, features_dim, vs, "combining_mlp/" + str(layer), training_mode)
            features = util.Dropout(features, FLAGS.semantic_classifier_keep_rate, training_mode)
    elif FLAGS.classifier_type == "Highway":
        features = util.Linear(
            mlp_input, mlp_input_dim, FLAGS.sentence_pair_combination_layer_dim, vs,
            name="resnet/linear", use_bias=True)
        features_dim = FLAGS.sentence_pair_combination_layer_dim

        for layer in range(FLAGS.num_sentence_pair_combination_layers):
            features = util.HighwayLayer(features, features_dim, vs, training_mode, name="highway/" + str(layer), 
                dropout_keep_rate=FLAGS.semantic_classifier_keep_rate,
                initializer=util.HeKaimingInitializer())
            features = util.BatchNorm(features, features_dim, vs, "combining_mlp/" + str(layer), training_mode)
            features = util.Dropout(features, FLAGS.semantic_classifier_keep_rate, training_mode)
    else:    
        # Apply a combining MLP
        features = mlp_input
        features_dim = mlp_input_dim
        for layer in range(FLAGS.num_sentence_pair_combination_layers):
            features = util.ReLULayer(features, features_dim, FLAGS.sentence_pair_combination_layer_dim, vs,
                name="combining_mlp/" + str(layer),
                initializer=util.HeKaimingInitializer())
            features_dim = FLAGS.sentence_pair_combination_layer_dim

            features = util.BatchNorm(features, features_dim, vs, "combining_mlp/" + str(layer), training_mode)
            features = util.Dropout(features, FLAGS.semantic_classifier_keep_rate, training_mode) 

    # Feed forward through a single output layer
    logits = util.Linear(
        features, features_dim, num_classes, vs,
        name="semantic_classifier", use_bias=True)

    return premise_model.transitions_pred, hypothesis_model.transitions_pred, logits
Esempio n. 5
0
def build_sentence_pair_model(cls,
                              vocab_size,
                              seq_length,
                              tokens,
                              transitions,
                              num_classes,
                              training_mode,
                              ground_truth_transitions_visible,
                              vs,
                              initial_embeddings=None,
                              project_embeddings=False,
                              ss_mask_gen=None,
                              ss_prob=0.0):
    """
    Construct a classifier which makes use of some hard-stack model.

    Args:
      cls: Hard stack class to use (from e.g. `spinn.stack`)
      vocab_size:
      seq_length: Length of each sequence provided to the stack model
      tokens: Theano batch (integer matrix), `batch_size * seq_length`
      transitions: Theano batch (integer matrix), `batch_size * seq_length`
      num_classes: Number of output classes
      training_mode: A Theano scalar indicating whether to act as a training model
        with dropout (1.0) or to act as an eval model with rescaling (0.0).
      ground_truth_transitions_visible: A Theano scalar. If set (1.0), allow the model access
        to ground truth transitions. This can be disabled at evaluation time to force Model 1
        (or 2S) to evaluate in the Model 2 style with predicted transitions. Has no effect on Model 0.
      vs: Variable store.
    """

    # Prepare layer which performs stack element composition.
    if cls is spinn.plain_rnn.RNN:
        compose_network = partial(util.LSTMLayer,
                                  initializer=util.HeKaimingInitializer())
        embedding_projection_network = None
    else:
        if FLAGS.lstm_composition:
            compose_network = partial(util.TreeLSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        else:
            assert not FLAGS.connect_tracking_comp, "Can only connect tracking and composition unit while using TreeLSTM"
            compose_network = partial(util.ReLULayer,
                                      initializer=util.HeKaimingInitializer())

        if project_embeddings:
            embedding_projection_network = util.Linear
        else:
            assert FLAGS.word_embedding_dim == FLAGS.model_dim, \
                "word_embedding_dim must equal model_dim unless a projection layer is used."
            embedding_projection_network = util.IdentityLayer

    model_visible_dim = FLAGS.model_dim / 2 if FLAGS.lstm_composition else FLAGS.model_dim
    spec = util.ModelSpec(FLAGS.model_dim,
                          FLAGS.word_embedding_dim,
                          FLAGS.batch_size,
                          vocab_size,
                          seq_length,
                          model_visible_dim=model_visible_dim)

    # Split the two sentences
    premise_tokens = tokens[:, :, 0]
    hypothesis_tokens = tokens[:, :, 1]

    premise_transitions = transitions[:, :, 0]
    hypothesis_transitions = transitions[:, :, 1]

    # TODO: Check non-Model0 support.
    recurrence = cls(
        spec,
        vs,
        compose_network,
        use_context_sensitive_shift=FLAGS.context_sensitive_shift,
        context_sensitive_use_relu=FLAGS.context_sensitive_use_relu,
        use_tracking_lstm=FLAGS.use_tracking_lstm,
        tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim)

    # Build two hard stack models which scan over input sequences.
    premise_model = ThinStack(
        spec,
        recurrence,
        embedding_projection_network,
        training_mode,
        ground_truth_transitions_visible,
        vs,
        X=premise_tokens,
        transitions=premise_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        use_input_batch_norm=False,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        name="premise")

    hypothesis_model = ThinStack(
        spec,
        recurrence,
        embedding_projection_network,
        training_mode,
        ground_truth_transitions_visible,
        vs,
        X=hypothesis_tokens,
        transitions=hypothesis_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        use_input_batch_norm=False,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        name="hypothesis")

    # Create standard MLP features
    mlp_input = T.concatenate([premise_vector, hypothesis_vector], axis=1)
    mlp_input_dim = 2 * sentence_vector_dim

    if FLAGS.use_difference_feature:
        mlp_input = T.concatenate(
            [mlp_input, premise_vector - hypothesis_vector], axis=1)
        mlp_input_dim += sentence_vector_dim

    if FLAGS.use_product_feature:
        mlp_input = T.concatenate(
            [mlp_input, premise_vector * hypothesis_vector], axis=1)
        mlp_input_dim += sentence_vector_dim

    mlp_input = util.BatchNorm(mlp_input, mlp_input_dim, vs,
                               "sentence_vectors", training_mode)
    mlp_input = util.Dropout(mlp_input, FLAGS.semantic_classifier_keep_rate,
                             training_mode)

    # Apply a combining MLP
    prev_features = mlp_input
    prev_features_dim = mlp_input_dim
    for layer in range(FLAGS.num_sentence_pair_combination_layers):
        prev_features = util.ReLULayer(
            prev_features,
            prev_features_dim,
            FLAGS.sentence_pair_combination_layer_dim,
            vs,
            name="combining_mlp/" + str(layer),
            initializer=util.HeKaimingInitializer())
        prev_features_dim = FLAGS.sentence_pair_combination_layer_dim

        prev_features = util.BatchNorm(prev_features, prev_features_dim, vs,
                                       "combining_mlp/" + str(layer),
                                       training_mode)
        prev_features = util.Dropout(prev_features,
                                     FLAGS.semantic_classifier_keep_rate,
                                     training_mode)

    # Feed forward through a single output layer
    logits = util.Linear(prev_features,
                         prev_features_dim,
                         num_classes,
                         vs,
                         name="semantic_classifier",
                         use_bias=True)

    def zero_fn():
        premise_model.zero()
        hypothesis_model.zero()

    return premise_model, hypothesis_model, logits, zero_fn