예제 #1
0
    def _vector_logit(self, projected_decoder_state: tf.Tensor,
                      vector_value: tf.Tensor, scope: str) -> tf.Tensor:
        """Get logit for a single vector, e.g., sentinel vector."""
        assert_shape(projected_decoder_state, [-1, 1, -1])
        assert_shape(vector_value, [-1, -1])

        with tf.variable_scope("{}_logit".format(scope)):
            vector_bias = tf.get_variable(
                "vector_bias", [], initializer=tf.constant_initializer(0.0))

            proj_vector_for_logit = tf.expand_dims(
                linear(vector_value,
                       self.attention_state_size,
                       scope="vector_projection"), 1)

            if self._share_projections:
                proj_vector_for_ctx = proj_vector_for_logit
            else:
                proj_vector_for_ctx = tf.expand_dims(
                    linear(vector_value,
                           self.attention_state_size,
                           scope="vector_ctx_proj"), 1)

            vector_logit = tf.reduce_sum(
                self.attn_v *
                tf.tanh(projected_decoder_state + proj_vector_for_logit),
                [2]) + vector_bias
            assert_shape(vector_logit, [-1, 1])
            return proj_vector_for_ctx, vector_logit
예제 #2
0
    def states(self) -> tf.Tensor:
        convolutions = linear(self.ordered_embedded_inputs,
                              self.conv_features,
                              scope="order_and_embed")
        for layer in range(self.encoder_layers):
            convolutions = self._residual_conv(convolutions,
                                               "encoder_conv_{}".format(layer))

        return convolutions + linear(self.ordered_embedded_inputs,
                                     self.conv_features,
                                     scope="input_to_final_state")
예제 #3
0
 def __call__(self, inputs, state, scope=None):
     """Gated recurrent unit (GRU) with nunits cells."""
     with tf.variable_scope(scope or type(self).__name__):  # "GRUCell"
         with tf.variable_scope("Gates"):  # Reset gate and update gate.
             # We start with bias of 1.0 to not reset and not update.
             r, u = tf.split(linear([inputs, state], 2 * self._num_units),
                             2, 1)
             r, u = noisy_sigmoid(r, self.training), noisy_sigmoid(
                 u, self.training)
     with tf.variable_scope("Candidate"):
         c = noisy_tanh(linear([inputs, r * state], self._num_units),
                        self.training)
         new_h = u * state + (1 - u) * c
     return new_h, new_h
예제 #4
0
    def step(self, att_objects: List[BaseAttention], input_: tf.Tensor,
             prev_state: tf.Tensor, prev_attns: List[tf.Tensor]):

        with tf.variable_scope(self.step_scope):
            cell = self._get_rnn_cell()

            # Merge input and previous attentions into one vector of the
            # right size.
            if self._attention_on_input:
                x = linear([input_] + prev_attns, self.embedding_size)
            else:
                x = input_

            # Run the RNN.
            cell_output, state = cell(x, prev_state)

            # Run the attention mechanism.
            if self._rnn_cell_str == 'GRU':
                attns = [
                    a.attention(cell_output, prev_state, x)
                    for a in att_objects
                ]
            elif self._rnn_cell_str == 'LSTM':
                attns = [
                    a.attention(cell_output, prev_state.c, x)
                    for a in att_objects
                ]
            else:
                raise ValueError("Unknown RNN cell.")

            if self._conditional_gru and self._rnn_cell_str == "GRU":
                cell_cond = self._get_conditional_gru_cell()
                cond_input = tf.concat(attns, -1)
                cell_output, state = cell_cond(cond_input,
                                               state,
                                               scope="cond_gru_2_cell")

            with tf.name_scope("rnn_output_projection"):
                if attns:
                    output = linear([cell_output] + attns,
                                    cell.output_size,
                                    scope="AttnOutputProjection")
                else:
                    output = cell_output

            logits = self._logit_function(output)

        return logits, state, attns
예제 #5
0
    def __init__(self,
                 mlp_input,
                 layer_configuration,
                 dropout_plc,
                 output_size,
                 name='multilayer_perceptron'):

        with tf.variable_scope(name):
            last_layer = mlp_input
            last_layer_size = mlp_input.get_shape()[1].value

            self.n_params = 0
            for i, size in enumerate(layer_configuration):
                last_layer = tf.tanh(
                    linear(last_layer,
                           size,
                           scope="dense_layer_{}".format(i + 1)))
                last_layer = tf.nn.dropout(last_layer, dropout_plc)
                self.n_params += last_layer_size * size
                last_layer_size = size

            with tf.variable_scope("classification_layer"):
                self.n_params += last_layer_size * output_size
                w_out = tf.get_variable("W_out",
                                        shape=[last_layer_size, output_size])

                b_out = tf.get_variable("b_out",
                                        initializer=tf.zeros_initializer(
                                            [output_size]))

                self.logits = tf.matmul(last_layer, w_out) + b_out
예제 #6
0
    def attention(self, query_state):
        """Put attention masks on att_states_reshaped
           using hidden_features and query.
        """

        with tf.variable_scope(self.scope + "/Attention") as varscope:
            # Sort-of a hack to get the matrix (bahdanau's W_a) in the linear
            # projection to be initialized this way. The biases are initialized
            # as zeros
            varscope.set_initializer(
                tf.random_normal_initializer(stddev=0.001))
            y = linear(query_state, self.attention_vec_size, scope=varscope)
            y = tf.reshape(y, [-1, 1, 1, self.attention_vec_size])

            # pylint: disable=invalid-name
            # code copied from tensorflow. Suggestion: rename the variables
            # according to the Bahdanau paper
            s = self.get_logits(y)

            if self.input_weights is None:
                a = tf.nn.softmax(s)
            else:
                a_all = tf.nn.softmax(s) * self.input_weights
                norm = tf.reduce_sum(a_all, 1, keep_dims=True) + 1e-8
                a = a_all / norm

            self.logits_in_time.append(s)
            self.attentions_in_time.append(a)

            # Now calculate the attention-weighted vector d.
            d = tf.reduce_sum(
                tf.expand_dims(tf.expand_dims(a, -1), -1) *
                self.att_states_reshaped, [1, 2])

            return tf.reshape(d, [-1, self.attn_size])
예제 #7
0
    def func(train_mode: tf.Tensor,
             rnn_size: Optional[int] = None,
             encoders: Optional[List[Any]] = None) -> tf.Tensor:
        """Linearly project the encoders' encoded value to rnn_size
        and apply dropout

        Arguments:
            train_mode: tf 0-D bool Tensor specifying the training mode
            rnn_size: The size of the resulting vector
            encoders: The list of encoders
        """
        if rnn_size is None:
            raise ValueError("You must supply rnn_size for this type of "
                             "encoder projection")

        if encoders is None or not encoders:
            raise ValueError("There must be at least one encoder for this type"
                             " of encoder projection")

        with tf.variable_scope("encoders_projection") as scope:
            encoded_concat = tf.concat([e.encoded for e in encoders], 1)
            encoded_concat = dropout(
                encoded_concat, dropout_keep_prob, train_mode)

            return linear(encoded_concat, rnn_size, scope)
예제 #8
0
    def attention(self, decoder_state, decoder_prev_state, decoder_input):
        with tf.variable_scope(self.scope):
            projected_state = linear(decoder_state, self.attention_state_size)
            projected_state = tf.expand_dims(projected_state, 1)

            assert_shape(projected_state, [-1, 1, self.attention_state_size])
            attn_ctx_vectors = [
                a.attention(decoder_state, decoder_prev_state, decoder_input)
                for a in self._attn_objs]

            proj_ctxs, attn_logits = [list(t) for t in zip(*[
                self._vector_logit(projected_state, ctx_vec, scope=enc.name)
                for ctx_vec, enc in zip(attn_ctx_vectors, self._encoders)])]

            if self._use_sentinels:
                sentinel_value = _sentinel(decoder_state,
                                           decoder_prev_state,
                                           decoder_input)
                proj_sentinel, sentinel_logit = self._vector_logit(
                    projected_state, sentinel_value, scope="sentinel")
                proj_ctxs.append(proj_sentinel)
                attn_logits.append(sentinel_logit)

            attention_distr = tf.nn.softmax(tf.concat(attn_logits, 1))
            self.attentions_in_time.append(attention_distr)

            if self._share_projections:
                output_cxts = proj_ctxs
            else:
                output_cxts = [
                    tf.expand_dims(
                        linear(ctx_vec, self.attention_state_size,
                               scope="proj_attn_{}".format(enc.name)), 1)
                    for ctx_vec, enc in zip(attn_ctx_vectors, self._encoders)]
                if self._use_sentinels:
                    output_cxts.append(tf.expand_dims(
                        linear(sentinel_value, self.attention_state_size,
                               scope="proj_sentinel"), 1))

            projections_concat = tf.concat(output_cxts, 1)
            context = tf.reduce_sum(
                tf.expand_dims(attention_distr, 2) * projections_concat, [1])

            return context
예제 #9
0
    def input_plus_attention(self, *args) -> tf.Tensor:
        """Merge input and previous attentions into one vector of the
         right size.
        """
        loop_state = LoopState(*args)

        embedded_input = self.embed_input_symbol(*loop_state)

        return linear([embedded_input] + loop_state.prev_contexts,
                      self.embedding_size)
예제 #10
0
    def step(self,
             att_objects: List[Attention],
             input_: tf.Tensor,
             prev_state: tf.Tensor,
             prev_attns: List[tf.Tensor]):

        with tf.variable_scope(self.step_scope):
            cell = self._get_rnn_cell()

            # Merge input and previous attentions into one vector of the
            # right size.
            if self._attention_on_input:
                x = linear([input_] + prev_attns, self.embedding_size)
            else:
                x = input_

            # Run the RNN.
            cell_output, state = cell(x, prev_state)

            # Run the attention mechanism.
            attns = [a.attention(cell_output) for a in att_objects]

            if self._conditional_gru:
                x_2 = linear(
                    attns, self.embedding_size, scope="cond_gru_2_linproj")
                # Run the RNN for the second time
                cell_output, state = cell(
                    x_2, state, scope="cond_gru_2_cell")

            with tf.name_scope("rnn_output_projection"):
                if attns:
                    output = linear([cell_output] + attns,
                                    cell.output_size,
                                    scope="AttnOutputProjection")
                else:
                    output = cell_output

            logits = self._logit_function(output)

        return logits, state, attns
예제 #11
0
def _sentinel(state, prev_state, input_):
    """Sentinel value given the decoder state."""
    with tf.variable_scope("sentinel"):

        decoder_state_size = state.get_shape()[-1].value
        concatenation = tf.concat([prev_state, input_], 1)

        gate = tf.nn.sigmoid(linear(concatenation, decoder_state_size))
        sentinel_value = gate * state

        assert_shape(sentinel_value, [-1, decoder_state_size])

        return sentinel_value
예제 #12
0
    def _logit_function(self, rnn_output):
        """Compute logits on the vocabulary given the state

        This variant simply linearly project the vectors to fit
        the size of the vocabulary

        Arguments:
            rnn_output: the output of the decoder RNN
                        (after output projection)

        Returns:
            A Tensor of shape batch_size x vocabulary_size
        """
        return linear(self._dropout(rnn_output), self.vocabulary_size)
예제 #13
0
    def attention(self, query: tf.Tensor, decoder_prev_state: tf.Tensor,
                  decoder_input: tf.Tensor, loop_state: AttentionLoopState,
                  step: tf.Tensor) -> Tuple[tf.Tensor, AttentionLoopState]:
        with tf.variable_scope(self.att_scope_name):
            projected_state = linear(query, self.attention_state_size)
            projected_state = tf.expand_dims(projected_state, 1)

            assert_shape(projected_state, [-1, 1, self.attention_state_size])

            logits = []

            for proj, bias in zip(self.encoder_projections_for_logits,
                                  self.encoder_attn_biases):

                logits.append(
                    tf.reduce_sum(
                        self.attn_v * tf.tanh(projected_state + proj), [2]) +
                    bias)

            if self._use_sentinels:
                sentinel_value = _sentinel(query, decoder_prev_state,
                                           decoder_input)
                projected_sentinel, sentinel_logit = self._vector_logit(
                    projected_state, sentinel_value, scope="sentinel")
                logits.append(sentinel_logit)

            attentions = self._renorm_softmax(tf.concat(logits, 1))

            self.attentions_in_time.append(attentions)

            if self._use_sentinels:
                tiled_encoder_projections = self._tile_encoders_for_beamsearch(
                    projected_sentinel)

                projections_concat = tf.concat(
                    tiled_encoder_projections + [projected_sentinel], 1)

            else:
                projections_concat = tf.concat(
                    self.encoder_projections_for_ctx, 1)

            contexts = tf.reduce_sum(
                tf.expand_dims(attentions, 2) * projections_concat, [1])

            next_loop_state = AttentionLoopState(
                contexts=loop_state.contexts.write(step, contexts),
                weights=loop_state.weights.write(step, attentions))

            return contexts, next_loop_state
예제 #14
0
    def attention(self, decoder_state: tf.Tensor,
                  decoder_prev_state: tf.Tensor, _,
                  loop_state: AttentionLoopState,
                  step: tf.Tensor) -> Tuple[tf.Tensor, AttentionLoopState]:
        """put attention masks on att_states_reshaped
           using hidden_features and query.
        """

        with tf.variable_scope(self.scope + "/Attention") as varscope:
            # Sort-of a hack to get the matrix (bahdanau's W_a) in the linear
            # projection to be initialized this way. The biases are initialized
            # as zeros
            varscope.set_initializer(
                tf.random_normal_initializer(stddev=0.001))
            y = linear(decoder_state,
                       self.attention_state_size,
                       scope=varscope)
            y = tf.reshape(y, [-1, 1, 1, self.attention_state_size])

            # pylint: disable=invalid-name
            # code copied from tensorflow. Suggestion: rename the variables
            # according to the Bahdanau paper
            s = self.get_logits(y, loop_state.weights)

            if self.input_weights is None:
                weights = tf.nn.softmax(s)
            else:
                weights_all = tf.nn.softmax(s) * self.input_weights
                norm = tf.reduce_sum(weights_all, 1, keep_dims=True) + 1e-8
                weights = weights_all / norm
            # pylint: enable=invalid-name

            # Now calculate the attention-weighted vector d.
            context = tf.reduce_sum(
                tf.expand_dims(tf.expand_dims(weights, -1), -1) *
                self.att_states_reshaped, [1, 2])
            context = tf.reshape(context, [-1, self.attn_size])

            next_contexts = loop_state.contexts.write(step, context)
            next_weights = loop_state.weights.write(step, weights)

            next_loop_state = AttentionLoopState(contexts=next_contexts,
                                                 weights=next_weights)

            return context, next_loop_state
예제 #15
0
    def attention(self, decoder_state, decoder_prev_state, decoder_input):
        with tf.variable_scope(self.scope):
            projected_state = linear(decoder_state, self.attention_state_size)
            projected_state = tf.expand_dims(projected_state, 1)

            assert_shape(projected_state, [-1, 1, self.attention_state_size])

            logits = []

            for proj, bias in zip(self.encoder_projections_for_logits,
                                  self.encoder_attn_biases):

                logits.append(tf.reduce_sum(
                    self.attn_v * tf.tanh(projected_state + proj), [2]) + bias)

            if self._use_sentinels:
                sentinel_value = _sentinel(decoder_state,
                                           decoder_prev_state,
                                           decoder_input)
                projected_sentinel, sentinel_logit = self._vector_logit(
                    projected_state, sentinel_value, scope="sentinel")
                logits.append(sentinel_logit)

            attentions = self._renorm_softmax(tf.concat(logits, 1))

            self.attentions_in_time.append(attentions)

            if self._use_sentinels:
                tiled_encoder_projections = self._tile_encoders_for_beamsearch(
                    projected_sentinel)

                projections_concat = tf.concat(
                    tiled_encoder_projections + [projected_sentinel], 1)

            else:
                projections_concat = tf.concat(
                    self.encoder_projections_for_ctx, 1)

            contexts = tf.reduce_sum(
                tf.expand_dims(attentions, 2) * projections_concat, [1])

            return contexts
예제 #16
0
    def __init__(self, mlp_input, layer_configuration, dropout_plc,
                 output_size, name: str = 'multilayer_perceptron',
                 activation_fn=tf.nn.relu) -> None:

        with tf.variable_scope(name):
            last_layer_size = mlp_input.get_shape()[-1].value

            last_layer = multilayer_projection(mlp_input,
                                               layer_configuration,
                                               activation=activation_fn,
                                               dropout_plc=dropout_plc,
                                               scope="deep_output_mlp")
            self.n_params = 0
            for size in layer_configuration:
                self.n_params += last_layer_size * size
                last_layer_size = size

            with tf.variable_scope("classification_layer"):
                self.n_params += last_layer_size * output_size
                self.logits = linear(last_layer, output_size)
예제 #17
0
    def attention(self, decoder_state: tf.Tensor,
                  decoder_prev_state: tf.Tensor, _) -> tf.Tensor:

        with tf.variable_scope(self.scope + "/RecurrentAttn") as varscope:
            initial_state = linear(decoder_state, self._state_size, varscope)
            initial_state = tf.tanh(initial_state)

            # TODO dropout?
            # we'd need the train_mode and dropout_keep_prob parameters

            sentence_lengths = tf.to_int32(tf.reduce_sum(
                self.input_weights, 1))

            _, encoded_tup = tf.nn.bidirectional_dynamic_rnn(
                self.fw_cell,
                self.bw_cell,
                self.attention_states,
                sequence_length=sentence_lengths,
                initial_state_fw=initial_state,
                initial_state_bw=initial_state,
                dtype=tf.float32)

            return tf.concat(encoded_tup, 1)
예제 #18
0
 def _projection(prev_state, prev_output, ctx_tensors, train_mode):
     return linear([prev_state] + ctx_tensors,
                   output_size,
                   scope="AttnOutputProjection")
예제 #19
0
    def _attention_decoder(
        self,
        go_symbols: tf.Tensor,
        train_inputs: tf.Tensor = None,
        attention_on_input=True,
        conditional_gru: bool = False,
        train_mode: bool = False,
        scope: Union[str, tf.VariableScope] = None
    ) -> Tuple[List[tf.Tensor], List[tf.Tensor]]:
        """Run the decoder RNN.

        Arguments:
            go_symbols: The tensor of start symbols of shape (1, batch_size)
            train_inputs: Training inputs to feed the decoder with. These are
                not used when `train_mode = False`
            attention_on_input: Flag whether attention from previous time step
                is fed to the input in the next step.
            conditional_gru: Flag that enables conditional GRU architecture
            train_mode: Boolean flag whether the decoder is running in
                train (with ground truth inputs) or runtime mode (with inputs
                decoded using the loop function)
            scope: Variable scope to use
        """
        att_objects = [
            self.get_attention_object(e, train_mode) for e in self.encoders
        ]
        att_objects = [a for a in att_objects if a is not None]

        cell = self._get_rnn_cell()

        with tf.variable_scope(scope or "attention_decoder"):
            if self._rnn_cell == 'GRU':
                state = self.initial_state
            elif self._rnn_cell == 'LSTM':
                # pylint: disable=redefined-variable-type
                state = tf.nn.rnn_cell.LSTMStateTuple(self.initial_state,
                                                      self.initial_state)
                # pylint: enable=redefined-variable-type
            else:
                raise ValueError("Unknown RNN cell.")

            outputs = []
            prev = None

            attns = [
                tf.zeros([self.batch_size, a.attn_size]) for a in att_objects
            ]
            states = []
            for i in range(self.max_output_len):
                if i > 0:
                    tf.get_variable_scope().reuse_variables()

                if prev is None:
                    assert i == 0
                    inp = go_symbols[0]
                elif train_mode:
                    inp = train_inputs[i - 1]
                else:
                    with tf.variable_scope("loop_function", reuse=True):
                        out_activation = self._logit_function(prev)
                        prev_word_index = tf.argmax(out_activation, 1)
                        inp = self._embed_and_dropout(prev_word_index)

                # Merge input and previous attentions into one vector of the
                # right size.
                if attention_on_input:
                    x = linear([inp] + attns, self.embedding_size)
                else:
                    x = inp

                # Run the RNN.
                cell_output, state = cell(x, state)

                # Run the attention mechanism.
                attns = [a.attention(cell_output) for a in att_objects]

                if conditional_gru:
                    x_2 = linear(attns,
                                 self.embedding_size,
                                 scope="cond_gru_2_linproj")
                    # Run the RNN for the second time
                    cell_output, state = cell(x_2,
                                              state,
                                              scope="cond_gru_2_cell")

                states.append(state)

                with tf.name_scope("rnn_output_projection"):
                    if attns:
                        output = linear([cell_output] + attns,
                                        cell.output_size,
                                        scope="AttnOutputProjection")
                    else:
                        output = cell_output

                prev = output
                outputs.append(output)

        return outputs, states
예제 #20
0
        def body(*args) -> LoopState:
            loop_state = LoopState(*args)
            step = loop_state.step

            with tf.variable_scope(self.step_scope):
                # Compute the input to the RNN
                rnn_input = self.input_projection(*loop_state)

                # Run the RNN.
                cell = self._get_rnn_cell()
                if self._rnn_cell_str == 'GRU':
                    cell_output, state = cell(rnn_input,
                                              loop_state.prev_rnn_output)
                    next_state = state
                    attns = [
                        a.attention(cell_output, loop_state.prev_rnn_output,
                                    rnn_input, att_loop_state, loop_state.step)
                        for a, att_loop_state in zip(
                            att_objects, loop_state.attention_loop_states)
                    ]
                    if att_objects:
                        contexts, att_loop_states = zip(*attns)
                    else:
                        contexts, att_loop_states = [], []

                    if self._conditional_gru:
                        cell_cond = self._get_conditional_gru_cell()
                        cond_input = tf.concat(contexts, -1)
                        cell_output, state = cell_cond(cond_input,
                                                       state,
                                                       scope="cond_gru_2_cell")
                elif self._rnn_cell_str == 'LSTM':
                    prev_state = tf.contrib.rnn.LSTMStateTuple(
                        loop_state.prev_rnn_state, loop_state.prev_rnn_output)
                    cell_output, state = cell(rnn_input, prev_state)
                    next_state = state.c
                    attns = [
                        a.attention(cell_output, loop_state.prev_rnn_output,
                                    rnn_input, att_loop_state, loop_state.step)
                        for a, att_loop_state in zip(
                            att_objects, loop_state.attention_loop_states)
                    ]
                    if att_objects:
                        contexts, att_loop_states = zip(*attns)
                    else:
                        contexts, att_loop_states = [], []
                else:
                    raise ValueError("Unknown RNN cell.")

                with tf.name_scope("rnn_output_projection"):
                    if attns:
                        output = linear([cell_output] + list(contexts),
                                        cell.output_size,
                                        scope="AttnOutputProjection")
                    else:
                        output = cell_output
                        att_loop_states = []

                logits = self._logit_function(output)

            self.step_scope.reuse_variables()

            if sample:
                next_symbols = tf.multinomial(logits, num_samples=1)
            elif train_mode:
                next_symbols = loop_state.train_inputs[step]
            else:
                next_symbols = tf.to_int32(tf.argmax(logits, axis=1))
                int_unfinished_mask = tf.to_int32(
                    tf.logical_not(loop_state.finished))

                # Note this works only when PAD_TOKEN_INDEX is 0. Otherwise
                # this have to be rewritten
                assert PAD_TOKEN_INDEX == 0
                next_symbols = next_symbols * int_unfinished_mask

            has_just_finished = tf.equal(next_symbols, END_TOKEN_INDEX)
            has_finished = tf.logical_or(loop_state.finished,
                                         has_just_finished)

            new_loop_state = LoopState(
                step=step + 1,
                input_symbol=next_symbols,
                train_inputs=loop_state.train_inputs,
                prev_rnn_state=next_state,
                prev_rnn_output=cell_output,
                rnn_outputs=loop_state.rnn_outputs.write(
                    step + 1, cell_output),
                prev_contexts=list(contexts),
                prev_logits=logits,
                logits=loop_state.logits.write(step, logits),
                finished=has_finished,
                mask=loop_state.mask.write(step, tf.logical_not(has_finished)),
                attention_loop_states=list(att_loop_states))
            return new_loop_state
예제 #21
0
    def attention(self, query: tf.Tensor, decoder_prev_state: tf.Tensor,
                  decoder_input: tf.Tensor, loop_state: HierarchicalLoopState,
                  step: tf.Tensor) -> Tuple[tf.Tensor, HierarchicalLoopState]:

        with tf.variable_scope(self.att_scope_name):
            projected_state = linear(query, self.attention_state_size)
            projected_state = tf.expand_dims(projected_state, 1)

            assert_shape(projected_state, [-1, 1, self.attention_state_size])
            attn_ctx_vectors, child_loop_states = zip(*[
                a.attention(query, decoder_prev_state, decoder_input, ls, step)
                for a, ls in zip(self.attentions, loop_state.child_loop_states)
            ])

            proj_ctxs, attn_logits = [
                list(t) for t in zip(*[
                    self._vector_logit(projected_state,
                                       ctx_vec,
                                       scope=att.name)  # type: ignore
                    for ctx_vec, att in zip(attn_ctx_vectors, self.attentions)
                ])
            ]

            if self._use_sentinels:
                sentinel_value = _sentinel(query, decoder_prev_state,
                                           decoder_input)
                proj_sentinel, sentinel_logit = self._vector_logit(
                    projected_state, sentinel_value, scope="sentinel")
                proj_ctxs.append(proj_sentinel)
                attn_logits.append(sentinel_logit)

            attention_distr = tf.nn.softmax(tf.concat(attn_logits, 1))
            self.attentions_in_time.append(attention_distr)

            if self._share_projections:
                output_cxts = proj_ctxs
            else:
                output_cxts = [
                    tf.expand_dims(
                        linear(ctx_vec,
                               self.attention_state_size,
                               scope="proj_attn_{}".format(att.name)),
                        1)  # type: ignore
                    for ctx_vec, att in zip(attn_ctx_vectors, self.attentions)
                ]
                if self._use_sentinels:
                    output_cxts.append(
                        tf.expand_dims(
                            linear(sentinel_value,
                                   self.attention_state_size,
                                   scope="proj_sentinel"), 1))

            projections_concat = tf.concat(output_cxts, 1)
            context = tf.reduce_sum(
                tf.expand_dims(attention_distr, 2) * projections_concat, [1])

            prev_loop_state = loop_state.loop_state
            next_contexts = prev_loop_state.contexts.write(step, context)
            next_weights = prev_loop_state.weights.write(step, attention_distr)

            next_loop_state = AttentionLoopState(contexts=next_contexts,
                                                 weights=next_weights)

            next_hier_loop_state = HierarchicalLoopState(
                child_loop_states=list(child_loop_states),
                loop_state=next_loop_state)

            return context, next_hier_loop_state
예제 #22
0
 def predictions(self):
     return linear(self._mlp_output,
                   self.dimension,
                   scope="output_projection")