Exemplo n.º 1
0
    def call(self, inputs, state):
        # 注意这个wrapper改变了attention操作顺序,注意与原版的区别
        if not isinstance(state, seq2seq.AttentionWrapperState):
            raise TypeError(
                "Expected state to be instance of AttentionWrapperState. "
                "Received type %s instead." % type(state))

        if self._is_multi:
            previous_alignments = state.alignments
            previous_alignment_history = state.alignment_history
        else:
            previous_alignments = [state.alignments]
            previous_alignment_history = [state.alignment_history]

        all_alignments = []
        all_attentions = []
        all_histories = []
        all_attentions_state = []
        for i, attention_mechanism in enumerate(self._attention_mechanisms):
            if isinstance(self._cell, rnn.LSTMCell):
                rnn_cell_state = state.cell_state.h
            else:
                rnn_cell_state = state.cell_state
            # Some of code are changed based on https://github.com/bgshih/aster/issues/2
            attention, alignments, attention_state = _compute_attention(
                attention_mechanism, rnn_cell_state, previous_alignments[i],
                self._attention_layers[i] if self._attention_layers else None)
            alignment_history = previous_alignment_history[i].write(
                state.time, alignments) if self._alignment_history else ()

            all_alignments.append(alignments)
            all_histories.append(alignment_history)
            all_attentions.append(attention)
            all_attentions_state.append(attention_state)

        attention = array_ops.concat(all_attentions, 1)

        cell_inputs = self._cell_input_fn(inputs, attention)
        cell_output, next_cell_state = self._cell(cell_inputs,
                                                  state.cell_state)

        next_state = seq2seq.AttentionWrapperState(
            time=state.time + 1,
            cell_state=next_cell_state,
            attention=attention,
            alignments=self._item_or_tuple(all_alignments),
            alignment_history=self._item_or_tuple(all_histories),
            attention_state=self._item_or_tuple(all_attentions_state))

        if self._output_attention:
            return attention, next_state
        else:
            return cell_output, next_state
Exemplo n.º 2
0
 def get_zero_state(self):
     batch_size, dtype = tf.shape(self.inputs)[0], self.inputs.dtype
     decoder_state0 = self.decoder.zero_state(batch_size, dtype)
     alignments = self.bahdanau(
         decoder_state0,
         self.bahdanau.initial_alignments(batch_size, dtype))
     expanded_alignments = tf.expand_dims(alignments, 1)
     attention_mechanism_values = self.bahdanau.values
     context = expanded_alignments @ attention_mechanism_values
     context1 = tf.squeeze(context, [1])
     t0 = tf.zeros([], dtype=tf.int32)
     return seq2seq.AttentionWrapperState(cell_state=decoder_state0,
                                          time=t0,
                                          alignments=alignments,
                                          alignment_history=(),
                                          attention=context1)
Exemplo n.º 3
0
def decoding_layer(dec_embed_input, embeddings, enc_output, enc_state,
                   vocab_size, text_length, summary_length, max_summary_length,
                   rnn_size, vocab_to_int, keep_prob, batch_size, num_layers):

    for layer in range(num_layers):
        with tf.variable_scope('decoder_{}'.format(layer)):
            lstm = tf.nn.rnn_cell.LSTMCell(
                rnn_size,
                initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))
            dec_cell = tf.nn.rnn_cell.DropoutWrapper(lstm,
                                                     input_keep_prob=keep_prob)
    #全连接层
    output_layer = Dense(vocab_size,
                         kernel_initializer=tf.truncated_normal_initializer(
                             mean=0.0, stddev=0.1))

    attn_mech = seq.BahdanauAttention(rnn_size,
                                      enc_output,
                                      text_length,
                                      normalize=False,
                                      name='BahdanauAttention')

    dec_cell = seq.AttentionWrapper(cell=dec_cell,
                                    attention_mechanism=attn_mech,
                                    attention_layer_size=rnn_size)

    # 引入注意力机制
    initial_state = seq.AttentionWrapperState(
        enc_state[0], _zero_state_tensors(rnn_size, batch_size, tf.float32))

    with tf.variable_scope("decode"):
        training_logits = training_decoding_layer(dec_embed_input,
                                                  summary_length, dec_cell,
                                                  initial_state, output_layer,
                                                  vocab_size,
                                                  max_summary_length)
    with tf.variable_scope("decode", reuse=True):
        inference_logits = inference_decoding_layer(
            embeddings, vocab_to_int['<GO>'], vocab_to_int['<EOS>'], dec_cell,
            initial_state, output_layer, max_summary_length, batch_size)
    return training_logits, inference_logits
    def call(self, inputs, state):
        """Perform a step of attention-wrapped RNN.
        - Step 1: Mix the `inputs` and previous step's `attention` output via
            `cell_input_fn`.
        - Step 2: Call the wrapped `cell` with this input and its previous state.
        - Step 3: Score the cell's output with `attention_mechanism`.
        - Step 4: Calculate the alignments by passing the score through the
            `normalizer`.
        - Step 5: Calculate the context vector as the inner product between the
            alignments and the attention_mechanism's values (memory).
        - Step 6: Calculate the attention output by concatenating the cell output
            and context through the attention layer (a linear layer with
            `attention_layer_size` outputs).
        Args:
            inputs: (Possibly nested tuple of) Tensor, the input at this time step.
            state: An instance of `AttentionWrapperState` containing
            tensors from the previous time step.
        Returns:
            A tuple `(attention_or_cell_output, next_state)`, where:
            - `attention_or_cell_output` depending on `output_attention`.
            - `next_state` is an instance of `AttentionWrapperState`
                containing the state calculated at this time step.
        Raises:
            TypeError: If `state` is not an instance of `AttentionWrapperState`.
        """
        if not isinstance(state, tf.contrib.seq2seq.AttentionWrapperState):
            raise TypeError(
                "Expected state to be instance of AttentionWrapperState. "
                "Received type %s instead." % type(state))

        # Step 1: Calculate the true inputs to the cell based on the
        # previous attention value.
        cell_inputs = self._cell_input_fn(inputs, state.attention)
        cell_state = state.cell_state
        cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

        cell_batch_size = (cell_output.shape[0].value
                           or tf.shape(cell_output)[0])
        error_message = (
            "When applying AttentionWrapper %s: " % self.name +
            "Non-matching batch sizes between the memory "
            "(encoder output) and the query (decoder output).  Are you using "
            "the BeamSearchDecoder?  You may need to tile your memory input via "
            "the tf.contrib.seq2seq.tile_batch function with argument "
            "multiple=beam_width.")
        with tf.control_dependencies(
                self._batch_size_checks(cell_batch_size, error_message)):
            cell_output = tf.identity(cell_output, name="checked_cell_output")

        if self._is_multi:
            previous_alignments = state.alignments
            previous_alignment_history = state.alignment_history
        else:
            previous_alignments = [state.alignments]
            previous_alignment_history = [state.alignment_history]

        all_alignments = []
        all_attentions = []
        all_histories = []

        for i, attention_mechanism in enumerate(self._attention_mechanisms):
            if self.coverage:
                # if we use coverage mode, previous alignments is coverage vector
                # alignment history stack has shape:  decoder time * batch * atten_len
                # convert it to coverage vector
                previous_alignments[i] = tf.cond(
                    previous_alignment_history[i].size() > 0,
                    lambda: tf.reduce_sum(tf.transpose(
                        previous_alignment_history[i].stack(), [1, 2, 0]),
                                          axis=2),
                    lambda: tf.zeros_like(previous_alignments[i]))
            # debug
            # previous_alignments[i] = tf.Print(previous_alignments[i],[previous_alignment_history[i].size(), tf.shape(previous_alignments[i]),previous_alignments[i]],message="atten wrapper:")
            attention, alignments, _ = _compute_attention(
                attention_mechanism, cell_output, previous_alignments[i],
                self._attention_layers[i] if self._attention_layers else None)
            alignment_history = previous_alignment_history[i].write(
                state.time, alignments) if self._alignment_history else ()

            all_alignments.append(alignments)
            all_histories.append(alignment_history)
            all_attentions.append(attention)

        attention = tf.concat(all_attentions, 1)
        next_state = seq2seq.AttentionWrapperState(
            time=state.time + 1,
            cell_state=next_cell_state,
            attention=attention,
            alignments=self._item_or_tuple(all_alignments),
            alignment_history=self._item_or_tuple(all_histories),
            attention_state=state.attention_state)

        if self._output_attention:
            return attention, next_state
        else:
            return cell_output, next_state
Exemplo n.º 5
0
    def call(self, inputs, state):
        """Perform a step of attention-wrapped RNN.

        - Step 1: Mix the `inputs` and previous step's `attention` output via
          `cell_input_fn`.
        - Step 2: Call the wrapped `cell` with this input and its previous state.
        - Step 3: Score the cell's output with `attention_mechanism`.
        - Step 4: Calculate the alignments by passing the score through the
          `normalizer`.
        - Step 5: Calculate the context vector as the inner product between the
          alignments and the attention_mechanism's values (memory).
        - Step 6: Calculate the attention output by POOLING the cell output
          and context through the attention layer (a linear layer with
          `attention_layer_size` outputs).
        """
        if not isinstance(state, seq2seq.AttentionWrapperState):
            raise TypeError("Expected state to be instance of "
                            "AttentionWrapperState. "
                            "Received type %s instead." % type(state))

        # Step 1: Calculate the true inputs to the cell based on the
        # previous attention value.
        cell_inputs = self._cell_input_fn(inputs, state.attention)
        cell_state = state.cell_state
        cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

        cell_batch_size = (cell_output.shape[0].value
                           or array_ops.shape(cell_output)[0])
        error_message = (
            "When applying AttentionWrapper %s: " % self.name +
            "Non-matching batch sizes between the memory "
            "(encoder output) and the query (decoder output).  Are you using "
            "the BeamSearchDecoder?  You may need to tile your memory input via "
            "the tf.contrib.seq2seq.tile_batch function with argument "
            "multiple=beam_width.")
        with ops.control_dependencies(
                self._batch_size_checks(cell_batch_size, error_message)):
            cell_output = array_ops.identity(cell_output,
                                             name="checked_cell_output")

        if self._is_multi:
            previous_alignments = state.alignments
            previous_alignment_history = state.alignment_history
        else:
            previous_alignments = [state.alignments]
            previous_alignment_history = [state.alignment_history]

        all_alignments = []
        all_attentions = []
        all_histories = []
        for i, attention_mechanism in enumerate(self._attention_mechanisms):
            attention, alignments = _compute_attention(
                attention_mechanism,
                cell_output,
                previous_alignments[i],
                self._attention_layers[i] if self._attention_layers else None,
                reuse=i > 0)
            alignment_history = previous_alignment_history[i].write(
                state.time, alignments) if self._alignment_history else ()

            all_alignments.append(alignments)
            all_histories.append(alignment_history)
            all_attentions.append(attention)

        if self.pooling == 'avgpool':
            attention = tf.reduce_mean(tf.stack(all_attentions, axis=1),
                                       axis=1)
        else:
            raise ValueError('Unknown pooling method')
        next_state = seq2seq.AttentionWrapperState(
            time=state.time + 1,
            cell_state=next_cell_state,
            attention=attention,
            alignments=self._item_or_tuple(all_alignments),
            alignment_history=self._item_or_tuple(all_histories))

        if self._output_attention:
            return attention, next_state
        else:
            return cell_output, next_state
Exemplo n.º 6
0
    def _init_encoder(self):
        with tf.variable_scope("Encoder") as scope:

            encoder_inputs = self._maybe_add_dense_layers()
            cell_type = self._hparams.cell_type[0 if self._data_type ==
                                                'video' else 1]
            batch_size = self._hparams.batch_size[0 if self._mode ==
                                                  'train' else 1]
            if self._hparams.encoder_type == 'unidirectional':
                self._encoder_cells, initial_state = build_rnn_layers(
                    cell_type=cell_type,
                    num_units_per_layer=self._num_units_per_layer,
                    use_dropout=self._hparams.use_dropout,
                    dropout_probability=self._dropout_probability,
                    mode=self._mode,
                    as_list=True,
                    batch_size=batch_size,
                    dtype=self._hparams.dtype)

                #self._encoder_cells, initial_state = create_model(model=cell_type,
                #                                     num_cells=self._num_units_per_layer,
                #                                     batch_size=batch_size,
                #                                     as_list=False if cell_type == 'multi_skip_lstm' else True,
                #                                     learn_initial_state=True if 'skip' in cell_type else False,
                #                                     use_dropout=self._hparams.use_dropout,
                #                                     dropout_probability=self._dropout_probability)

                self._encoder_cells = maybe_list(self._encoder_cells)
                print(self._num_units_per_layer)
                print('encoder_cells', self._encoder_cells)
                print('AttentiveEncoder_initial_state', initial_state)

                #### here weird code

                # 1. reverse mem
                # self._attended_memory = tf.reverse(self._attended_memory, axis=[1])

                # 2. append zeros
                # randval1 = tf.random.uniform(shape=[], minval=25, maxval=100, dtype=tf.int32)
                # randval2 = tf.random.uniform(shape=[], minval=25, maxval=100, dtype=tf.int32)
                # zeros_slice1 = tf.zeros([1, randval1, 256], dtype=tf.float32)  # assuming we use inference on a batch size of 1
                # zeros_slice2 = tf.zeros([1, randval2, 256], dtype=tf.float32)
                # self._attended_memory = tf.concat([zeros_slice1, self._attended_memory, zeros_slice2], axis=1)
                # self._attended_memory_length += randval1 + randval2

                # 3. blank mem
                # self._attended_memory = 0* self._attended_memory

                # 4. mix with noise
                # noise = tf.random.truncated_normal(shape=tf.shape(self._attended_memory))
                # noise = tf.random.uniform(shape=tf.shape(self._attended_memory))

                # self._attended_memory = noise

                #### here stop weird code
                if cell_type == 'skip_lstm' and self._hparams.separate_skip_rnn:
                    skip_cell = self._encoder_cells[0]
                    self._encoder_cells = self._encoder_cells[1:]

                    skip_out = tf.nn.dynamic_rnn(
                        cell=skip_cell,
                        inputs=encoder_inputs,
                        sequence_length=self._inputs_len,
                        parallel_iterations=batch_size,
                        swap_memory=False,
                        dtype=self._hparams.dtype,
                        scope=scope,
                        initial_state=skip_cell.trainable_initial_state(
                            batch_size),
                    )
                    print('skip_out', skip_out)
                    skip_output, skip_final_state = skip_out
                    h, updated_states = skip_output
                    print('skip_encoder_updated_states', updated_states)
                    cost_per_sample = self._hparams.cost_per_sample[1]
                    budget_loss = tf.reduce_mean(
                        tf.reduce_sum(cost_per_sample * updated_states, 1), 0)
                    meanUpdates = tf.reduce_mean(
                        tf.reduce_sum(updated_states, 1), 0)
                    self.skip_infos = SkipInfoTuple(updated_states,
                                                    meanUpdates, budget_loss)
                    # Tried to remove the skipped states in the output, but this destroys the input shape
                    #updated_states_shape = tf.shape(updated_states)
                    #h_shape = tf.shape(h)
                    #updated_states = tf.reshape(updated_states, [batch_size, updated_states_shape[1]])
                    #new_h = tf.boolean_mask(h, tf.where(updated_states == 1.))
                    #new_h = tf.where(updated_states == 1., h, tf.zeros(shape=h_shape))
                    #print('new_h', new_h)
                    #new_h_shape = tf.shape(new_h)
                    #new_h = tf.reshape(new_h, [batch_size, new_h_shape[0] / batch_size, new_h_shape[1]])
                    #self._inputs_len = (batch_size, new_h_shape[0] / batch_size)

                    print('skip_encoder_layer_h', h)
                print('attended_memory', self._attended_memory)
                print(self._encoder_cells)
                attention_cells, dummy_initial_state = add_attention(
                    cell_type='lstm'
                    if self._hparams.separate_skip_rnn else cell_type,
                    cells=self._encoder_cells[-1],
                    attention_types=self._hparams.attention_type[0],
                    num_units=self._num_units_per_layer[-1],
                    memory=self._attended_memory,
                    memory_len=self._attended_memory_length,
                    mode=self._mode,
                    dtype=self._hparams.dtype,
                    initial_state=initial_state[-1] if
                    (isinstance(initial_state, tuple)
                     and not isinstance(initial_state, SkipLSTMStateTuple))
                    else initial_state,
                    batch_size=tf.shape(self._inputs_len),
                    write_attention_alignment=self._hparams.
                    write_attention_alignment,
                    fusion_type='linear_fusion',
                )
                print('AttentiveEncoder_initial_state2', initial_state)
                if isinstance(initial_state, tuple) and not isinstance(
                        initial_state, SkipLSTMStateTuple):
                    initial_state = list(initial_state)
                    initial_state[-1] = dummy_initial_state
                    initial_state = tuple(initial_state)
                else:
                    initial_state = dummy_initial_state
                self._encoder_cells[-1] = attention_cells

                # initial_state = self._encoder_cells.get_initial_state(inputs=None, batch_size=batch_size, dtype=self._hparams.dtype)
                #initial_state = []
                #for i, cell in enumerate(self._encoder_cells):
                #    print(i, cell)
                #    if isinstance(cell, SkipLSTMCell):
                #        #pass
                #        #with tf.variable_scope(f'layer_{i}') as init_scope:
                #        initial_state.append(cell.zero_state(batch_size, dtype=self._hparams.dtype))
                #    else:
                #        with tf.variable_scope(f'layer_{i}'):
                #            initial_state.append(
                #                cell.get_initial_state(inputs=None, batch_size=batch_size, dtype=self._hparams.dtype))
                #initial_state = tuple(initial_state)

                print('AttentiveEncoder_encoder_cells', self._encoder_cells)
                self._encoder_cells = maybe_multirnn(self._encoder_cells)

                print('AttentiveEncoder_encoder_cells', self._encoder_cells)
                print('AttentiveEncoder_encoder_inputs', encoder_inputs)
                print('AttentiveEncoder_inputs_len', self._inputs_len)
                #initial_state = self._encoder_cells.get_initial_state(batch_size=batch_size, dtype=self._hparams.dtype)
                #print('AttentiveEncoder_initial_state_final', initial_state)
                out = tf.nn.dynamic_rnn(
                    cell=self._encoder_cells,
                    inputs=encoder_inputs
                    if not self._hparams.separate_skip_rnn else h,
                    sequence_length=self._inputs_len,
                    parallel_iterations=self._hparams.
                    batch_size[0 if self._mode == 'train' else 1],
                    swap_memory=False,
                    dtype=self._hparams.dtype,
                    scope=scope,
                    initial_state=initial_state,
                )

                self._encoder_outputs, self._encoder_final_state = out
                print("AttentiveEncoder_dynamic_rnn_out",
                      self._encoder_outputs)
                print("AttentiveEncoder_dynamic_rnn_fs",
                      self._encoder_final_state)

                if not self._hparams.separate_skip_rnn and 'skip' in cell_type:
                    self._encoder_outputs, updated_states = self._encoder_outputs
                    cost_per_sample = self._hparams.cost_per_sample[1]
                    budget_loss = tf.reduce_mean(
                        tf.reduce_sum(cost_per_sample * updated_states, 1), 0)
                    meanUpdates = tf.reduce_mean(
                        tf.reduce_sum(updated_states, 1), 0)
                    self.skip_infos = SkipInfoTuple(updated_states,
                                                    meanUpdates, budget_loss)

                    if isinstance(self._encoder_final_state,
                                  tuple) and not isinstance(
                                      self._encoder_final_state,
                                      seq2seq.AttentionWrapperState):
                        self._encoder_final_state = self._encoder_final_state[
                            -1]
                    print('AttentiveEncoder_final_state_inBetween',
                          self._encoder_final_state)
                    if isinstance(self._encoder_final_state,
                                  seq2seq.AttentionWrapperState):
                        cell_state = self._encoder_final_state.cell_state
                        try:
                            cell_state = [
                                LSTMStateTuple(cs.c, cs.h) for cs in cell_state
                            ]
                        except:
                            cell_state = LSTMStateTuple(
                                cell_state.c, cell_state.h)
                        self._encoder_final_state = seq2seq.AttentionWrapperState(
                            cell_state, self._encoder_final_state.attention,
                            self._encoder_final_state.time,
                            self._encoder_final_state.alignments,
                            self._encoder_final_state.alignment_history,
                            self._encoder_final_state.attention_state)
                    print('AttentiveEncoder_final_state',
                          self._encoder_final_state)

                if self._hparams.write_attention_alignment is True:
                    # self.weights_summary = self._encoder_final_state[-1].attention_weight_history.stack()
                    self.attention_summary, self.attention_alignment = self._create_attention_alignments_summary(
                        maybe_list(self._encoder_final_state)[-1])