Exemplo n.º 1
0
    def BiRNNAtt(x, attention, layer) :
        with tf.variable_scope('encoder_sentence_{}'.format(layer),reuse=False):
            # Prepare data shape to match `rnn` function requirements
            # Current data input shape: (batch_size, timesteps, n_input)
            # Required shape: 'timesteps' tensors list of shape (batch_size, num_input)

            # Unstack to get a list of 'timesteps' tensors of shape (batch_size, num_input)
            # x = tf.unstack(x, max_length, 1)

            # Define lstm cells with tensorflow
            # Forward direction cell
            lstm_fw_cell = rnn.BasicLSTMCell(dims, forget_bias=1.0)
            lstm_fw_att = seq2seq.AttentionWrapper(lstm_fw_cell, attention)
            # Backward direction cell
            lstm_bw_cell = rnn.BasicLSTMCell(dims, forget_bias=1.0)
            lstm_bw_att = seq2seq.AttentionWrapper(lstm_bw_cell, attention)


            ((fw_outputs, bw_outputs), (fw_states, bw_states)) = tf.nn.bidirectional_dynamic_rnn(lstm_fw_att,
                                                                                                 lstm_bw_att,
                                                                                                 x,
                                                                                                 dtype=tf.float32)
            outputs = tf.concat([fw_outputs, bw_outputs], axis=2)


            # print("BiLSTM lengths: ", len(outputs))
            # Linear activation, using rnn inner loop last output
            return outputs
Exemplo n.º 2
0
    def getBeamSearchDecoderCell(self, encoder_outputs, encoder_final_states):
        basic_cells = [self.get_basicLSTMCell() for i in range(layer_num)]
        basic_cell = tf.nn.rnn_cell.MultiRNNCell(basic_cells)
        tiled_encoder_outputs = seq2seq.tile_batch(encoder_outputs,
                                                   multiplier=beam_size)
        tiled_encoder_final_states = [
            seq2seq.tile_batch(state, multiplier=beam_size)
            for state in encoder_final_states
        ]
        tiled_sequence_length = seq2seq.tile_batch(self.enc_len,
                                                   multiplier=beam_size)
        initial_state = tuple(tiled_encoder_final_states)
        #attention
        attention_mechanism = seq2seq.BahdanauAttention(
            num_units=num_units,
            memory=tiled_encoder_outputs,
            memory_sequence_length=tiled_sequence_length)
        att_cell = seq2seq.AttentionWrapper(
            basic_cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=num_units,
            alignment_history=False,
            cell_input_fn=None,
            initial_cell_state=initial_state)

        initial_state = att_cell.zero_state(
            batch_size=tf.shape(self.enc_in)[0] * beam_size, dtype=tf.float32)
        #            att_state.clone(cell_state=encoder_final_state)

        return att_cell, initial_state
Exemplo n.º 3
0
def create_decoder_cell(agenda,
                        base_sent_embeds,
                        mev_st,
                        mev_ts,
                        base_length,
                        iw_length,
                        dw_length,
                        attn_dim,
                        hidden_dim,
                        num_layer,
                        enable_alignment_history=False,
                        enable_dropout=False,
                        dropout_keep=0.1,
                        no_insert_delete_attn=False):
    base_attn = seq2seq.BahdanauAttention(attn_dim,
                                          base_sent_embeds,
                                          base_length,
                                          name='src_attn')

    cnx_src, micro_evs_st = mev_st
    mev_st_attn = seq2seq.BahdanauAttention(attn_dim,
                                            cnx_src,
                                            iw_length,
                                            name='mev_st_attn')
    mev_st_attn._values = micro_evs_st

    attns = [base_attn, mev_st_attn]

    if not no_insert_delete_attn:
        cnx_tgt, micro_evs_ts = mev_ts
        mev_ts_attn = seq2seq.BahdanauAttention(attn_dim,
                                                cnx_tgt,
                                                dw_length,
                                                name='mev_ts_attn')
        mev_ts_attn._values = micro_evs_ts

        attns += [mev_ts_attn]

    bottom_cell = tf_rnn.LSTMCell(hidden_dim, name='bottom_cell')
    bottom_attn_cell = seq2seq.AttentionWrapper(
        bottom_cell,
        tuple(attns),
        output_attention=False,
        alignment_history=enable_alignment_history,
        name='att_bottom_cell')

    all_cells = [bottom_attn_cell]

    num_layer -= 1
    for i in range(num_layer):
        cell = tf_rnn.LSTMCell(hidden_dim, name='layer_%s' % (i + 1))
        if enable_dropout and dropout_keep < 1.:
            cell = tf_rnn.DropoutWrapper(cell, output_keep_prob=dropout_keep)

        all_cells.append(cell)

    decoder_cell = AttentionAugmentRNNCell(all_cells)
    decoder_cell.set_agenda(agenda)

    return decoder_cell
Exemplo n.º 4
0
def create_rnn_cell(unit_type,
                    num_units,
                    num_layers,
                    num_residual_layers,
                    forget_bias,
                    dropout,
                    mode,
                    attention_mechanism=None,
                    attention_num_heads=1,
                    attention_layer_size=None,
                    output_attention=False,
                    single_cell_fn=None,
                    trainable=True):
  """Returns an instance of an RNN cell.

  Args:
    unit_type: A string that specifies the type of the recurrent unit. Must be
      one of {"lstm", "gru", "lstm_norm", "nas"}.
    num_units: An integer for the numner of units per layer.
    num_layers: An integer for the number of recurrent layers.
    num_residual_layers: An integer for the number of residual layers.
    forget_bias: A float for the forget bias in LSTM cells.
    dropout: A float for the recurrent dropout rate.
    mode: TRAIN | EVAL | PREDICT
    attention_mechanism: An instance of tf.contrib.seq2seq.AttentionMechanism.
    attention_num_heads: An integer for the number of attention heads.
    attention_layer_size: Optional integer for the size of the attention layer.
    output_attention: A boolean indicating whether RNN cell outputs attention.
    single_cell_fn: A function for building a single RNN cell.
    trainable: A boolean indicating whether the cell weights are trainable.

  Returns:
    An RNNCell instance.
  """
  cell_list = _cell_list(
      unit_type=unit_type,
      num_units=num_units,
      num_layers=num_layers,
      forget_bias=forget_bias,
      dropout=dropout,
      mode=mode,
      num_residual_layers=num_residual_layers,
      single_cell_fn=single_cell_fn,
      trainable=trainable)

  if len(cell_list) == 1:  # Single layer.
    cell = cell_list[0]
  else:  # Multiple layers.
    cell = contrib_rnn.MultiRNNCell(cell_list)

  # Wrap with attention, if necessary.
  if attention_mechanism is not None:
    cell = contrib_seq2seq.AttentionWrapper(
        cell, [attention_mechanism] * attention_num_heads,
        attention_layer_size=[attention_layer_size] * attention_num_heads,
        alignment_history=False,
        output_attention=output_attention,
        name="attention")

  return cell
Exemplo n.º 5
0
    def decode(self, dec_cell, enc_outputs, ctx_outputs):
        with tf.variable_scope("decode"):
            batch_size = self._batch_size

            attn_mech = seq2seq.BahdanauAttention(self._memory_size,
                                                  enc_outputs,
                                                  self.input_lengths)
            dec_cell = CondWrapper(dec_cell, ctx_outputs)
            dec_cell = seq2seq.AttentionWrapper(dec_cell, attn_mech,
                                                self._memory_size)
            dec_initial_state = dec_cell.zero_state(batch_size=batch_size,
                                                    dtype=tf.float32)
            helper_build_fn = self._infer_helper if self._infer else self._train_helper

            output_layer = layers_core.Dense(self._vocab_size,
                                             use_bias=True,
                                             activation=None)
            decoder = seq2seq.BasicDecoder(cell=dec_cell,
                                           helper=helper_build_fn(),
                                           initial_state=dec_initial_state,
                                           output_layer=output_layer)
            dec_output, dec_state = seq2seq.dynamic_decode(
                decoder,
                impute_finished=True,
                maximum_iterations=self._max_seq_length)
            rnn_output = dec_output.rnn_output
            sample_id = dec_output.sample_id
        return rnn_output, sample_id, dec_state
Exemplo n.º 6
0
    def build_attention_cell(self, encoder_outputs, encoder_states):
        memory = encoder_outputs

        if self.time_major:
            memory = tf.transpose(encoder_outputs, [1, 0, 2])

        attention_mechanism = seq2seq.LuongAttention(
            num_units=self.hps.att_num_units,
            memory=memory,
            memory_sequence_length=self.iterator.target_length)

        cell = rnn.MultiRNNCell([
            self.build_rnn_cell(FLAGS.cell_type)
            for _ in range(self.hps.stack_layers)
        ])

        cell = seq2seq.AttentionWrapper(
            cell,
            attention_mechanism,
            attention_layer_size=self.hps.att_num_units,
            name='attention')

        batch_size = tf.size(self.iterator.source_length)

        decoder_initial_state = cell.zero_state(
            batch_size=batch_size,
            dtype=dtype).clone(cell_state=encoder_states)

        return cell, decoder_initial_state
Exemplo n.º 7
0
 def inference_decode_layer(self, start_token, dec_cell, end_token,
                            output_layer):
     start_tokens = tf.tile(tf.constant([start_token], dtype=tf.int32),
                            [self.batch_size],
                            name='start_token')
     tiled_enc_output = seq2seq.tile_batch(self.enc_output,
                                           multiplier=self.Beam_width)
     tiled_enc_state = seq2seq.tile_batch(self.enc_state,
                                          multiplier=self.Beam_width)
     tiled_source_len = seq2seq.tile_batch(self.source_len,
                                           multiplier=self.Beam_width)
     atten_mech = seq2seq.BahdanauAttention(self.hidden_dim * 2,
                                            tiled_enc_output,
                                            tiled_source_len,
                                            normalize=True)
     decoder_att = seq2seq.AttentionWrapper(dec_cell, atten_mech,
                                            self.hidden_dim * 2)
     initial_state = decoder_att.zero_state(
         self.batch_size * self.Beam_width,
         tf.float32).clone(cell_state=tiled_enc_state)
     decoder = seq2seq.BeamSearchDecoder(decoder_att,
                                         self.embeddings,
                                         start_tokens,
                                         end_token,
                                         initial_state,
                                         beam_width=self.Beam_width,
                                         output_layer=output_layer)
     infer_logits, _, _ = seq2seq.dynamic_decode(decoder, False, False,
                                                 self.max_target_len)
     return infer_logits
Exemplo n.º 8
0
    def train_decode_layer(self, dec_embeddig_input, dec_cell, output_layer):
        atten_mech = seq2seq.BahdanauAttention(
            num_units=self.hidden_dim * 2,
            memory=self.enc_output,
            memory_sequence_length=self.target_len,
            normalize=True,
            name='BahadanauAttention')
        dec_cell = seq2seq.AttentionWrapper(dec_cell,
                                            atten_mech,
                                            self.hidden_dim * 2,
                                            name='dec_attention_cell')

        initial_state = dec_cell.zero_state(
            batch_size=self.batch_size,
            dtype=tf.float32).clone(cell_state=self.enc_state)

        train_helper = seq2seq.TrainingHelper(dec_embeddig_input,
                                              self.target_len)
        training_decoder = seq2seq.BasicDecoder(dec_cell,
                                                train_helper,
                                                initial_state=initial_state,
                                                output_layer=output_layer)
        train_logits, _, _ = seq2seq.dynamic_decode(
            training_decoder,
            output_time_major=False,
            impute_finished=False,
            maximum_iterations=self.max_target_len)
        return train_logits
Exemplo n.º 9
0
def add_attention(
    cells,
    attention_types,
    num_units,
    memory,
    memory_len,
    mode,
    batch_size,
    dtype,
    beam_search=False,
    beam_width=None,
    initial_state=None,
    write_attention_alignment=False,
    fusion_type='linear_fusion',
):
    r"""
    Wraps the decoder_cells with an AttentionWrapper
    Args:
        cells: instances of `RNNCell`
        beam_search: `bool` flag for beam search decoders
        batch_size: `Tensor` containing the batch size. Necessary to the initialisation of the initial state

    Returns:
        attention_cells: the Attention wrapped decoder cells
        initial_state: a proper initial state to be used with the returned cells
    """
    attention_mechanisms, attention_layers, attention_layer_sizes, output_attention = create_attention_mechanisms(
        beam_search=beam_search,
        beam_width=beam_width,
        memory=memory,
        memory_len=memory_len,
        num_units=num_units,
        attention_types=attention_types,
        fusion_type=fusion_type,
        mode=mode,
        dtype=dtype)

    if beam_search is True:
        initial_state = seq2seq.tile_batch(initial_state,
                                           multiplier=beam_width)

    attention_cells = seq2seq.AttentionWrapper(
        cell=cells,
        attention_mechanism=attention_mechanisms,
        attention_layer_size=attention_layer_sizes,
        # initial_cell_state=decoder_initial_state,
        alignment_history=write_attention_alignment,
        output_attention=output_attention,
        attention_layer=attention_layers,
    )

    attn_zero = attention_cells.zero_state(
        dtype=dtype,
        batch_size=batch_size *
        beam_width if beam_search is True else batch_size)

    if initial_state is not None:
        initial_state = attn_zero.clone(cell_state=initial_state)

    return attention_cells, initial_state
Exemplo n.º 10
0
    def decoding_layer(self,dec_embed_input,embeddings,enc_output,enc_state,
                       vocab_size,text_len,summary_len,max_sum_len):

        lstm = rnn.LSTMCell(self.hidden_dim * 2,initializer=tf.random_normal_initializer(-0.1,0.1,seed=2))

        dec_cell = rnn.DropoutWrapper(lstm,input_keep_prob=self.keep_prob,)

        output_layer = tf.layers.Dense(vocab_size,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))

        attn_mech = seq2seq.BahdanauAttention(self.hidden_dim * 2,
                                              enc_output,
                                              text_len,
                                              normalize=False,name='BahdanauAttention')

        dec_cell = seq2seq.AttentionWrapper(dec_cell,attn_mech,attention_layer_size=self.hidden_dim * 2)

        # initial_state = seq2seq.AttentionWrapperState(enc_state[0],_zero_state_tensors(self.hidden_dim,batch_size,
        #                                                                                tf.float32))
        initial_state = dec_cell.zero_state(self.batch_size,tf.float32).clone(cell_state=LSTMStateTuple(*enc_state))

        with tf.variable_scope('decode'):
            traing_logits = self.training_decoding_layer(dec_embed_input,summary_len,dec_cell,initial_state,
                                                         output_layer,max_sum_len)

        with tf.variable_scope('decode',reuse=True):
            inference_logits = self.inference_decoding_layer(embeddings,self.vocab_to_int['<GO>'],
                                                             self.vocab_to_int['<EOS>'],dec_cell,
                                                             initial_state,output_layer,max_sum_len)

        return traing_logits, inference_logits
Exemplo n.º 11
0
    def _build_decoder(self):
        """ Decode keyword and context into a sequence of vectors. """
        self.sequence_decoder = tf.placeholder(
            dtype=tf.float32,
            shape=[_BATCH_SIZE, None, CHAR_VEC_DIM],
            name='context')
        self.length_decoder = tf.placeholder(dtype=tf.int32,
                                             shape=[_BATCH_SIZE],
                                             name='length_keywords')
        attention = seq2seq.BahdanauAttention(
            _NUM_UNITS,
            memory=self.encoder_outputs,
            memory_sequence_length=self.context_length,
            name="BahdanauAttention")
        cell_attention = tf.contrib.rnn.GRUCell(_NUM_UNITS)
        attention_wrapper = seq2seq.AttentionWrapper(cell_attention, attention)

        self.initial_decode_state = attention_wrapper.zero_state(
            _BATCH_SIZE,
            dtype=tf.float32).clone(cell_state=self.states_keywords)

        self.decoder_outputs, self.decoder_final_state = tf.nn.dynamic_rnn(
            attention_wrapper,
            self.sequence_decoder,
            sequence_length=self.length_decoder,
            initial_state=self.initial_decode_state,
            dtype=tf.float32,
            time_major=False)
Exemplo n.º 12
0
    def _build_model(self,
                     batch_size,
                     helper_build_fn,
                     decoder_maxiters=None,
                     alignment_history=False):
        # embed input_data into a one-hot representation
        inputs = tf.one_hot(self.input_data,
                            self._input_size,
                            dtype=self._dtype)
        inputs_len = self.input_lengths

        with tf.name_scope('conv-encoder'):
            W = tf.Variable(tf.truncated_normal(
                [3, self._input_size, self._enc_size], stddev=0.1),
                            name="conv-weights")
            b = tf.Variable(tf.truncated_normal([self._enc_size], stddev=0.1),
                            name="conv-bias")
            enc_out = tf.nn.elu(
                tf.nn.conv1d(inputs, W, stride=1, padding='SAME') + b)

        with tf.name_scope('attn-decoder'):
            dec_cell_in1 = rnn.GRUCell(self._dec_size)
            dec_cell_in2 = rnn.GRUCell(self._dec_size)
            memory = enc_out
            attn_mech = seq2seq.LuongMonotonicAttention(
                self._enc_size,
                memory,
                memory_sequence_length=inputs_len,
                sigmoid_noise=0.5,
                score_bias_init=-4.,
                mode='recursive',
                scale=True)
            dec_cell_attn = rnn.MultiRNNCell(
                [rnn.GRUCell(self._dec_size),
                 rnn.GRUCell(self._enc_size)],
                state_is_tuple=True)
            dec_cell_attn = seq2seq.AttentionWrapper(
                dec_cell_attn,
                attn_mech,
                attention_layer_size=self._enc_size,
                alignment_history=alignment_history)
            dec_cell_out = rnn.GRUCell(self._output_size)
            dec_cell = rnn.MultiRNNCell(
                [dec_cell_in1, dec_cell_in2, dec_cell_attn, dec_cell_out],
                state_is_tuple=True)

            dec = seq2seq.BasicDecoder(
                dec_cell, helper_build_fn(),
                dec_cell.zero_state(batch_size, self._dtype))

            dec_out, dec_state, _ = seq2seq.dynamic_decode(
                dec,
                output_time_major=False,
                maximum_iterations=decoder_maxiters,
                impute_finished=True)

        self.outputs = dec_out.rnn_output
        self.output_ids = dec_out.sample_id
        self.final_state = dec_state
Exemplo n.º 13
0
    def _build_decoder_beam_search(self):

        batch_size, _ = tf.unstack(tf.shape(self._labels))

        attention_mechanisms, layer_sizes = self._create_attention_mechanisms(
            beam_search=True)

        decoder_initial_state_tiled = seq2seq.tile_batch(
            self._decoder_initial_state, multiplier=self._hparams.beam_width)

        if self._hparams.enable_attention is True:

            attention_cells = seq2seq.AttentionWrapper(
                cell=self._decoder_cells,
                attention_mechanism=attention_mechanisms,
                attention_layer_size=layer_sizes,
                initial_cell_state=decoder_initial_state_tiled,
                alignment_history=self._hparams.write_attention_alignment,
                output_attention=self._output_attention)

            initial_state = attention_cells.zero_state(
                dtype=self._hparams.dtype,
                batch_size=batch_size * self._hparams.beam_width)

            initial_state = initial_state.clone(
                cell_state=decoder_initial_state_tiled)

            cells = attention_cells
        else:
            cells = self._decoder_cells
            initial_state = decoder_initial_state_tiled

        self._decoder_inference = seq2seq.BeamSearchDecoder(
            cell=cells,
            embedding=self._embedding_matrix,
            start_tokens=array_ops.fill([batch_size], self._GO_ID),
            end_token=self._EOS_ID,
            initial_state=initial_state,
            beam_width=self._hparams.beam_width,
            output_layer=self._dense_layer,
            length_penalty_weight=0.5,
        )

        outputs, states, lengths = seq2seq.dynamic_decode(
            self._decoder_inference,
            impute_finished=False,
            maximum_iterations=self._hparams.max_label_length,
            swap_memory=False)

        if self._hparams.write_attention_alignment is True:
            self.attention_summary = self._create_attention_alignments_summary(
                states)

        self.inference_outputs = outputs.beam_search_decoder_output
        self.inference_predicted_ids = outputs.predicted_ids[:, :,
                                                             0]  # return the first beam
        self.inference_predicted_beam = outputs.predicted_ids
        self.beam_search_output = outputs.beam_search_decoder_output
Exemplo n.º 14
0
    def _build_model(self,
                     batch_size,
                     helper_build_fn,
                     decoder_maxiters=None,
                     alignment_history=False):
        # embed input_data into a one-hot representation
        inputs = tf.one_hot(self.input_data,
                            self._input_size,
                            dtype=self._dtype)
        inputs_len = self.input_lengths

        with tf.name_scope('bidir-encoder'):
            fw_cell = rnn.MultiRNNCell(
                [rnn.BasicRNNCell(self._enc_rnn_size) for i in range(3)],
                state_is_tuple=True)
            bw_cell = rnn.MultiRNNCell(
                [rnn.BasicRNNCell(self._enc_rnn_size) for i in range(3)],
                state_is_tuple=True)
            fw_cell_zero = fw_cell.zero_state(batch_size, self._dtype)
            bw_cell_zero = bw_cell.zero_state(batch_size, self._dtype)

            enc_out, _ = tf.nn.bidirectional_dynamic_rnn(
                fw_cell,
                bw_cell,
                inputs,
                sequence_length=inputs_len,
                initial_state_fw=fw_cell_zero,
                initial_state_bw=bw_cell_zero)

        with tf.name_scope('attn-decoder'):
            dec_cell_in = rnn.GRUCell(self._dec_rnn_size)
            attn_values = tf.concat(enc_out, 2)
            attn_mech = seq2seq.BahdanauAttention(self._enc_rnn_size * 2,
                                                  attn_values, inputs_len)
            dec_cell_attn = rnn.GRUCell(self._enc_rnn_size * 2)
            dec_cell_attn = seq2seq.AttentionWrapper(
                dec_cell_attn,
                attn_mech,
                self._enc_rnn_size * 2,
                alignment_history=alignment_history)
            dec_cell_out = rnn.GRUCell(self._output_size)
            dec_cell = rnn.MultiRNNCell(
                [dec_cell_in, dec_cell_attn, dec_cell_out],
                state_is_tuple=True)

            dec = seq2seq.BasicDecoder(
                dec_cell, helper_build_fn(),
                dec_cell.zero_state(batch_size, self._dtype))

            dec_out, dec_state = seq2seq.dynamic_decode(
                dec,
                output_time_major=False,
                maximum_iterations=decoder_maxiters,
                impute_finished=True)

        self.outputs = dec_out.rnn_output
        self.output_ids = dec_out.sample_id
        self.final_state = dec_state
    def build_decoder_cell(self):

        if self.use_beamsearch_decode:
            encoder_outputs = tf.contrib.seq2seq.tile_batch(
                self.encoder_outputs, multiplier=self.beam_width)
            encoder_last_state = tf.contrib.seq2seq.tile_batch(
                self.encoder_last_state, multiplier=self.beam_width)
            encoder_inputs_length = tf.contrib.seq2seq.tile_batch(
                self.encoder_inputs_length, multiplier=self.beam_width)
        else:
            encoder_outputs = self.encoder_outputs
            encoder_last_state = self.encoder_last_state
            encoder_inputs_length = self.encoder_inputs_length

        self.attention_mechanism = seq2seq.BahdanauAttention(
            num_units=self.decoder_hidden_units,
            memory=encoder_outputs,
            memory_sequence_length=encoder_inputs_length)

        self.decoder_cell_list = [
            self.build_single_cell(self.decoder_hidden_units)
            for _ in range(self.depth)
        ]

        # NOTE(sdsuo): Not sure what this does yet
        def attn_decoder_input_fn(inputs, attention):
            if not self.attn_input_feeding:
                return inputs

            # Essential when use_residual=True
            _input_layer = Dense(self.decoder_hidden_units,
                                 dtype=self.dtype,
                                 name='attn_input_feeding')
            return _input_layer(rnn.array_ops.concat([inputs, attention], -1))

        # Attention mechanism is implemented only on all decoder layers
        self.decoder_cell_list = seq2seq.AttentionWrapper(
            cell=rnn.MultiRNNCell(self.decoder_cell_list),
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.decoder_hidden_units,
            cell_input_fn=attn_decoder_input_fn,
            initial_cell_state=encoder_last_state,
            alignment_history=False,
            name='attention_wrapper')

        if self.use_beamsearch_decode:
            batch_size = self.batch_size * self.beam_width
        else:
            batch_size = self.batch_size

        # add by Meng
        decoder_initial_state = self.decoder_cell_list.zero_state(
            batch_size=batch_size,
            dtype=self.dtype).clone(cell_state=encoder_last_state)

        return self.decoder_cell_list, decoder_initial_state
Exemplo n.º 16
0
    def build_decoder_cell(self):
        # TODO(sdsuo): Read up and decide whether to use beam search
        self.attention_mechanism = seq2seq.BahdanauAttention(
            num_units=self.decoder_hidden_units,
            memory=self.encoder_outputs,
            memory_sequence_length=self.encoder_inputs_length
        )

        self.decoder_cell_list = [
            self.build_single_cell(self.decoder_hidden_units) for _ in range(self.depth)
        ]

        # NOTE(sdsuo): Not sure what this does yet
        def attn_decoder_input_fn(inputs, attention):
            if not self.attn_input_feeding:
                return inputs

            # Essential when use_residual=True
            _input_layer = Dense(self.decoder_hidden_units, dtype=self.dtype,
                                 name='attn_input_feeding')
            return _input_layer(rnn.array_ops.concat([inputs, attention], -1))

        # NOTE(sdsuo): Attention mechanism is implemented only on the top decoder layer
        self.decoder_cell_list[-1] = seq2seq.AttentionWrapper(
            cell=self.decoder_cell_list[-1],
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.decoder_hidden_units,
            cell_input_fn=attn_decoder_input_fn,
            initial_cell_state=self.encoder_last_state[-1],
            alignment_history=False,
            name='attention_wrapper'
        )

        # NOTE(sdsuo): Not sure why this is necessary
        # To be compatible with AttentionWrapper, the encoder last state
        # of the top layer should be converted into the AttentionWrapperState form
        # We can easily do this by calling AttentionWrapper.zero_state

        # Also if beamsearch decoding is used, the batch_size argument in .zero_state
        # should be ${decoder_beam_width} times to the origianl batch_size
        if self.use_beamsearch_decode:
            batch_size = self.batch_size * self.beam_width
        else:
            batch_size = self.batch_size
        # NOTE(vera): important dimension here
        # embed()
        initial_state = [state for state in self.encoder_last_state]
        initial_state[-1] = self.decoder_cell_list[-1].zero_state(
            batch_size=batch_size,
            dtype=self.dtype
        )
        decoder_initial_state = tuple(initial_state)


        return rnn.MultiRNNCell(self.decoder_cell_list), decoder_initial_state
Exemplo n.º 17
0
    def _build_decoder(self, encoder_states, target_sequence, keep_prob,
                       sampling_prob, attention_mechanism):
        """Define decoder architecture.
        """
        # connect each layer sequentially, building a graph that resembles a
        # feed-forward network made of recurrent units
        decoder_cell = self._multi_cell(num_units=self.num_units,
                                        num_layers=self.num_layers,
                                        keep_prob=keep_prob)

        # connect attention to decoder
        attention_layer_size = self.num_units
        decoder = seq2seq.AttentionWrapper(
            cell=decoder_cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=attention_layer_size)

        # decoder start symbol
        decoder_raw_seq = target_sequence[:, :-1]
        prefix = tf.fill([tf.shape(target_sequence)[0], 1, self.target_depth],
                         0.0)
        decoder_input_seq = tf.concat([prefix, decoder_raw_seq], axis=1)

        # the model is using fixed lengths of target sequences so tile the defined
        # length in the batch dimension
        decoder_sequence_length = tf.tile([self.target_length],
                                          [tf.shape(target_sequence)[0]])

        # decoder sampling scheduler feeds decoder output to next time input
        # instead of using ground-truth target vals during training
        helper = seq2seq.ScheduledOutputTrainingHelper(
            inputs=decoder_input_seq,
            sequence_length=decoder_sequence_length,
            sampling_probability=sampling_prob)

        # output layer
        projection_layer = Dense(units=self.target_depth, use_bias=True)

        # clone encoder state
        initial_state = decoder.zero_state(
            tf.shape(target_sequence)[0], tf.float32)
        initial_state = initial_state.clone(cell_state=encoder_states)

        # wrapper for decoder
        decoder = seq2seq.BasicDecoder(cell=decoder,
                                       helper=helper,
                                       initial_state=initial_state,
                                       output_layer=projection_layer)

        # build the unrolled graph of the recurrent neural network
        outputs, decoder_state, _sequence_lengths = seq2seq.dynamic_decode(
            decoder=decoder, maximum_iterations=self.target_length)

        return (outputs, decoder_state)
Exemplo n.º 18
0
    def decoder_cell(self, inputs, lengths):
        attention_mechanism = seq2seq.LuongAttention(
            num_units=self.layer_size,
            memory=inputs,
            memory_sequence_length=lengths,
            scale=True)

        return seq2seq.AttentionWrapper(
            cell=self.cell(),
            attention_mechanism=attention_mechanism,
            attention_layer_size=self.layer_size)
Exemplo n.º 19
0
    def build_decode_cell(self):
        encoder_outputs = self.encoder_outputs
        encoder_last_state = self.encoder_last_state
        encoder_inputs_length = self.encoder_inputs_length
        # Building attention mechanism: Default Bahdanau
        # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473
        self.attention_mechanism = seq2seq.BahdanauAttention(
            num_units=self.hidden_units,
            memory=encoder_outputs,
            memory_sequence_length=encoder_inputs_length
        )

        if self.attention_type.lower() == 'luong':
            self.attention_mechanism = seq2seq.LuongAttention(
                num_units=self.hidden_units,
                memory=encoder_outputs,
                memory_sequence_length=encoder_inputs_length
            )

        # decoder_cell
        self.decoder_cell_list = [self.build_single_cell(layer=2) for _ in range(self.depth)]

        def attn_decoder_input_fn(inputs, attention):
            if not self.attn_input_feeding:
                return inputs
            # Essential when use_residual=True
            _input_layer = Dense(self.hidden_units * 2, dtype=self.dtype, name='attn_input_feeding')
            return _input_layer(tf.concat([inputs, attention], -1))

        # AttentionWrapper wraps RNNCell with the attention_mechanism
        # Note: We implement Attention mechanism only on the top decoder layer
        self.decoder_cell_list[-1] = seq2seq.AttentionWrapper(
            cell=self.decoder_cell_list[-1],
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.hidden_units,
            cell_input_fn=attn_decoder_input_fn,
            initial_cell_state=encoder_last_state[-1],
            alignment_history=False,
            name='Attention_wrapper'
        )

        # To be compatible with AttentionWrapper, the encoder last state
        # of the top layer should be converted into the AttentionWrapperState form
        # We can easily do this by calling AttentionWrapper.zero_state

        batch_size = self.batch_size
        initial_state = [state for state in encoder_last_state]
        initial_state[-1] = self.decoder_cell_list[-1].zero_state(
            batch_size=batch_size, dtype=self.dtype
        )
        decoder_initial_state = tuple(initial_state)
        return rnn.MultiRNNCell(self.decoder_cell_list), decoder_initial_state
Exemplo n.º 20
0
    def _build_decoder_train(self):

        self._labels_embedded = tf.nn.embedding_lookup(self._embedding_matrix,
                                                       self._labels_padded_GO)

        self._helper_train = seq2seq.ScheduledEmbeddingTrainingHelper(
            inputs=self._labels_embedded,
            sequence_length=self._labels_len,
            embedding=self._embedding_matrix,
            sampling_probability=self._sampling_probability_outputs,
        )

        if self._hparams.enable_attention is True:
            attention_mechanisms, layer_sizes = self._create_attention_mechanisms(
            )

            attention_cells = seq2seq.AttentionWrapper(
                cell=self._decoder_cells,
                attention_mechanism=attention_mechanisms,
                attention_layer_size=layer_sizes,
                initial_cell_state=self._decoder_initial_state,
                alignment_history=False,
                output_attention=self._output_attention,
            )
            batch_size, _ = tf.unstack(tf.shape(self._labels))

            attn_zero = attention_cells.zero_state(dtype=self._hparams.dtype,
                                                   batch_size=batch_size)
            initial_state = attn_zero.clone(
                cell_state=self._decoder_initial_state)

            cells = attention_cells
        else:
            cells = self._decoder_cells
            initial_state = self._decoder_initial_state

        self._decoder_train = seq2seq.BasicDecoder(
            cell=cells,
            helper=self._helper_train,
            initial_state=initial_state,
            output_layer=self._dense_layer,
        )

        self._basic_decoder_train_outputs, self._final_states, self._final_seq_lens = seq2seq.dynamic_decode(
            self._decoder_train,
            output_time_major=False,
            impute_finished=True,
            swap_memory=False,
        )

        self._logits = self._basic_decoder_train_outputs.rnn_output
Exemplo n.º 21
0
    def _init_encoder(self):
        with tf.variable_scope("Encoder") as scope:

            encoder_inputs = self._maybe_add_dense_layers()

            if self._hparams.encoder_type == 'unidirectional':
                self._encoder_cells = build_rnn_layers(
                    cell_type=self._hparams.cell_type,
                    num_units_per_layer=self._num_units_per_layer,
                    use_dropout=self._hparams.use_dropout,
                    dropout_probability=self._hparams.dropout_probability,
                    mode=self._mode,
                    as_list=True,
                    dtype=self._hparams.dtype)

                attention_mechanism, output_attention = create_attention_mechanism(
                    attention_type=self._hparams.attention_type[0][0],
                    num_units=self._num_units_per_layer[-1],
                    memory=self._attended_memory,
                    memory_sequence_length=self._attended_memory_length,
                    mode=self._mode,
                    dtype=self._hparams.dtype)

                attention_cells = seq2seq.AttentionWrapper(
                    cell=self._encoder_cells[-1],
                    attention_mechanism=attention_mechanism,
                    attention_layer_size=self._hparams.
                    decoder_units_per_layer[-1],
                    alignment_history=self._hparams.write_attention_alignment,
                    output_attention=output_attention,
                )

                self._encoder_cells[-1] = attention_cells

                self._encoder_outputs, self._encoder_final_state = tf.nn.dynamic_rnn(
                    cell=MultiRNNCell(self._encoder_cells),
                    inputs=encoder_inputs,
                    sequence_length=self._inputs_len,
                    parallel_iterations=self._hparams.
                    batch_size[0 if self._mode == 'train' else 1],
                    swap_memory=False,
                    dtype=self._hparams.dtype,
                    scope=scope,
                )

                if self._hparams.write_attention_alignment is True:
                    self.attention_summary = self._create_attention_alignments_summary(
                        self._encoder_final_state[-1])
Exemplo n.º 22
0
    def _build_decoder_greedy(self):

        batch_size, _ = tf.unstack(tf.shape(self._labels))
        self._helper_greedy = seq2seq.GreedyEmbeddingHelper(
            embedding=self._embedding_matrix,
            start_tokens=tf.tile([self._GO_ID], [batch_size]),
            end_token=self._EOS_ID)

        if self._hparams.enable_attention is True:
            attention_mechanisms, layer_sizes = self._create_attention_mechanisms()

            attention_cells = seq2seq.AttentionWrapper(
                cell=self._decoder_cells,
                attention_mechanism=attention_mechanisms,
                attention_layer_size=layer_sizes,
                initial_cell_state=self._decoder_initial_state,
                alignment_history=self._hparams.write_attention_alignment,
                output_attention=self._output_attention
            )
            attn_zero = attention_cells.zero_state(
                dtype=self._hparams.dtype, batch_size=batch_size
            )
            initial_state = attn_zero.clone(
                cell_state=self._decoder_initial_state
            )
            cells = attention_cells
        else:
            cells = self._decoder_cells
            initial_state = self._decoder_initial_state

        self._decoder_inference = seq2seq.BasicDecoder(
            cell=cells,
            helper=self._helper_greedy,
            initial_state=initial_state,
            output_layer=self._dense_layer)

        outputs, states, lengths = seq2seq.dynamic_decode(
            self._decoder_inference,
            impute_finished=True,
            swap_memory=False,
            maximum_iterations=self._hparams.max_label_length)

        # self._result = outputs, states, lengths
        self.inference_outputs = outputs.rnn_output
        self.inference_predicted_ids = outputs.sample_id

        if self._hparams.write_attention_alignment is True:
            self.attention_summary = self._create_attention_alignments_summary(states, )
Exemplo n.º 23
0
    def _decoder_cell(self):
        batch_size, _ = tf.unstack(tf.shape(self._targets))

        attention = seq2seq.BahdanauAttention(
            num_units=2 * self.CELL_SIZE,
            memory=self._targets_encoder_outputs,
            memory_sequence_length=self._targets_length)

        attentive_cell = seq2seq.AttentionWrapper(
            cell=rnn.GRUCell(2 * self.CELL_SIZE, activation=tf.nn.tanh),
            attention_mechanism=attention,
            attention_layer_size=2 * self.CELL_SIZE,
            initial_cell_state=self._targets_encoder_state)

        return (
            attentive_cell,
            attentive_cell.zero_state(batch_size, tf.float32),
        )
Exemplo n.º 24
0
    def _build_decoder_cell(self, encoder_outputs, encoder_state,
                            source_sequence_length):
        beam_width = self.hparams.beam_width
        if self.hparams.time_major:
            memory = tf.transpose(encoder_outputs, [1, 0, 2])

        if self.mode == PREDICT and beam_width > 0:
            memory = seq2seq.tile_batch(memory, beam_width)
            source_sequence_length = seq2seq.tile_batch(
                source_sequence_length, beam_width)
            encoder_state = seq2seq.tile_batch(encoder_state, beam_width)
            batch_size = self.batch_size * beam_width
        else:
            batch_size = self.batch_size

        # Use Attention Mechanism
        attention_machanism = seq2seq.LuongAttention(
            num_units=self.hparams.num_units,
            memory=memory,
            memory_sequence_length=source_sequence_length)

        cell = model_helper.build_rnn_cell(
            self.hparams.unit_type,
            self.hparams.num_units,
            self.hparams.num_layers,
            self.hparams.dropout,
        )
        alignment_history = (self.mode == PREDICT and beam_width == 0)

        cell = seq2seq.AttentionWrapper(
            cell,
            attention_machanism,
            attention_layer_size=self.hparams.num_units,
            alignment_history=alignment_history,
            name='attention')
        if self.hparams.pass_hidden_state:
            initial_state = cell.zero_state(
                batch_size, tf.float32).clone(cell_state=encoder_state)
        else:
            initial_state = cell.zero_state(batch_size, tf.float32)

        return cell, initial_state
Exemplo n.º 25
0
    def _decoder(self, keep_prob, encoder_output, encoder_state, batch_size, scope, helper, reuse=None):
        with tf.variable_scope(scope, reuse=reuse):
            attention_states = encoder_output
            cell = rnn.MultiRNNCell([self._cell(keep_prob) for _ in range(self.lstm_dims)])
            attention_mechanism = seq2seq.BahdanauAttention(self.hidden_size, attention_states)  # attention
            decoder_cell = seq2seq.AttentionWrapper(cell, attention_mechanism,
                                                    attention_layer_size=self.hidden_size // 2)
            decoder_cell = rnn.OutputProjectionWrapper(decoder_cell, self.hidden_size, reuse=reuse,
                                                       activation=tf.nn.leaky_relu)
            decoder_initial_state = decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=encoder_state)

            output_layer = tf.layers.Dense(self.num_words,
                                           kernel_initializer=tf.contrib.layers.xavier_initializer(),
                                           activation=tf.nn.leaky_relu)
            decoder = seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state, output_layer=output_layer)
            output, _, _ = seq2seq.dynamic_decode(decoder, maximum_iterations=self.max_sentence_length,
                                                  impute_finished=True)

            # tf.summary.histogram('decoder', output)
        return output
Exemplo n.º 26
0
def decoding_layer(dec_embed_input, embeddings, enc_output, enc_state,
                   vocab_size, text_length, summary_length, max_summary_length,
                   rnn_size, vocab_to_int, keep_prob, batch_size, num_layers):

    for layer in range(num_layers):
        with tf.variable_scope('decoder_{}'.format(layer)):
            lstm = tf.nn.rnn_cell.LSTMCell(
                rnn_size,
                initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))
            dec_cell = tf.nn.rnn_cell.DropoutWrapper(lstm,
                                                     input_keep_prob=keep_prob)
    #全连接层
    output_layer = Dense(vocab_size,
                         kernel_initializer=tf.truncated_normal_initializer(
                             mean=0.0, stddev=0.1))

    attn_mech = seq.BahdanauAttention(rnn_size,
                                      enc_output,
                                      text_length,
                                      normalize=False,
                                      name='BahdanauAttention')

    dec_cell = seq.AttentionWrapper(cell=dec_cell,
                                    attention_mechanism=attn_mech,
                                    attention_layer_size=rnn_size)

    # 引入注意力机制
    initial_state = seq.AttentionWrapperState(
        enc_state[0], _zero_state_tensors(rnn_size, batch_size, tf.float32))

    with tf.variable_scope("decode"):
        training_logits = training_decoding_layer(dec_embed_input,
                                                  summary_length, dec_cell,
                                                  initial_state, output_layer,
                                                  vocab_size,
                                                  max_summary_length)
    with tf.variable_scope("decode", reuse=True):
        inference_logits = inference_decoding_layer(
            embeddings, vocab_to_int['<GO>'], vocab_to_int['<EOS>'], dec_cell,
            initial_state, output_layer, max_summary_length, batch_size)
    return training_logits, inference_logits
    def decoding_layer_train(self, num_units, max_time, batch_size, char2numY,
                             data_output_embed, encoder_output, last_state,
                             bidirectional):
        if not bidirectional:
            decoder_cell = rnn.LSTMCell(num_units)
        else:
            decoder_cell = rnn.LSTMCell(2 * num_units)
        training_helper = seq2seq.TrainingHelper(inputs=data_output_embed,
                                                 sequence_length=[max_time] *
                                                 batch_size,
                                                 time_major=False)

        attention_mechanism = seq2seq.BahdanauAttention(
            num_units=num_units,
            memory=encoder_output,
            memory_sequence_length=[max_time] * batch_size)

        attention_cell = seq2seq.AttentionWrapper(
            cell=decoder_cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=num_units)

        decoder_initial_state = attention_cell.zero_state(
            batch_size=batch_size,
            dtype=tf.float32).clone(cell_state=last_state)
        output_layer = tf.layers.Dense(
            len(char2numY) - 2,
            kernel_initializer=tf.truncated_normal_initializer(mean=0.0,
                                                               stddev=0.1))
        training_decoder = seq2seq.BasicDecoder(
            cell=attention_cell,
            helper=training_helper,
            initial_state=decoder_initial_state,
            output_layer=output_layer)

        train_outputs, _, _ = seq2seq.dynamic_decode(
            decoder=training_decoder,
            impute_finished=True,
            maximum_iterations=max_time)

        return train_outputs
    def wrap_att(self,
                 dec_cell,
                 lstm_size,
                 enc_output,
                 lengths,
                 alignment_history=False):
        """
        Wrap a decoder cell within an attention cell like in the paper: global Luong attention.
        """
        attention_mechanism = s2s.LuongAttention(
            num_units=lstm_size,
            memory=enc_output,
            memory_sequence_length=lengths,
            name='LuongAttention')

        # wrapp as a seq2seq AttentionWrapper
        return s2s.AttentionWrapper(cell=dec_cell,
                                    attention_mechanism=attention_mechanism,
                                    attention_layer_size=None,
                                    output_attention=False,
                                    alignment_history=alignment_history)
Exemplo n.º 29
0
def create_decoder_cell(agenda, base_sent_embeds, insert_word_embeds, delete_word_embeds,
                        base_length, iw_length, dw_length,
                        attn_dim, hidden_dim, num_layer,
                        enable_alignment_history=False, enable_dropout=False, dropout_keep=0.1,
                        no_insert_delete_attn=False):
    base_attn = seq2seq.BahdanauAttention(attn_dim, base_sent_embeds, base_length, name='src_attn')
    attns = [base_attn]
    if not no_insert_delete_attn:
        insert_attn = seq2seq.BahdanauAttention(attn_dim, insert_word_embeds, iw_length, name='insert_attn')
        delete_attn = seq2seq.BahdanauAttention(attn_dim, delete_word_embeds, dw_length, name='delete_attn')
        attns += [insert_attn, delete_attn]

    if no_insert_delete_attn:
        assert len(attns) == 1
    else:
        assert len(attns) == 3

    bottom_cell = tf_rnn.LSTMCell(hidden_dim, name='bottom_cell')
    bottom_attn_cell = seq2seq.AttentionWrapper(
        bottom_cell,
        tuple(attns),
        output_attention=False,
        alignment_history=enable_alignment_history,
        name='att_bottom_cell'
    )

    all_cells = [bottom_attn_cell]

    num_layer -= 1
    for i in range(num_layer):
        cell = tf_rnn.LSTMCell(hidden_dim, name='layer_%s' % (i + 1))
        if enable_dropout and dropout_keep < 1.:
            cell = tf_rnn.DropoutWrapper(cell, output_keep_prob=dropout_keep)

        all_cells.append(cell)

    decoder_cell = AttentionAugmentRNNCell(all_cells)
    decoder_cell.set_agenda(agenda)

    return decoder_cell
Exemplo n.º 30
0
    def build_model(self):
        encoder = self.encoder
        inputs = self.inputs
        with tf.variable_scope('encoder'):
            t_sequence = tf.unstack(inputs, axis=1, name='TimeMajorInputs')
            outputs, _, _ = tf.nn.static_bidirectional_rnn(cell_fw=encoder,
                                                           cell_bw=encoder,
                                                           inputs=t_sequence,
                                                           dtype=inputs.dtype)
        with tf.variable_scope('decoder'):
            with tf.name_scope('attention'):
                memory = tf.stack(outputs,
                                  axis=1,
                                  name='BatchMajorAnnotations')
                self.bahdanau = seq2seq.BahdanauAttention(self.attention_size,
                                                          memory=memory)

            raw_decoder = self.decoder
            decoder_cell = seq2seq.AttentionWrapper(raw_decoder,
                                                    self.bahdanau,
                                                    output_attention=False)
            self.decoder_cell = decoder_cell