コード例 #1
0
ファイル: stack.py プロジェクト: vishalbelsare/spinn
    def _project_embeddings(self, raw_embeddings, dropout_mask=None):
        """
        Run a forward pass of the embedding projection network, retaining
        intermediate values in order to support backpropagation.
        """
        projected = self._embedding_projection_network(raw_embeddings,
                                                       self.word_embedding_dim,
                                                       self.model_dim,
                                                       self._vs,
                                                       name=self._prefix +
                                                       "project")

        if self.use_input_batch_norm:
            projected = util.BatchNorm(projected,
                                       self.model_dim,
                                       self._vs,
                                       self._prefix + "buffer",
                                       self.training_mode,
                                       axes=[0, 1])

        # Dropout.
        # If we use dropout, we need to retain the mask for backprop purposes.
        ret_dropout_mask = None
        if self.use_input_dropout:
            projected, ret_dropout_mask = util.Dropout(
                projected,
                self.embedding_dropout_keep_rate,
                self.training_mode,
                dropout_mask=dropout_mask,
                return_mask=True)

        return projected, ret_dropout_mask
コード例 #2
0
ファイル: fat_stack.py プロジェクト: imclab/spinn
    def _make_scan(self):
        """Build the sequential composition / scan graph."""

        batch_size, max_stack_size = self.X.shape

        # Stack batch is a 3D tensor.
        stack_shape = (batch_size, max_stack_size, self.stack_dim)
        stack_init = T.zeros(stack_shape)

        # Look up all of the embeddings that will be used.
        raw_embeddings = self.word_embeddings[
            self.X]  # batch_size * seq_length * emb_dim

        if self.context_sensitive_shift:
            # Use the raw embedding vectors, they will be combined with the hidden state of
            # the tracking unit later
            buffer_t = raw_embeddings
            buffer_emb_dim = self.word_embedding_dim
        else:
            # Allocate a "buffer" stack initialized with projected embeddings,
            # and maintain a cursor in this buffer.
            buffer_t = self._embedding_projection_network(
                raw_embeddings,
                self.word_embedding_dim,
                self.model_dim,
                self._vs,
                name="project")
            if self.use_input_batch_norm:
                buffer_t = util.BatchNorm(buffer_t,
                                          self.model_dim,
                                          self._vs,
                                          "buffer",
                                          self.training_mode,
                                          axes=[0, 1])
            if self.use_input_dropout:
                buffer_t = util.Dropout(buffer_t,
                                        self.embedding_dropout_keep_rate,
                                        self.training_mode)
            buffer_emb_dim = self.model_dim

        # Collapse buffer to (batch_size * buffer_size) * emb_dim for fast indexing.
        buffer_t = buffer_t.reshape((-1, buffer_emb_dim))

        buffer_cur_init = T.zeros((batch_size, ), dtype="int")

        DUMMY = T.zeros((2, ))  # a dummy tensor used as a place-holder

        # Dimshuffle inputs to seq_len * batch_size for scanning
        transitions = self.transitions.dimshuffle(1, 0)

        # Initialize the hidden state for the tracking LSTM, if needed.
        if self.use_tracking_lstm:
            if self.initialize_hyp_tracking_state and self.is_hypothesis:
                # Initialize the c state of tracking unit from the c state of premise model.
                h_state_init = T.zeros(
                    (batch_size, self.tracking_lstm_hidden_dim))
                hidden_init = T.concatenate(
                    [h_state_init, self.premise_tracking_c_state_final],
                    axis=1)
            else:
                hidden_init = T.zeros(
                    (batch_size, self.tracking_lstm_hidden_dim * 2))
        else:
            hidden_init = DUMMY

        # Initialize the attention representation if needed
        if self.use_attention not in {"TreeWangJiang", "TreeThang", "None"
                                      } and self.is_hypothesis:
            h_dim = self.model_dim / 2
            if self.use_attention == "WangJiang" or self.use_attention == "Thang":
                attention_init = T.zeros((batch_size, 2 * h_dim))
            else:
                attention_init = T.zeros((batch_size, h_dim))
        else:
            # If we're not using a sequential attention accumulator (i.e., no attention or
            # tree attention), use a size-zero value here.
            attention_init = DUMMY

        # Set up the output list for scanning over _step().
        if self._predict_transitions:
            outputs_info = [
                stack_init, buffer_cur_init, hidden_init, attention_init, None
            ]
        else:
            outputs_info = [
                stack_init, buffer_cur_init, hidden_init, attention_init
            ]

        # Prepare data to scan over.
        sequences = [transitions]
        if self.interpolate:
            # Generate Bernoulli RVs to simulate scheduled sampling
            # if the interpolate flag is on.
            ss_mask_gen_matrix = self.ss_mask_gen.binomial(transitions.shape,
                                                           p=self.ss_prob)
            # Take in the RV sequence as input.
            sequences.append(ss_mask_gen_matrix)
        else:
            # Take in the RV sequqnce as a dummy output. This is
            # done to avaid defining another step function.
            outputs_info = [DUMMY] + outputs_info

        non_sequences = [buffer_t, self.ground_truth_transitions_visible]

        if self.use_attention != "None" and self.is_hypothesis:
            h_dim = self.model_dim / 2
            projected_stack_tops = util.AttentionUnitInit(
                self.premise_stack_tops, h_dim, self._vs)
            non_sequences = non_sequences + [
                self.premise_stack_tops, projected_stack_tops
            ]
        else:
            DUMMY2 = T.zeros((2, ))  # another dummy tensor
            non_sequences = non_sequences + [DUMMY, DUMMY2]

        scan_ret = theano.scan(self._step,
                               sequences=sequences,
                               non_sequences=non_sequences,
                               outputs_info=outputs_info,
                               n_steps=self.seq_length,
                               name="stack_fwd")

        stack_ind = 0 if self.interpolate else 1
        self.final_stack = scan_ret[0][stack_ind][-1]
        self.final_representations = self.final_stack[:, 0, :self.model_dim]
        self.embeddings = self.final_stack[:, 0]

        if self._predict_transitions:
            self.transitions_pred = scan_ret[0][-1].dimshuffle(1, 0, 2)
        else:
            self.transitions_pred = T.zeros((batch_size, 0))

        if self.use_attention != "None" and not self.is_hypothesis:
            # Store the stack top at each step as an attribute.
            h_dim = self.model_dim / 2
            self.stack_tops = scan_ret[0][stack_ind][:, :, 0, :h_dim].reshape(
                (max_stack_size, batch_size, h_dim))

        if self.use_attention != "None" and self.is_hypothesis:
            h_dim = self.model_dim / 2
            if self.use_attention == "Rocktaschel":
                self.final_weighed_representation = util.AttentionUnitFinalRepresentation(
                    scan_ret[0][stack_ind + 3][-1], self.embeddings[:, :h_dim],
                    h_dim, self._vs)
            elif self.use_attention in {"WangJiang", "Thang"}:
                self.final_weighed_representation = scan_ret[0][
                    stack_ind + 3][-1][:, :h_dim]
            elif self.use_attention in {"TreeWangJiang", "TreeThang"}:
                self.final_weighed_representation = scan_ret[0][stack_ind][
                    -1][:, 0, 2 * h_dim:3 * h_dim]

        if self.initialize_hyp_tracking_state and not self.is_hypothesis:
            # Store the final c states of the tracking unit.
            self.tracking_c_state_final = scan_ret[0][
                stack_ind + 2][-1][:, self.tracking_lstm_hidden_dim:]
        else:
            self.tracking_c_state_final = None
コード例 #3
0
def build_sentence_model(cls, vocab_size, seq_length, tokens, transitions,
                         num_classes, training_mode, ground_truth_transitions_visible, vs,
                         initial_embeddings=None, project_embeddings=False, ss_mask_gen=None, ss_prob=0.0):
    """
    Construct a classifier which makes use of some hard-stack model.

    Args:
      cls: Hard stack class to use (from e.g. `spinn.fat_stack`)
      vocab_size:
      seq_length: Length of each sequence provided to the stack model
      tokens: Theano batch (integer matrix), `batch_size * seq_length`
      transitions: Theano batch (integer matrix), `batch_size * seq_length`
      num_classes: Number of output classes
      training_mode: A Theano scalar indicating whether to act as a training model
        with dropout (1.0) or to act as an eval model with rescaling (0.0).
      ground_truth_transitions_visible: A Theano scalar. If set (1.0), allow the model access
        to ground truth transitions. This can be disabled at evaluation time to force Model 1
        (or 2S) to evaluate in the Model 2 style with predicted transitions. Has no effect on Model 0.
      vs: Variable store.
    """

    # Prepare layer which performs stack element composition.
    if cls is spinn.plain_rnn.RNN:
        compose_network = partial(util.LSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        embedding_projection_network = None
    elif cls is spinn.cbow.CBOW:
        compose_network = None
        embedding_projection_network = None
    else:
        if FLAGS.lstm_composition:
            compose_network = partial(util.TreeLSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        else:
            assert not FLAGS.connect_tracking_comp, "Can only connect tracking and composition unit while using TreeLSTM"
            compose_network = partial(util.ReLULayer,
                                      initializer=util.HeKaimingInitializer())

        if project_embeddings:
            embedding_projection_network = util.Linear
        else:
            assert FLAGS.word_embedding_dim == FLAGS.model_dim, \
                "word_embedding_dim must equal model_dim unless a projection layer is used."
            embedding_projection_network = util.IdentityLayer

    # Build hard stack which scans over input sequence.
    sentence_model = cls(
        FLAGS.model_dim, FLAGS.word_embedding_dim, vocab_size, seq_length,
        compose_network, embedding_projection_network, training_mode, ground_truth_transitions_visible, vs,
        predict_use_cell=FLAGS.predict_use_cell,
        use_tracking_lstm=FLAGS.use_tracking_lstm,
        tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim,
        X=tokens,
        transitions=transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        connect_tracking_comp=FLAGS.connect_tracking_comp,
        context_sensitive_shift=FLAGS.context_sensitive_shift,
        context_sensitive_use_relu=FLAGS.context_sensitive_use_relu,
        use_input_batch_norm=False)

    # Extract top element of final stack timestep.
    if FLAGS.lstm_composition or cls is spinn.plain_rnn.RNN:
        sentence_vector = sentence_model.final_representations[:,:FLAGS.model_dim / 2].reshape((-1, FLAGS.model_dim / 2))
        sentence_vector_dim = FLAGS.model_dim / 2
    else:
        sentence_vector = sentence_model.final_representations.reshape((-1, FLAGS.model_dim))
        sentence_vector_dim = FLAGS.model_dim

    sentence_vector = util.BatchNorm(sentence_vector, sentence_vector_dim, vs, "sentence_vector", training_mode)
    sentence_vector = util.Dropout(sentence_vector, FLAGS.semantic_classifier_keep_rate, training_mode)

    # Feed forward through a single output layer
    logits = util.Linear(
        sentence_vector, sentence_vector_dim, num_classes, vs,
        name="semantic_classifier", use_bias=True)

    return sentence_model.transitions_pred, logits
コード例 #4
0
def build_sentence_pair_model(cls, vocab_size, seq_length, tokens, transitions,
                     num_classes, training_mode, ground_truth_transitions_visible, vs,
                     initial_embeddings=None, project_embeddings=False, ss_mask_gen=None, ss_prob=0.0):
    """
    Construct a classifier which makes use of some hard-stack model.

    Args:
      cls: Hard stack class to use (from e.g. `spinn.fat_stack`)
      vocab_size:
      seq_length: Length of each sequence provided to the stack model
      tokens: Theano batch (integer matrix), `batch_size * seq_length`
      transitions: Theano batch (integer matrix), `batch_size * seq_length`
      num_classes: Number of output classes
      training_mode: A Theano scalar indicating whether to act as a training model
        with dropout (1.0) or to act as an eval model with rescaling (0.0).
      ground_truth_transitions_visible: A Theano scalar. If set (1.0), allow the model access
        to ground truth transitions. This can be disabled at evaluation time to force Model 1
        (or 2S) to evaluate in the Model 2 style with predicted transitions. Has no effect on Model 0.
      vs: Variable store.
    """


    # Prepare layer which performs stack element composition.
    if cls is spinn.plain_rnn.RNN:
        compose_network = partial(util.LSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        embedding_projection_network = None
    elif cls is spinn.cbow.CBOW:
        compose_network = None
        embedding_projection_network = None
    else:
        if FLAGS.lstm_composition:
            compose_network = partial(util.TreeLSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        else:
            assert not FLAGS.connect_tracking_comp, "Can only connect tracking and composition unit while using TreeLSTM"
            compose_network = partial(util.ReLULayer,
                                      initializer=util.HeKaimingInitializer())

        if project_embeddings:
            embedding_projection_network = util.Linear
        else:
            assert FLAGS.word_embedding_dim == FLAGS.model_dim, \
                "word_embedding_dim must equal model_dim unless a projection layer is used."
            embedding_projection_network = util.IdentityLayer

    # Split the two sentences
    premise_tokens = tokens[:, :, 0]
    hypothesis_tokens = tokens[:, :, 1]

    premise_transitions = transitions[:, :, 0]
    hypothesis_transitions = transitions[:, :, 1]

    # Build two hard stack models which scan over input sequences.
    premise_model = cls(
        FLAGS.model_dim, FLAGS.word_embedding_dim, vocab_size, seq_length,
        compose_network, embedding_projection_network, training_mode, ground_truth_transitions_visible, vs,
        predict_use_cell=FLAGS.predict_use_cell,
        use_tracking_lstm=FLAGS.use_tracking_lstm,
        tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim,
        X=premise_tokens,
        transitions=premise_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        connect_tracking_comp=FLAGS.connect_tracking_comp,
        context_sensitive_shift=FLAGS.context_sensitive_shift,
        context_sensitive_use_relu=FLAGS.context_sensitive_use_relu,
        use_attention=FLAGS.use_attention,
        initialize_hyp_tracking_state=FLAGS.initialize_hyp_tracking_state)

    premise_stack_tops = premise_model.stack_tops if FLAGS.use_attention != "None" else None
    premise_tracking_c_state_final = premise_model.tracking_c_state_final if cls not in [spinn.plain_rnn.RNN, 
                                                                                            spinn.cbow.CBOW] else None
    hypothesis_model = cls(
        FLAGS.model_dim, FLAGS.word_embedding_dim, vocab_size, seq_length,
        compose_network, embedding_projection_network, training_mode, ground_truth_transitions_visible, vs,
        predict_use_cell=FLAGS.predict_use_cell,
        use_tracking_lstm=FLAGS.use_tracking_lstm,
        tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim,
        X=hypothesis_tokens,
        transitions=hypothesis_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        connect_tracking_comp=FLAGS.connect_tracking_comp,
        context_sensitive_shift=FLAGS.context_sensitive_shift,
        context_sensitive_use_relu=FLAGS.context_sensitive_use_relu,
        use_attention=FLAGS.use_attention,
        premise_stack_tops=premise_stack_tops,
        is_hypothesis=True,
        initialize_hyp_tracking_state=FLAGS.initialize_hyp_tracking_state,
        premise_tracking_c_state_final=premise_tracking_c_state_final)

    # Extract top element of final stack timestep.
    if FLAGS.use_attention == "None" or FLAGS.use_difference_feature or FLAGS.use_product_feature:
        premise_vector = premise_model.final_representations
        hypothesis_vector = hypothesis_model.final_representations

        if (FLAGS.lstm_composition and cls is not spinn.cbow.CBOW) or cls is spinn.plain_rnn.RNN:
            premise_vector = premise_vector[:,:FLAGS.model_dim / 2].reshape((-1, FLAGS.model_dim / 2))
            hypothesis_vector = hypothesis_vector[:,:FLAGS.model_dim / 2].reshape((-1, FLAGS.model_dim / 2))
            sentence_vector_dim = FLAGS.model_dim / 2
        else:
            premise_vector = premise_vector.reshape((-1, FLAGS.model_dim))
            hypothesis_vector = hypothesis_vector.reshape((-1, FLAGS.model_dim))
            sentence_vector_dim = FLAGS.model_dim

    if FLAGS.use_attention != "None":
        # Use the attention weighted representation
        h_dim = FLAGS.model_dim / 2
        mlp_input = hypothesis_model.final_weighed_representation.reshape((-1, h_dim))
        mlp_input_dim = h_dim
    else:
        # Create standard MLP features
        mlp_input = T.concatenate([premise_vector, hypothesis_vector], axis=1)
        mlp_input_dim = 2 * sentence_vector_dim

    if FLAGS.use_difference_feature:
        mlp_input = T.concatenate([mlp_input, premise_vector - hypothesis_vector], axis=1)
        mlp_input_dim += sentence_vector_dim

    if FLAGS.use_product_feature:
        mlp_input = T.concatenate([mlp_input, premise_vector * hypothesis_vector], axis=1)
        mlp_input_dim += sentence_vector_dim

    mlp_input = util.BatchNorm(mlp_input, mlp_input_dim, vs, "sentence_vectors", training_mode)
    mlp_input = util.Dropout(mlp_input, FLAGS.semantic_classifier_keep_rate, training_mode)

    if FLAGS.classifier_type == "ResNet":
        features = util.Linear(
            mlp_input, mlp_input_dim, FLAGS.sentence_pair_combination_layer_dim, vs,
            name="resnet/linear", use_bias=True)
        features_dim = FLAGS.sentence_pair_combination_layer_dim

        for layer in range(FLAGS.num_sentence_pair_combination_layers):
            features = util.HeKaimingResidualLayerSet(features, features_dim, vs, training_mode, name="resnet/" + str(layer), 
                dropout_keep_rate=FLAGS.semantic_classifier_keep_rate, depth=FLAGS.resnet_unit_depth, 
                initializer=util.HeKaimingInitializer())
            features = util.BatchNorm(features, features_dim, vs, "combining_mlp/" + str(layer), training_mode)
            features = util.Dropout(features, FLAGS.semantic_classifier_keep_rate, training_mode)
    elif FLAGS.classifier_type == "Highway":
        features = util.Linear(
            mlp_input, mlp_input_dim, FLAGS.sentence_pair_combination_layer_dim, vs,
            name="resnet/linear", use_bias=True)
        features_dim = FLAGS.sentence_pair_combination_layer_dim

        for layer in range(FLAGS.num_sentence_pair_combination_layers):
            features = util.HighwayLayer(features, features_dim, vs, training_mode, name="highway/" + str(layer), 
                dropout_keep_rate=FLAGS.semantic_classifier_keep_rate,
                initializer=util.HeKaimingInitializer())
            features = util.BatchNorm(features, features_dim, vs, "combining_mlp/" + str(layer), training_mode)
            features = util.Dropout(features, FLAGS.semantic_classifier_keep_rate, training_mode)
    else:    
        # Apply a combining MLP
        features = mlp_input
        features_dim = mlp_input_dim
        for layer in range(FLAGS.num_sentence_pair_combination_layers):
            features = util.ReLULayer(features, features_dim, FLAGS.sentence_pair_combination_layer_dim, vs,
                name="combining_mlp/" + str(layer),
                initializer=util.HeKaimingInitializer())
            features_dim = FLAGS.sentence_pair_combination_layer_dim

            features = util.BatchNorm(features, features_dim, vs, "combining_mlp/" + str(layer), training_mode)
            features = util.Dropout(features, FLAGS.semantic_classifier_keep_rate, training_mode) 

    # Feed forward through a single output layer
    logits = util.Linear(
        features, features_dim, num_classes, vs,
        name="semantic_classifier", use_bias=True)

    return premise_model.transitions_pred, hypothesis_model.transitions_pred, logits
コード例 #5
0
ファイル: classifier.py プロジェクト: vishalbelsare/spinn
def build_sentence_pair_model(cls,
                              vocab_size,
                              seq_length,
                              tokens,
                              transitions,
                              num_classes,
                              training_mode,
                              ground_truth_transitions_visible,
                              vs,
                              initial_embeddings=None,
                              project_embeddings=False,
                              ss_mask_gen=None,
                              ss_prob=0.0):
    """
    Construct a classifier which makes use of some hard-stack model.

    Args:
      cls: Hard stack class to use (from e.g. `spinn.stack`)
      vocab_size:
      seq_length: Length of each sequence provided to the stack model
      tokens: Theano batch (integer matrix), `batch_size * seq_length`
      transitions: Theano batch (integer matrix), `batch_size * seq_length`
      num_classes: Number of output classes
      training_mode: A Theano scalar indicating whether to act as a training model
        with dropout (1.0) or to act as an eval model with rescaling (0.0).
      ground_truth_transitions_visible: A Theano scalar. If set (1.0), allow the model access
        to ground truth transitions. This can be disabled at evaluation time to force Model 1
        (or 2S) to evaluate in the Model 2 style with predicted transitions. Has no effect on Model 0.
      vs: Variable store.
    """

    # Prepare layer which performs stack element composition.
    if cls is spinn.plain_rnn.RNN:
        compose_network = partial(util.LSTMLayer,
                                  initializer=util.HeKaimingInitializer())
        embedding_projection_network = None
    else:
        if FLAGS.lstm_composition:
            compose_network = partial(util.TreeLSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        else:
            assert not FLAGS.connect_tracking_comp, "Can only connect tracking and composition unit while using TreeLSTM"
            compose_network = partial(util.ReLULayer,
                                      initializer=util.HeKaimingInitializer())

        if project_embeddings:
            embedding_projection_network = util.Linear
        else:
            assert FLAGS.word_embedding_dim == FLAGS.model_dim, \
                "word_embedding_dim must equal model_dim unless a projection layer is used."
            embedding_projection_network = util.IdentityLayer

    model_visible_dim = FLAGS.model_dim / 2 if FLAGS.lstm_composition else FLAGS.model_dim
    spec = util.ModelSpec(FLAGS.model_dim,
                          FLAGS.word_embedding_dim,
                          FLAGS.batch_size,
                          vocab_size,
                          seq_length,
                          model_visible_dim=model_visible_dim)

    # Split the two sentences
    premise_tokens = tokens[:, :, 0]
    hypothesis_tokens = tokens[:, :, 1]

    premise_transitions = transitions[:, :, 0]
    hypothesis_transitions = transitions[:, :, 1]

    # TODO: Check non-Model0 support.
    recurrence = cls(
        spec,
        vs,
        compose_network,
        use_context_sensitive_shift=FLAGS.context_sensitive_shift,
        context_sensitive_use_relu=FLAGS.context_sensitive_use_relu,
        use_tracking_lstm=FLAGS.use_tracking_lstm,
        tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim)

    # Build two hard stack models which scan over input sequences.
    premise_model = ThinStack(
        spec,
        recurrence,
        embedding_projection_network,
        training_mode,
        ground_truth_transitions_visible,
        vs,
        X=premise_tokens,
        transitions=premise_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        use_input_batch_norm=False,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        name="premise")

    hypothesis_model = ThinStack(
        spec,
        recurrence,
        embedding_projection_network,
        training_mode,
        ground_truth_transitions_visible,
        vs,
        X=hypothesis_tokens,
        transitions=hypothesis_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        use_input_batch_norm=False,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        name="hypothesis")

    # Create standard MLP features
    mlp_input = T.concatenate([premise_vector, hypothesis_vector], axis=1)
    mlp_input_dim = 2 * sentence_vector_dim

    if FLAGS.use_difference_feature:
        mlp_input = T.concatenate(
            [mlp_input, premise_vector - hypothesis_vector], axis=1)
        mlp_input_dim += sentence_vector_dim

    if FLAGS.use_product_feature:
        mlp_input = T.concatenate(
            [mlp_input, premise_vector * hypothesis_vector], axis=1)
        mlp_input_dim += sentence_vector_dim

    mlp_input = util.BatchNorm(mlp_input, mlp_input_dim, vs,
                               "sentence_vectors", training_mode)
    mlp_input = util.Dropout(mlp_input, FLAGS.semantic_classifier_keep_rate,
                             training_mode)

    # Apply a combining MLP
    prev_features = mlp_input
    prev_features_dim = mlp_input_dim
    for layer in range(FLAGS.num_sentence_pair_combination_layers):
        prev_features = util.ReLULayer(
            prev_features,
            prev_features_dim,
            FLAGS.sentence_pair_combination_layer_dim,
            vs,
            name="combining_mlp/" + str(layer),
            initializer=util.HeKaimingInitializer())
        prev_features_dim = FLAGS.sentence_pair_combination_layer_dim

        prev_features = util.BatchNorm(prev_features, prev_features_dim, vs,
                                       "combining_mlp/" + str(layer),
                                       training_mode)
        prev_features = util.Dropout(prev_features,
                                     FLAGS.semantic_classifier_keep_rate,
                                     training_mode)

    # Feed forward through a single output layer
    logits = util.Linear(prev_features,
                         prev_features_dim,
                         num_classes,
                         vs,
                         name="semantic_classifier",
                         use_bias=True)

    def zero_fn():
        premise_model.zero()
        hypothesis_model.zero()

    return premise_model, hypothesis_model, logits, zero_fn