Exemplo n.º 1
0
def get_cell(cell_type, size, layers=1, direction='unidirectional'):
    if cell_type == "layer_norm_basic":
        cell = LayerNormBasicLSTMCell(size)
    elif cell_type == "lstm_block_fused":
        cell = tf.contrib.rnn.LSTMBlockFusedCell(size)
    elif cell_type == "cudnn_lstm":
        cell = CudnnLSTM(layers, size, direction=direction)
    elif cell_type == "cudnn_gru":
        cell = CudnnGRU(layers, size, direction=direction)
    elif cell_type == "lstm_block":
        cell = LSTMBlockCell(size)
    elif cell_type == "gru_block":
        cell = GRUBlockCell(size)
    elif cell_type == "rnn":
        cell = BasicRNNCell(size)
    elif cell_type == "cudnn_rnn":
        cell = CudnnRNNTanh(layers, size)
    else:
        cell = BasicLSTMCell(size)
    return cell
Exemplo n.º 2
0
 def __call__(self, is_train, scope=None):
     return GRUBlockCell(self.num_units)
Exemplo n.º 3
0
    def _init(self):
        ExtractionQAModel._init(self)
        if self._composition == "GRU":
            if self._layer_norm:
                rnn_constructor = lambda size: FusedRNNCellAdaptor(LayerNormGRUCell(size), use_dynamic_rnn=True)
            else:
                rnn_constructor = lambda size: FusedRNNCellAdaptor(GRUBlockCell(size), use_dynamic_rnn=True)
        elif self._composition == "RNN":
            rnn_constructor = lambda size: FusedRNNCellAdaptor(BasicRNNCell(size), use_dynamic_rnn=True)
        else:
            if self._layer_norm:
                rnn_constructor = lambda size: FusedRNNCellAdaptor(LayerNormLSTMCell(size), use_dynamic_rnn=True)
            else:
                rnn_constructor = lambda size: LSTMBlockFusedCell(size)

        with tf.device(self._device0):
            self._eval = tf.get_variable("is_eval", initializer=False, trainable=False)
            self._set_train = self._eval.initializer
            self._set_eval = self._eval.assign(True)

            self.context_mask = tfutil.mask_for_lengths(self.context_length, self._batch_size, self.embedder.max_length)

            question_binary_mask = tfutil.mask_for_lengths(self.question_length,
                                                           self.question_embedder.batch_size,
                                                           self.question_embedder.max_length,
                                                           value=1.0,
                                                           mask_right=False)

            with tf.variable_scope("preprocessing_layer"):

                question_binary_mask = tf.gather(question_binary_mask, self.context_partition)
                self._embedded_question_not_dropped = tf.gather(self._embedded_question_not_dropped, self.context_partition)

                # context
                if self._with_features:
                    mask = tf.get_variable("attention_mask", [1, 1, self._embedded_question_not_dropped.get_shape()[-1].value],
                                           initializer=tf.constant_initializer(1.0))
                    # compute word wise features
                    #masked_question = self.question_embedder.output * mask
                    # [B, Q, L]
                    q2c_scores = tf.matmul(self._embedded_question_not_dropped * mask,
                                                 self._embedded_context_not_dropped, adjoint_b=True)
                    q2c_scores = q2c_scores + tf.expand_dims(self.context_mask, 1)
                    #c2q_weights = tf.reduce_max(q2c_scores / (tf.reduce_max(q2c_scores, [2], keep_dims=True) + 1e-5), [1])

                    q2c_weights = tf.reduce_sum(tf.nn.softmax(q2c_scores) * \
                                                tf.expand_dims(question_binary_mask, 2), [1])

                    # [B, L , 1]
                    self.context_features = tf.concat(axis=2, values=[tf.expand_dims(self._word_in_question, 2),
                                                          #tf.expand_dims(c2q_weights, 2),
                                                          tf.expand_dims(q2c_weights,  2)])

                    embedded_ctxt = tf.concat(axis=2, values=[self.embedded_context, self.context_features])


                    in_question_feature = tf.ones(tf.stack([self.question_embedder.batch_size,
                                                           self.question_embedder.max_length, 2]))
                    embedded_question = tf.concat(axis=2, values=[self.embedded_question, in_question_feature])
                else:
                    embedded_ctxt = self.embedded_context
                    embedded_question = self.embedded_question

                if self._with_question_type_features:
                    # Need to add another zero vector so that the total number
                    # of features is even, for LSTM performance reasons.
                    question_type_features = tf.stack([self._is_factoid,
                                                      self._is_list,
                                                      self._is_yesno,
                                                      tf.zeros(tf.shape(self._is_list),
                                                               dtype=tf.bool)],
                                                     axis=1)
                    question_type_features = tf.cast(question_type_features, tf.float32)
                    question_type_features = tf.expand_dims(question_type_features, 1)

                    embedded_question = tf.concat(axis=2, values=[embedded_question,
                                                      tf.tile(question_type_features,
                                                              tf.stack([1, tf.shape(embedded_question)[1], 1]))])

                    question_type_features = tf.gather(question_type_features, self.context_partition)
                    embedded_ctxt = tf.concat(axis=2, values=[embedded_ctxt,
                                                  tf.tile(question_type_features,
                                                          tf.stack([1, tf.shape(embedded_ctxt)[1], 1]))])

                if self._with_entity_tag_features:
                    embedded_question = tf.concat(axis=2, values=[embedded_question,
                                                      tf.cast(self._question_tags, tf.float32)])
                    embedded_ctxt = tf.concat(axis=2, values=[embedded_ctxt,
                                                  tf.cast(self._context_tags, tf.float32)])

                self.encoded_question = self._preprocessing_layer(rnn_constructor, embedded_question,
                                                                  self.question_length, projection_scope="question_proj")

                self.encoded_ctxt = self._preprocessing_layer(rnn_constructor, embedded_ctxt, self.context_length,
                                                              share_rnn=True, projection_scope="context_proj",
                                                              num_fusion_layers=self._num_intrafusion_layers)

                # single time attention over question
                attention_scores = tf.contrib.layers.fully_connected(self.encoded_question, 1,
                                                                     activation_fn=None,
                                                                     weights_initializer=None,
                                                                     biases_initializer=None,
                                                                     scope="attention")
                attention_scores = attention_scores + tf.expand_dims(
                    tfutil.mask_for_lengths(self.question_length, self.question_embedder.batch_size,
                                            self.question_embedder.max_length), 2)
                attention_weights = tf.nn.softmax(attention_scores, 1)
                self.question_attention_weights = attention_weights
                self.question_representation = tf.reduce_sum(attention_weights * self.encoded_question, [1])

                # Multiply question features for each paragraph
                self.encoded_question = tf.gather(self.encoded_question, self.context_partition)
                self.question_representation_per_context = tf.gather(self.question_representation, self.context_partition)
                self.question_length = tf.gather(self.question_length, self.context_partition)

            if self._with_inter_fusion:
                with tf.variable_scope("inter_fusion"):
                    with tf.variable_scope("associative") as vs:
                        mask = tf.get_variable("attention_mask", [1, 1, self.size], initializer=tf.constant_initializer(1.0))
                        mask = tf.nn.relu(mask)
                        for i in range(1):
                            # [B, Q, L]
                            inter_scores = tf.matmul(self.encoded_question * mask, self.encoded_ctxt, adjoint_b=True)
                            inter_scores = inter_scores + tf.expand_dims(self.context_mask, 1)

                            inter_weights = tf.nn.softmax(inter_scores)
                            inter_weights = inter_weights * tf.expand_dims(question_binary_mask, 2)
                            # [B, L, Q] x [B, Q, S] -> [B, L, S]
                            co_states = tf.matmul(inter_weights, self.encoded_question, adj_x=True)

                            u = tf.contrib.layers.fully_connected(tf.concat(axis=2, values=[self.encoded_ctxt, co_states]), self.size,
                                                                  activation_fn=tf.sigmoid,
                                                                  biases_initializer=tf.constant_initializer(1.0),
                                                                  scope="update_gate")
                            self.encoded_ctxt = u * self.encoded_ctxt + (1.0 - u) * co_states
                            vs.reuse_variables()

                    with tf.variable_scope("recurrent") as vs:
                        self.encoded_ctxt.set_shape([None, None, self.size])
                        self.encoded_ctxt = dynamic_rnn(GatedAggregationRNNCell(self.size),
                                                        tf.reverse_sequence(self.encoded_ctxt, self.context_length, 1),
                                                        self.context_length,
                                                        dtype=tf.float32, time_major=False, scope="backward")[0]

                        self.encoded_ctxt = dynamic_rnn(GatedAggregationRNNCell(self.size),
                                                        tf.reverse_sequence(self.encoded_ctxt, self.context_length, 1),
                                                        self.context_length,
                                                        dtype=tf.float32, time_major=False, scope="forward")[0]

            # No matching layer, so set matched_output to encoded_ctxt (for compatibility)
            self.matched_output = self.encoded_ctxt

            with tf.variable_scope("pointer_layer"):
                self.predicted_context_indices, \
                self._start_scores, self._start_pointer, self.start_probs, \
                self._end_scores, self._end_pointer, self.end_probs = \
                    self._spn_answer_layer(self.question_representation_per_context, self.encoded_ctxt)

            self.yesno_added = False
            if self._with_yesno:
                self.add_yesno(add_model_scope=False)

            self._train_variables = [p for p in tf.trainable_variables() if self.name in p.name]