예제 #1
0
파일: stack.py 프로젝트: hitluobin/spinn
    def _project_embeddings(self, raw_embeddings, dropout_mask=None):
        """
        Run a forward pass of the embedding projection network, retaining
        intermediate values in order to support backpropagation.
        """
        projected = self._embedding_projection_network(
            raw_embeddings, self.word_embedding_dim, self.model_dim, self._vs,
            name=self._prefix + "project")

        if self.use_input_batch_norm:
            projected = util.BatchNorm(
                projected, self.model_dim, self._vs, self._prefix + "buffer",
                self.training_mode, axes=[0, 1])

        # Dropout.
        # If we use dropout, we need to retain the mask for backprop purposes.
        ret_dropout_mask = None
        if self.use_input_dropout:
            projected, ret_dropout_mask = util.Dropout(
                projected, self.embedding_dropout_keep_rate,
                self.training_mode, dropout_mask=dropout_mask,
                return_mask=True)

        return projected, ret_dropout_mask
예제 #2
0
def build_sentence_model(cls, vocab_size, seq_length, tokens, transitions,
                         num_classes, training_mode, ground_truth_transitions_visible, vs,
                         initial_embeddings=None, project_embeddings=False, ss_mask_gen=None, ss_prob=0.0):
    """
    Construct a classifier which makes use of some hard-stack model.

    Args:
      cls: Hard stack class to use (from e.g. `rembed.stack`)
      vocab_size:
      seq_length: Length of each sequence provided to the stack model
      tokens: Theano batch (integer matrix), `batch_size * seq_length`
      transitions: Theano batch (integer matrix), `batch_size * seq_length`
      num_classes: Number of output classes
      training_mode: A Theano scalar indicating whether to act as a training model
        with dropout (1.0) or to act as an eval model with rescaling (0.0).
      ground_truth_transitions_visible: A Theano scalar. If set (1.0), allow the model access
        to ground truth transitions. This can be disabled at evaluation time to force Model 1
        (or 2S) to evaluate in the Model 2 style with predicted transitions. Has no effect on Model 0.
      vs: Variable store.
    """

    # Prepare layer which performs stack element composition.
    if cls is rembed.plain_rnn.RNN:
        compose_network = partial(util.LSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        embedding_projection_network = None
    else:
        if FLAGS.lstm_composition:
            compose_network = partial(util.TreeLSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        else:
            assert not FLAGS.connect_tracking_comp, "Can only connect tracking and composition unit while using TreeLSTM"
            compose_network = partial(util.ReLULayer,
                                      initializer=util.HeKaimingInitializer())

        if project_embeddings:
            embedding_projection_network = util.Linear
        else:
            assert FLAGS.word_embedding_dim == FLAGS.model_dim, \
                "word_embedding_dim must equal model_dim unless a projection layer is used."
            embedding_projection_network = util.IdentityLayer

    model_visible_dim = FLAGS.model_dim / 2 if FLAGS.lstm_composition else FLAGS.model_dim
    spec = util.ModelSpec(FLAGS.model_dim, FLAGS.word_embedding_dim,
                          FLAGS.batch_size, vocab_size, seq_length,
                          model_visible_dim=model_visible_dim)

    # TODO: Check non-Model0 support.
    recurrence = cls(spec, vs, compose_network,
                     use_context_sensitive_shift=FLAGS.context_sensitive_shift,
                     context_sensitive_use_relu=FLAGS.context_sensitive_use_relu,
                     use_tracking_lstm=FLAGS.use_tracking_lstm,
                     tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim)

    model = ThinStack(spec, recurrence, embedding_projection_network,
                      training_mode, ground_truth_transitions_visible, vs,
                      X=tokens,
                      transitions=transitions,
                      initial_embeddings=initial_embeddings,
                      embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
                      use_input_batch_norm=False,
                      ss_mask_gen=ss_mask_gen,
                      ss_prob=ss_prob)

    # Extract top element of final stack timestep.
    if FLAGS.lstm_composition:
        sentence_vector = model.sentence_embeddings[:, :FLAGS.model_dim / 2]
        sentence_vector_dim = FLAGS.model_dim / 2
    else:
        sentence_vector = model.sentence_embeddings
        sentence_vector_dim = FLAGS.model_dim

    sentence_vector = util.BatchNorm(sentence_vector, sentence_vector_dim, vs, "sentence_vector", training_mode)
    sentence_vector = util.Dropout(sentence_vector, FLAGS.semantic_classifier_keep_rate, training_mode)

    # Feed forward through a single output layer
    logits = util.Linear(
        sentence_vector, sentence_vector_dim, num_classes, vs,
        name="semantic_classifier", use_bias=True)

    def zero_fn():
        model.zero()

    return model, logits, zero_fn
예제 #3
0
def build_sentence_pair_model(cls, vocab_size, seq_length, tokens, transitions,
                     num_classes, training_mode, ground_truth_transitions_visible, vs,
                     initial_embeddings=None, project_embeddings=False, ss_mask_gen=None, ss_prob=0.0):
    """
    Construct a classifier which makes use of some hard-stack model.

    Args:
      cls: Hard stack class to use (from e.g. `rembed.stack`)
      vocab_size:
      seq_length: Length of each sequence provided to the stack model
      tokens: Theano batch (integer matrix), `batch_size * seq_length`
      transitions: Theano batch (integer matrix), `batch_size * seq_length`
      num_classes: Number of output classes
      training_mode: A Theano scalar indicating whether to act as a training model
        with dropout (1.0) or to act as an eval model with rescaling (0.0).
      ground_truth_transitions_visible: A Theano scalar. If set (1.0), allow the model access
        to ground truth transitions. This can be disabled at evaluation time to force Model 1
        (or 2S) to evaluate in the Model 2 style with predicted transitions. Has no effect on Model 0.
      vs: Variable store.
    """

    # Prepare layer which performs stack element composition.
    if cls is rembed.plain_rnn.RNN:
        compose_network = partial(util.LSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        embedding_projection_network = None
    else:
        if FLAGS.lstm_composition:
            compose_network = partial(util.TreeLSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        else:
            assert not FLAGS.connect_tracking_comp, "Can only connect tracking and composition unit while using TreeLSTM"
            compose_network = partial(util.ReLULayer,
                                      initializer=util.HeKaimingInitializer())

        if project_embeddings:
            embedding_projection_network = util.Linear
        else:
            assert FLAGS.word_embedding_dim == FLAGS.model_dim, \
                "word_embedding_dim must equal model_dim unless a projection layer is used."
            embedding_projection_network = util.IdentityLayer

    model_visible_dim = FLAGS.model_dim / 2 if FLAGS.lstm_composition else FLAGS.model_dim
    spec = util.ModelSpec(FLAGS.model_dim, FLAGS.word_embedding_dim,
                          FLAGS.batch_size, vocab_size, seq_length,
                          model_visible_dim=model_visible_dim)

    # Split the two sentences
    premise_tokens = tokens[:, :, 0]
    hypothesis_tokens = tokens[:, :, 1]

    premise_transitions = transitions[:, :, 0]
    hypothesis_transitions = transitions[:, :, 1]

    # TODO: Check non-Model0 support.
    recurrence = cls(spec, vs, compose_network,
                     use_context_sensitive_shift=FLAGS.context_sensitive_shift,
                     context_sensitive_use_relu=FLAGS.context_sensitive_use_relu,
                     use_tracking_lstm=FLAGS.use_tracking_lstm,
                     tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim)

    # Build two hard stack models which scan over input sequences.
    premise_model = ThinStack(spec, recurrence, embedding_projection_network,
        training_mode, ground_truth_transitions_visible, vs,
        X=premise_tokens,
        transitions=premise_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        use_input_batch_norm=False,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        use_attention=FLAGS.use_attention,
        name="premise")

    premise_stack_tops = premise_model.stack_tops if FLAGS.use_attention != "None" else None

    hypothesis_model = ThinStack(spec, recurrence, embedding_projection_network,
        training_mode, ground_truth_transitions_visible, vs,
        X=hypothesis_tokens,
        transitions=hypothesis_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        use_input_batch_norm=False,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        use_attention=FLAGS.use_attention,
        name="hypothesis")

    # Extract top element of final stack timestep.
    if FLAGS.use_attention == "None" or FLAGS.use_difference_feature or FLAGS.use_product_feature:
        premise_vector = premise_model.sentence_embeddings
        hypothesis_vector = hypothesis_model.sentence_embeddings

        if FLAGS.lstm_composition:
            premise_vector = premise_vector[:,:FLAGS.model_dim / 2]
            hypothesis_vector = hypothesis_vector[:,:FLAGS.model_dim / 2]
            sentence_vector_dim = FLAGS.model_dim / 2
        else:
            sentence_vector_dim = FLAGS.model_dim

    if FLAGS.use_attention != "None":
        # Use the attention weighted representation
        h_dim = FLAGS.model_dim / 2
        mlp_input = hypothesis_model.final_weighed_representation.reshape((-1, h_dim))
        mlp_input_dim = h_dim
    else:
        # Create standard MLP features
        mlp_input = T.concatenate([premise_vector, hypothesis_vector], axis=1)
        mlp_input_dim = 2 * sentence_vector_dim

    if FLAGS.use_difference_feature:
        mlp_input = T.concatenate([mlp_input, premise_vector - hypothesis_vector], axis=1)
        mlp_input_dim += sentence_vector_dim

    if FLAGS.use_product_feature:
        mlp_input = T.concatenate([mlp_input, premise_vector * hypothesis_vector], axis=1)
        mlp_input_dim += sentence_vector_dim

    mlp_input = util.BatchNorm(mlp_input, mlp_input_dim, vs, "sentence_vectors", training_mode)
    mlp_input = util.Dropout(mlp_input, FLAGS.semantic_classifier_keep_rate, training_mode)

    # Apply a combining MLP
    prev_features = mlp_input
    prev_features_dim = mlp_input_dim
    for layer in range(FLAGS.num_sentence_pair_combination_layers):
        prev_features = util.ReLULayer(prev_features, prev_features_dim, FLAGS.sentence_pair_combination_layer_dim, vs,
            name="combining_mlp/" + str(layer),
            initializer=util.HeKaimingInitializer())
        prev_features_dim = FLAGS.sentence_pair_combination_layer_dim

        prev_features = util.BatchNorm(prev_features, prev_features_dim, vs, "combining_mlp/" + str(layer), training_mode)
        prev_features = util.Dropout(prev_features, FLAGS.semantic_classifier_keep_rate, training_mode)

    # Feed forward through a single output layer
    logits = util.Linear(
        prev_features, prev_features_dim, num_classes, vs,
        name="semantic_classifier", use_bias=True)

    def zero_fn():
        premise_model.zero()
        hypothesis_model.zero()

    return premise_model, hypothesis_model, logits, zero_fn
예제 #4
0
    def _make_scan(self):
        """Build the sequential composition / scan graph."""

        batch_size, max_stack_size = self.X.shape

        # Stack batch is a 3D tensor.
        stack_shape = (batch_size, max_stack_size, self.stack_dim)
        stack_init = T.zeros(stack_shape)

        # Allocate two helper stack copies (passed as non_seqs into scan).
        stack_pushed = T.zeros(stack_shape)
        stack_merged = T.zeros(stack_shape)

        # Look up all of the embeddings that will be used.
        raw_embeddings = self.word_embeddings[
            self.X]  # batch_size * seq_length * emb_dim

        if self.context_sensitive_shift:
            # Use the raw embedding vectors, they will be combined with the hidden state of
            # the tracking unit later
            buffer_t = raw_embeddings
            buffer_emb_dim = self.word_embedding_dim
        else:
            # Allocate a "buffer" stack initialized with projected embeddings,
            # and maintain a cursor in this buffer.
            buffer_t = self._embedding_projection_network(
                raw_embeddings,
                self.word_embedding_dim,
                self.model_dim,
                self._vs,
                name="project")
            if self.use_input_batch_norm:
                buffer_t = util.BatchNorm(buffer_t,
                                          self.model_dim,
                                          self._vs,
                                          "buffer",
                                          self.training_mode,
                                          axes=[0, 1])
            if self.use_input_dropout:
                buffer_t = util.Dropout(buffer_t,
                                        self.embedding_dropout_keep_rate,
                                        self.training_mode)
            buffer_emb_dim = self.model_dim

        # Collapse buffer to (batch_size * buffer_size) * emb_dim for fast indexing.
        buffer_t = buffer_t.reshape((-1, buffer_emb_dim))

        buffer_cur_init = T.zeros((batch_size, ), dtype="int")

        DUMMY = T.zeros((2, ))  # a dummy tensor used as a place-holder

        # Dimshuffle inputs to seq_len * batch_size for scanning
        transitions = self.transitions.dimshuffle(1, 0)

        # Initialize the hidden state for the tracking LSTM, if needed.
        if self.use_tracking_lstm:
            if self.initialize_hyp_tracking_state and self.is_hypothesis:
                # Initialize the c state of tracking unit from the c state of premise model.
                h_state_init = T.zeros(
                    (batch_size, self.tracking_lstm_hidden_dim))
                hidden_init = T.concatenate(
                    [h_state_init, self.premise_tracking_c_state_final],
                    axis=1)
            else:
                hidden_init = T.zeros(
                    (batch_size, self.tracking_lstm_hidden_dim * 2))
        else:
            hidden_init = DUMMY

        # Initialize the attention representation if needed
        if self.use_attention not in {"TreeWangJiang", "TreeThang", "None"
                                      } and self.is_hypothesis:
            h_dim = self.model_dim / 2
            if self.use_attention == "WangJiang" or self.use_attention == "Thang":
                attention_init = T.zeros((batch_size, 2 * h_dim))
            else:
                attention_init = T.zeros((batch_size, h_dim))
        else:
            # If we're not using a sequential attention accumulator (i.e., no attention or
            # tree attention), use a size-zero value here.
            attention_init = DUMMY

        # Set up the output list for scanning over _step().
        if self._predict_transitions:
            outputs_info = [
                stack_init, buffer_cur_init, hidden_init, attention_init, None
            ]
        else:
            outputs_info = [
                stack_init, buffer_cur_init, hidden_init, attention_init
            ]

        # Prepare data to scan over.
        sequences = [transitions]
        if self.interpolate:
            # Generate Bernoulli RVs to simulate scheduled sampling
            # if the interpolate flag is on.
            ss_mask_gen_matrix = self.ss_mask_gen.binomial(transitions.shape,
                                                           p=self.ss_prob)
            # Take in the RV sequence as input.
            sequences.append(ss_mask_gen_matrix)
        else:
            # Take in the RV sequqnce as a dummy output. This is
            # done to avaid defining another step function.
            outputs_info = [DUMMY] + outputs_info

        non_sequences = [
            stack_pushed, stack_merged, buffer_t,
            self.ground_truth_transitions_visible
        ]

        if self.use_attention != "None" and self.is_hypothesis:
            h_dim = self.model_dim / 2
            projected_stack_tops = util.AttentionUnitInit(
                self.premise_stack_tops, h_dim, self._vs)
            non_sequences = non_sequences + [
                self.premise_stack_tops, projected_stack_tops
            ]
        else:
            DUMMY2 = T.zeros((2, ))  # another dummy tensor
            non_sequences = non_sequences + [DUMMY, DUMMY2]

        scan_ret = theano.scan(self._step,
                               sequences=sequences,
                               non_sequences=non_sequences,
                               outputs_info=outputs_info,
                               n_steps=self.seq_length,
                               name="stack_fwd")

        stack_ind = 0 if self.interpolate else 1
        self.final_stack = scan_ret[0][stack_ind][-1]
        self.final_representations = self.final_stack[:, 0, :self.model_dim]
        self.embeddings = self.final_stack[:, 0]

        if self._predict_transitions:
            self.transitions_pred = scan_ret[0][-1].dimshuffle(1, 0, 2)
        else:
            self.transitions_pred = T.zeros((batch_size, 0))

        if self.use_attention != "None" and not self.is_hypothesis:
            # Store the stack top at each step as an attribute.
            h_dim = self.model_dim / 2
            self.stack_tops = scan_ret[0][stack_ind][:, :, 0, :h_dim].reshape(
                (max_stack_size, batch_size, h_dim))

        if self.use_attention != "None" and self.is_hypothesis:
            h_dim = self.model_dim / 2
            if self.use_attention == "Rocktaschel":
                self.final_weighed_representation = util.AttentionUnitFinalRepresentation(
                    scan_ret[0][stack_ind + 3][-1], self.embeddings[:, :h_dim],
                    h_dim, self._vs)
            elif self.use_attention in {"WangJiang", "Thang"}:
                self.final_weighed_representation = scan_ret[0][
                    stack_ind + 3][-1][:, :h_dim]
            elif self.use_attention in {"TreeWangJiang", "TreeThang"}:
                self.final_weighed_representation = scan_ret[0][stack_ind][
                    -1][:, 0, 2 * h_dim:3 * h_dim]

        if self.initialize_hyp_tracking_state and not self.is_hypothesis:
            # Store the final c states of the tracking unit.
            self.tracking_c_state_final = scan_ret[0][
                stack_ind + 2][-1][:, self.tracking_lstm_hidden_dim:]
        else:
            self.tracking_c_state_final = None
예제 #5
0
    def _make_scan(self):
        """Build the sequential composition / scan graph."""

        batch_size, max_stack_size = self.X.shape

        # Stack batch is a 3D tensor.
        stack_shape = (batch_size, max_stack_size, self.model_dim)
        stack_init = T.zeros(stack_shape)

        # Allocate two helper stack copies (passed as non_seqs into scan).
        stack_pushed = T.zeros(stack_shape)
        stack_merged = T.zeros(stack_shape)

        # Look up all of the embeddings that will be used.
        raw_embeddings = self.embeddings[
            self.X]  # batch_size * seq_length * emb_dim

        # Allocate a "buffer" stack initialized with projected embeddings,
        # and maintain a cursor in this buffer.
        buffer_t = self._embedding_projection_network(raw_embeddings,
                                                      self.word_embedding_dim,
                                                      self.model_dim,
                                                      self._vs,
                                                      name="project")
        buffer_t = util.Dropout(buffer_t, self.embedding_dropout_keep_rate,
                                self.apply_dropout)

        # Collapse buffer to (batch_size * buffer_size) * emb_dim for fast indexing.
        buffer_t = buffer_t.reshape((-1, self.model_dim))

        buffer_cur_init = T.zeros((batch_size, ), dtype="int")

        # TODO(jgauthier): Implement linear memory (was in previous HardStack;
        # dropped it during a refactor)

        # Two definitions of the step function here, one with the scheduled sampling mask thrown in (step_ss),
        # one without (the old step). Only to avoid allocating a matrix of ones in case SS is turned off.
        # Identical in every respect except for how you set the mask

        def step_ss(transitions_t, ss_mask_gen_matrix_t, stack_t, buffer_cur_t,
                    stack_pushed, stack_merged, buffer):
            # Extract top buffer values.
            idxs = buffer_cur_t + (T.arange(batch_size) * self.seq_length)
            buffer_top_t = buffer[idxs]

            if self._predict_network is not None:
                # We are predicting our own stack operations.
                predict_inp = T.concatenate(
                    [stack_t[:, 0], stack_t[:, 1], buffer_top_t], axis=1)
                actions_t = self._predict_network(predict_inp,
                                                  self.model_dim * 3,
                                                  2,
                                                  self._vs,
                                                  name="predict_actions")

            if self.use_predictions:
                if self.interpolate:
                    # Interpolate between truth and prediction, using bernoulli RVs generated prior to the step
                    mask = transitions_t * ss_mask_gen_matrix_t + actions_t.argmax(
                        axis=1) * (1 - ss_mask_gen_matrix_t)
                else:
                    # Use predicted actions to build a mask.
                    mask = actions_t.argmax(axis=1)
            else:
                # Use transitions provided from external parser.
                mask = transitions_t

            # Now update the stack: first precompute merge results.
            merge_items = stack_t[:, :2].reshape((-1, self.model_dim * 2))
            merge_value = self._compose_network(merge_items,
                                                self.model_dim * 2,
                                                self.model_dim,
                                                self._vs,
                                                name="compose")

            # Compute new stack value.
            stack_next = update_hard_stack(stack_t, stack_pushed, stack_merged,
                                           buffer_top_t, merge_value, mask)

            # Move buffer cursor as necessary. Since mask == 1 when merge, we
            # should increment each buffer cursor by 1 - mask
            buffer_cur_next = buffer_cur_t + (1 - mask)

            if self._predict_network is not None:
                return stack_next, actions_t, buffer_cur_next
            else:
                return stack_next, buffer_cur_next

        def step(transitions_t, stack_t, buffer_cur_t, stack_pushed,
                 stack_merged, buffer):
            # Extract top buffer values.
            idxs = buffer_cur_t + (T.arange(batch_size) * self.seq_length)
            buffer_top_t = buffer[idxs]

            if self._predict_network is not None:
                # We are predicting our own stack operations.
                predict_inp = T.concatenate(
                    [stack_t[:, 0], stack_t[:, 1], buffer_top_t], axis=1)
                actions_t = self._predict_network(predict_inp,
                                                  self.model_dim * 3,
                                                  2,
                                                  self._vs,
                                                  name="predict_actions")

            if self.use_predictions:
                # Predicting our own actions
                mask = actions_t.argmax(axis=1)
            else:
                # Use transitions provided from external parser.
                mask = transitions_t

            # Now update the stack: first precompute merge results.
            merge_items = stack_t[:, :2].reshape((-1, self.model_dim * 2))
            merge_value = self._compose_network(merge_items,
                                                self.model_dim * 2,
                                                self.model_dim,
                                                self._vs,
                                                name="compose")

            # Compute new stack value.
            stack_next = update_hard_stack(stack_t, stack_pushed, stack_merged,
                                           buffer_top_t, merge_value, mask)

            # Move buffer cursor as necessary. Since mask == 1 when merge, we
            # should increment each buffer cursor by 1 - mask
            buffer_cur_next = buffer_cur_t + (1 - mask)

            if self._predict_network is not None:
                return stack_next, actions_t, buffer_cur_next
            else:
                return stack_next, buffer_cur_next

        # Dimshuffle inputs to seq_len * batch_size for scanning
        transitions = self.transitions.dimshuffle(1, 0)

        # Generate Bernoulli RVs to simulate scheduled sampling, if the interpolate flag is on
        if self.interpolate:
            ss_mask_gen_matrix = self.ss_mask_gen.binomial(transitions.shape,
                                                           p=self.ss_prob)
        else:
            ss_mask_gen_matrix = None

        # If we have a prediction network, we need an extra outputs_info
        # element (the `None`) to carry along prediction values
        if self._predict_network is not None:
            outputs_info = [stack_init, None, buffer_cur_init]
        else:
            outputs_info = [stack_init, buffer_cur_init]

        if self.interpolate:
            scan_ret = theano.scan(
                step_ss,
                sequences=[transitions, ss_mask_gen_matrix],
                non_sequences=[stack_pushed, stack_merged, buffer_t],
                outputs_info=outputs_info)[0]
        else:
            scan_ret = theano.scan(
                step,
                sequences=[transitions],
                non_sequences=[stack_pushed, stack_merged, buffer_t],
                outputs_info=outputs_info)[0]

        self.final_stack = scan_ret[0][-1]

        self.transitions_pred = None
        if self._predict_network is not None:
            self.transitions_pred = scan_ret[1].dimshuffle(1, 0, 2)
예제 #6
0
def build_sentence_model(cls,
                         vocab_size,
                         seq_length,
                         tokens,
                         transitions,
                         num_classes,
                         apply_dropout,
                         vs,
                         initial_embeddings=None,
                         project_embeddings=False,
                         ss_mask_gen=None,
                         ss_prob=0.0):
    """
    Construct a classifier which makes use of some hard-stack model.

    Args:
      cls: Hard stack class to use (from e.g. `rembed.stack`)
      vocab_size:
      seq_length: Length of each sequence provided to the stack model
      tokens: Theano batch (integer matrix), `batch_size * seq_length`
      transitions: Theano batch (integer matrix), `batch_size * seq_length`
      num_classes: Number of output classes
      apply_dropout: 1.0 at training time, 0.0 at eval time (to avoid corrupting outputs in dropout)
      vs: Variable store.
    """

    # Prepare layer which performs stack element composition.
    if FLAGS.lstm_composition:
        compose_network = partial(util.TreeLSTMLayer,
                                  initializer=util.HeKaimingInitializer())
    else:
        compose_network = partial(util.ReLULayer,
                                  initializer=util.HeKaimingInitializer())

    if project_embeddings:
        embedding_projection_network = util.Linear
    else:
        assert FLAGS.word_embedding_dim == FLAGS.model_dim, \
            "word_embedding_dim must equal model_dim unless a projection layer is used."
        embedding_projection_network = util.IdentityLayer

    # Build hard stack which scans over input sequence.
    stack = cls(FLAGS.model_dim,
                FLAGS.word_embedding_dim,
                vocab_size,
                seq_length,
                compose_network,
                embedding_projection_network,
                apply_dropout,
                vs,
                X=tokens,
                transitions=transitions,
                initial_embeddings=initial_embeddings,
                embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
                ss_mask_gen=ss_mask_gen,
                ss_prob=ss_prob)

    # Extract top element of final stack timestep.
    final_stack = stack.final_stack
    stack_top = final_stack[:, 0]
    sentence_vector = stack_top.reshape((-1, FLAGS.model_dim))

    sentence_vector = util.Dropout(sentence_vector,
                                   FLAGS.semantic_classifier_keep_rate,
                                   apply_dropout)

    # Feed forward through a single output layer
    logits = util.Linear(sentence_vector,
                         FLAGS.model_dim,
                         num_classes,
                         vs,
                         use_bias=True)

    return stack.transitions_pred, logits
예제 #7
0
def build_sentence_pair_model(cls,
                              vocab_size,
                              seq_length,
                              tokens,
                              transitions,
                              num_classes,
                              apply_dropout,
                              vs,
                              initial_embeddings=None,
                              project_embeddings=False,
                              ss_mask_gen=None,
                              ss_prob=0.0):
    """
    Construct a classifier which makes use of some hard-stack model.

    Args:
      cls: Hard stack class to use (from e.g. `rembed.stack`)
      vocab_size:
      seq_length: Length of each sequence provided to the stack model
      tokens: Theano batch (integer matrix), `batch_size * seq_length`
      transitions: Theano batch (integer matrix), `batch_size * seq_length`
      num_classes: Number of output classes
      apply_dropout: 1.0 at training time, 0.0 at eval time (to avoid corrupting outputs in dropout)
      vs: Variable store.
    """

    # Prepare layer which performs stack element composition.
    if FLAGS.lstm_composition:
        compose_network = partial(util.TreeLSTMLayer,
                                  initializer=util.HeKaimingInitializer())
    else:
        compose_network = partial(util.ReLULayer,
                                  initializer=util.HeKaimingInitializer())

    if project_embeddings:
        embedding_projection_network = util.Linear
    else:
        assert FLAGS.word_embedding_dim == FLAGS.model_dim, \
            "word_embedding_dim must equal model_dim unless a projection layer is used."
        embedding_projection_network = util.IdentityLayer

    # Split the two sentences
    premise_tokens = tokens[:, :, 0]
    hypothesis_tokens = tokens[:, :, 1]

    premise_transitions = transitions[:, :, 0]
    hypothesis_transitions = transitions[:, :, 1]

    # Build two hard stack models which scan over input sequences.
    premise_model = cls(FLAGS.model_dim,
                        FLAGS.word_embedding_dim,
                        vocab_size,
                        seq_length,
                        compose_network,
                        embedding_projection_network,
                        apply_dropout,
                        vs,
                        X=premise_tokens,
                        transitions=premise_transitions,
                        initial_embeddings=initial_embeddings,
                        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
                        ss_mask_gen=ss_mask_gen,
                        ss_prob=ss_prob)
    hypothesis_model = cls(
        FLAGS.model_dim,
        FLAGS.word_embedding_dim,
        vocab_size,
        seq_length,
        compose_network,
        embedding_projection_network,
        apply_dropout,
        vs,
        X=hypothesis_tokens,
        transitions=hypothesis_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob)

    # Extract top element of final stack timestep.
    premise_stack_top = premise_model.final_stack[:, 0]
    hypothesis_stack_top = hypothesis_model.final_stack[:, 0]

    premise_vector = premise_stack_top.reshape((-1, FLAGS.model_dim))
    hypothesis_vector = hypothesis_stack_top.reshape((-1, FLAGS.model_dim))

    # Concatenate and apply dropout
    mlp_input = T.concatenate([premise_vector, hypothesis_vector], axis=1)
    dropout_mlp_input = util.Dropout(mlp_input,
                                     FLAGS.semantic_classifier_keep_rate,
                                     apply_dropout)

    # Apply a combining MLP
    pair_features = util.MLP(dropout_mlp_input,
                             2 * FLAGS.model_dim,
                             FLAGS.model_dim,
                             vs,
                             hidden_dims=[FLAGS.model_dim],
                             name="combining_mlp",
                             initializer=util.HeKaimingInitializer())

    # Feed forward through a single output layer
    logits = util.Linear(pair_features,
                         FLAGS.model_dim,
                         num_classes,
                         vs,
                         use_bias=True)

    return premise_model.transitions_pred, hypothesis_model.transitions_pred, logits
예제 #8
0
파일: stack.py 프로젝트: mihail911/rembed
    def _make_scan(self):
        """Build the sequential composition / scan graph."""

        batch_size, max_stack_size = self.X.shape

        # Stack batch is a 3D tensor.
        stack_shape = (batch_size, max_stack_size, self.model_dim)
        stack_init = T.zeros(stack_shape)

        # Allocate two helper stack copies (passed as non_seqs into scan).
        stack_pushed = T.zeros(stack_shape)
        stack_merged = T.zeros(stack_shape)

        # Look up all of the embeddings that will be used.
        raw_embeddings = self.embeddings[
            self.X]  # batch_size * seq_length * emb_dim

        if self.context_sensitive_shift:
            # Use the raw embedding vectors, they will be combined with the hidden state of
            # the tracking unit later
            buffer_t = raw_embeddings
            buffer_emb_dim = self.word_embedding_dim
        else:
            # Allocate a "buffer" stack initialized with projected embeddings,
            # and maintain a cursor in this buffer.
            buffer_t = self._embedding_projection_network(
                raw_embeddings,
                self.word_embedding_dim,
                self.model_dim,
                self._vs,
                name="project")
            if self.use_input_batch_norm:
                buffer_t = util.BatchNorm(buffer_t,
                                          self.model_dim,
                                          self._vs,
                                          "buffer",
                                          self.training_mode,
                                          axes=[0, 1])
            if self.use_input_dropout:
                buffer_t = util.Dropout(buffer_t,
                                        self.embedding_dropout_keep_rate,
                                        self.training_mode)
            buffer_emb_dim = self.model_dim

        # Collapse buffer to (batch_size * buffer_size) * emb_dim for fast indexing.
        buffer_t = buffer_t.reshape((-1, buffer_emb_dim))

        buffer_cur_init = T.zeros((batch_size, ), dtype="int")

        DUMMY = T.zeros((2, ))  # a dummy tensor used as a place-holder

        # Dimshuffle inputs to seq_len * batch_size for scanning
        transitions = self.transitions.dimshuffle(1, 0)

        # Initialize the hidden state for the tracking LSTM, if needed.
        if self.use_tracking_lstm:
            # TODO: Unify what 'dim' means with LSTM. Here, it's the dim of
            # each of h and c. For 'model_dim', it's the combined dimension
            # of the full hidden state (so h and c are each model_dim/2).
            hidden_init = T.zeros(
                (batch_size, self.tracking_lstm_hidden_dim * 2))
        else:
            hidden_init = DUMMY

        # Set up the output list for scanning over _step().
        if self._predict_transitions:
            outputs_info = [stack_init, buffer_cur_init, hidden_init, None]
        else:
            outputs_info = [stack_init, buffer_cur_init, hidden_init]

        # Prepare data to scan over.
        sequences = [transitions]
        if self.interpolate:
            # Generate Bernoulli RVs to simulate scheduled sampling
            # if the interpolate flag is on.
            ss_mask_gen_matrix = self.ss_mask_gen.binomial(transitions.shape,
                                                           p=self.ss_prob)
            # Take in the RV sequence as input.
            sequences.append(ss_mask_gen_matrix)
        else:
            # Take in the RV sequqnce as a dummy output. This is
            # done to avaid defining another step function.
            outputs_info = [DUMMY] + outputs_info

        scan_ret = theano.scan(self._step,
                               sequences=sequences,
                               non_sequences=[
                                   stack_pushed, stack_merged, buffer_t,
                                   self.ground_truth_transitions_visible
                               ],
                               outputs_info=outputs_info)[0]

        stack_ind = 0 if self.interpolate else 1
        self.final_stack = scan_ret[stack_ind][-1]
        self.embeddings = self.final_stack[:, 0]

        self.transitions_pred = None
        if self._predict_transitions:
            self.transitions_pred = scan_ret[-1].dimshuffle(1, 0, 2)