Exemplo n.º 1
0
 def inference_decode_layer(self, start_token, dec_cell, end_token,
                            output_layer):
     start_tokens = tf.tile(tf.constant([start_token], dtype=tf.int32),
                            [self.batch_size],
                            name='start_token')
     tiled_enc_output = seq2seq.tile_batch(self.enc_output,
                                           multiplier=self.Beam_width)
     tiled_enc_state = seq2seq.tile_batch(self.enc_state,
                                          multiplier=self.Beam_width)
     tiled_source_len = seq2seq.tile_batch(self.source_len,
                                           multiplier=self.Beam_width)
     atten_mech = seq2seq.BahdanauAttention(self.hidden_dim * 2,
                                            tiled_enc_output,
                                            tiled_source_len,
                                            normalize=True)
     decoder_att = seq2seq.AttentionWrapper(dec_cell, atten_mech,
                                            self.hidden_dim * 2)
     initial_state = decoder_att.zero_state(
         self.batch_size * self.Beam_width,
         tf.float32).clone(cell_state=tiled_enc_state)
     decoder = seq2seq.BeamSearchDecoder(decoder_att,
                                         self.embeddings,
                                         start_tokens,
                                         end_token,
                                         initial_state,
                                         beam_width=self.Beam_width,
                                         output_layer=output_layer)
     infer_logits, _, _ = seq2seq.dynamic_decode(decoder, False, False,
                                                 self.max_target_len)
     return infer_logits
Exemplo n.º 2
0
    def build_decoder_cell(self):

        encoder_outputs = self.encoder_outputs
        encoder_last_state = self.encoder_last_state
        encoder_inputs_length = self.encoder_inputs_length

        if self.use_beamsearch_decode:
            print ("use beamsearch decoding..")
            encoder_outputs = seq2seq.tile_batch(
                self.encoder_outputs, multiplier=self.beam_width)
            encoder_last_state = nest.map_structure(
                lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state)
            encoder_inputs_length = seq2seq.tile_batch(
                self.encoder_inputs_length, multiplier=self.beam_width)

        # Building attention mechanism: Default Bahdanau
        # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473
        self.attention_mechanism = attention_wrapper.BahdanauAttention(
            num_units=self.hidden_units, memory=encoder_outputs,
            memory_sequence_length=encoder_inputs_length,) 
        # 'Luong' style attention: https://arxiv.org/abs/1508.04025
        if self.attention_type.lower() == 'luong':
            self.attention_mechanism = attention_wrapper.LuongAttention(
                num_units=self.hidden_units, memory=encoder_outputs, 
                memory_sequence_length=encoder_inputs_length,)
 
        # Building decoder_cell
        self.decoder_cell_list = [
            self.build_single_cell() for i in range(self.depth)]
        decoder_initial_state = encoder_last_state

        def attn_decoder_input_fn(inputs, attention):
            if not self.attn_input_feeding:
                return inputs

            # Essential when use_residual=True
            _input_layer = Dense(self.hidden_units, dtype=self.dtype,
                                 name='attn_input_feeding')
            return _input_layer(array_ops.concat([inputs, attention], -1))

        # AttentionWrapper wraps RNNCell with the attention_mechanism
        # Note: We implement Attention mechanism only on the top decoder layer
        self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper(
            cell=self.decoder_cell_list[-1],
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.hidden_units,
            cell_input_fn=attn_decoder_input_fn,
            initial_cell_state=encoder_last_state[-1],
            alignment_history=False,
            name='Attention_Wrapper')

        batch_size = self.batch_size if not self.use_beamsearch_decode \
                     else self.batch_size * self.beam_width
        initial_state = [state for state in encoder_last_state]

        initial_state[-1] = self.decoder_cell_list[-1].zero_state(
          batch_size=batch_size, dtype=self.dtype)
        decoder_initial_state = tuple(initial_state)

        return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
Exemplo n.º 3
0
    def _build_decoder_cell(self, enc_outputs, enc_state):
        beam_size = self.config.beam_size
        context_length = self.source_length
        memory = enc_outputs

        if self.mode == ModelMode.infer and beam_size > 0:
            enc_state = tc_seq2seq.tile_batch(enc_state,
                                              multiplier=beam_size)

            memory = tc_seq2seq.tile_batch(memory,
                                           multiplier=beam_size)

            context_length = tc_seq2seq.tile_batch(context_length,
                                                   multiplier=beam_size)

            batch_size = self.batch_size * beam_size

        else:
            enc_state = enc_state
            batch_size = self.batch_size

        dec_cell = get_rnn_cell(self.config.unit_type,
                                hidden_size=self.config.dec_hidden_size,
                                num_layers=self.config.num_layers,
                                dropout_keep_prob=self.dropout_keep_prob)

        return dec_cell, enc_state
Exemplo n.º 4
0
    def _create_attention_mechanisms(self, beam_search=False):
        r"""
        Creates a list of attention mechanisms (e.g. seq2seq.BahdanauAttention)
        and also a list of ints holding the attention projection layer size
        Args:
            beam_search: `bool`, whether the beam-search decoding algorithm is used or not
        """
        mechanisms = []
        layer_sizes = []

        if beam_search is True:
            encoder_memory = seq2seq.tile_batch(
                self._encoder_memory, multiplier=self._hparams.beam_width)

            encoder_features_len = seq2seq.tile_batch(
                self._encoder_features_len, multiplier=self._hparams.beam_width)

        else:
            encoder_memory = self._encoder_memory
            encoder_features_len = self._encoder_features_len

        for attention_type in self._hparams.attention_type[0]:

            attention = self._create_attention_mechanism(
                num_units=self._hparams.decoder_units_per_layer[-1],
                memory=encoder_memory,
                memory_sequence_length=encoder_features_len,
                attention_type=attention_type
            )
            mechanisms.append(attention)
            layer_sizes.append(self._hparams.decoder_units_per_layer[-1])

        return mechanisms, layer_sizes
Exemplo n.º 5
0
    def _build_decoder_cell(self, enc_outputs, enc_state):
        beam_size = self.config.beam_size  # 5,beam search
        context_length = self.source_length
        memory = enc_outputs

        # 预测阶段才进行beam search
        if self.mode == ModelMode.infer and beam_size > 0:
            # beam_search_decoder里面的函数
            enc_state = tc_seq2seq.tile_batch(enc_state, multiplier=beam_size)

            memory = tc_seq2seq.tile_batch(memory, multiplier=beam_size)

            context_length = tc_seq2seq.tile_batch(context_length,
                                                   multiplier=beam_size)

        else:
            enc_state = enc_state
            batch_size = self.batch_size

        dec_cell = get_rnn_cell(
            self.config.unit_type,  # lstm
            hidden_size=self.config.dec_hidden_size,  # 300
            num_layers=1,  # 1
            dropout_keep_prob=self.dropout_keep_prob)

        return dec_cell, enc_state
Exemplo n.º 6
0
    def build_decoder_cell(self):
        encoder_inputs_length = self.encoder_inputs_length
        if self.beam_search:
            print("use beamsearch decoding..")
            self.encoder_outputs = tile_batch(self.encoder_outputs,
                                              multiplier=self.beam_size)
            self.encoder_state = nest.map_structure(
                lambda s: tile_batch(s, self.beam_size), self.encoder_state)
            encoder_inputs_length = tile_batch(encoder_inputs_length,
                                               multiplier=self.beam_size)

        # 定义要使用的attention机制。
        attention_mechanism = BahdanauAttention(
            num_units=self.rnn_size,
            memory=self.encoder_outputs,
            memory_sequence_length=encoder_inputs_length)

        # 定义decoder阶段要使用的RNNCell,然后为其封装attention wrapper
        decoder_cell = self.create_rnn_cell()
        decoder_cell = AttentionWrapper(
            cell=decoder_cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=self.rnn_size,
            name='Attention_Wrapper')

        batch_size = self.batch_size if not self.beam_search else self.batch_size * self.beam_size

        decoder_initial_state = decoder_cell.zero_state(
            batch_size=batch_size,
            dtype=tf.float32).clone(cell_state=self.encoder_state)

        return decoder_cell, decoder_initial_state
Exemplo n.º 7
0
    def getBeamSearchDecoderCell(self, encoder_outputs, encoder_final_states):
        basic_cells = [self.get_basicLSTMCell() for i in range(layer_num)]
        basic_cell = tf.nn.rnn_cell.MultiRNNCell(basic_cells)
        tiled_encoder_outputs = seq2seq.tile_batch(encoder_outputs,
                                                   multiplier=beam_size)
        tiled_encoder_final_states = [
            seq2seq.tile_batch(state, multiplier=beam_size)
            for state in encoder_final_states
        ]
        tiled_sequence_length = seq2seq.tile_batch(self.enc_len,
                                                   multiplier=beam_size)
        initial_state = tuple(tiled_encoder_final_states)
        #attention
        attention_mechanism = seq2seq.BahdanauAttention(
            num_units=num_units,
            memory=tiled_encoder_outputs,
            memory_sequence_length=tiled_sequence_length)
        att_cell = seq2seq.AttentionWrapper(
            basic_cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=num_units,
            alignment_history=False,
            cell_input_fn=None,
            initial_cell_state=initial_state)

        initial_state = att_cell.zero_state(
            batch_size=tf.shape(self.enc_in)[0] * beam_size, dtype=tf.float32)
        #            att_state.clone(cell_state=encoder_final_state)

        return att_cell, initial_state
Exemplo n.º 8
0
    def _get_beam_search_cell(self, beam_width):
        """Returns the RNN cell for beam search decoding.
        """
        with tf.variable_scope(self.variable_scope, reuse=True):
            attn_kwargs = copy.copy(self._attn_kwargs)

            memory = attn_kwargs['memory']
            attn_kwargs['memory'] = tile_batch(memory, multiplier=beam_width)

            memory_seq_length = attn_kwargs['memory_sequence_length']
            if memory_seq_length is not None:
                attn_kwargs['memory_sequence_length'] = tile_batch(
                    memory_seq_length, beam_width)

            attn_modules = ['tensorflow.contrib.seq2seq', 'texar.tf.custom']
            bs_attention_mechanism = utils.check_or_get_instance(
                self._hparams.attention.type,
                attn_kwargs,
                attn_modules,
                classtype=tf.contrib.seq2seq.AttentionMechanism)

            bs_attn_cell = AttentionWrapper(self._cell._cell,
                                            bs_attention_mechanism,
                                            cell_input_fn=self._cell_input_fn,
                                            **self._attn_cell_kwargs)

            self._beam_search_cell = bs_attn_cell

            return bs_attn_cell
Exemplo n.º 9
0
def beam_eval_decoder(agenda,
                      embeddings,
                      extended_base_words,
                      oov,
                      start_token_id,
                      stop_token_id,
                      base_sent_hiddens,
                      base_length,
                      vocab_size,
                      attn_dim,
                      hidden_dim,
                      num_layer,
                      max_sentence_length,
                      beam_width,
                      swap_memory,
                      enable_dropout=False,
                      dropout_keep=1.,
                      no_insert_delete_attn=False):
    with tf.variable_scope(OPS_NAME, 'decoder', reuse=True):
        true_batch_size = tf.shape(base_sent_hiddens)[0]

        tiled_agenda = seq2seq.tile_batch(agenda, beam_width)
        tiled_extended_base_words = seq2seq.tile_batch(extended_base_words,
                                                       beam_width)
        tiled_oov = seq2seq.tile_batch(oov, beam_width)

        tiled_base_sent = seq2seq.tile_batch(base_sent_hiddens, beam_width)
        tiled_base_lengths = seq2seq.tile_batch(base_length, beam_width)

        start_token_id = tf.cast(start_token_id, tf.int32)
        stop_token_id = tf.cast(stop_token_id, tf.int32)

        cell, zero_states = create_decoder_cell(
            tiled_agenda,
            tiled_extended_base_words,
            tiled_oov,
            tiled_base_sent,
            tiled_base_lengths,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_layer,
            enable_dropout=enable_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn,
            beam_width=beam_width)

        decoder = seq2seq.BeamSearchDecoder(cell,
                                            create_embedding_fn(vocab_size),
                                            tf.fill([true_batch_size],
                                                    start_token_id),
                                            stop_token_id,
                                            zero_states,
                                            beam_width=beam_width,
                                            length_penalty_weight=0.0)

        return seq2seq.dynamic_decode(decoder,
                                      maximum_iterations=max_sentence_length,
                                      swap_memory=swap_memory)
Exemplo n.º 10
0
    def _build_decoder_cell(self):
        # no beam
        encoder_outputs = self.encoder_outputs
        encoder_last_state = self.encoder_last_state
        encoder_inputs_length = self.encoder_inputs_length

        def attn_decoder_input_fn(inputs, attention):
            if not self.attn_input_feeding:
                return inputs
            _input_layer = Dense(self.hidden_units, dtype=self.dtype, name="attn_input_feeding")
            return _input_layer(array_ops.concat([inputs, attention], -1))
        
        # attention mechanism 'luong'
        with tf.variable_scope('shared_attention_mechanism'):
            self.attention_mechanism = attention_wrapper.LuongAttention(num_units=self.hidden_units, \
                                                                        memory=encoder_outputs, memory_sequence_length=encoder_inputs_length)        
        # build decoder cell
        self.init_decoder_cell_list = [self._build_single_cell() for i in range(self.depth)]
        decoder_initial_state = encoder_last_state
        
        self.decoder_cell_list = self.init_decoder_cell_list[:-1] + [attention_wrapper.AttentionWrapper(\
            cell = self.init_decoder_cell_list[-1], \
            attention_mechanism=self.attention_mechanism,\
            attention_layer_size=self.hidden_units,\
            cell_input_fn=attn_decoder_input_fn,\
            initial_cell_state=encoder_last_state[-1],\
            alignment_history=False)]
        batch_size = self.batch_size
        initial_state = [state for state in encoder_last_state]
        initial_state[-1] = self.decoder_cell_list[-1].zero_state(batch_size=batch_size, dtype=self.dtype)
        decoder_initial_state = tuple(initial_state)
        
        # beam
        beam_encoder_outputs = seq2seq.tile_batch(self.encoder_outputs, multiplier=self.beam_width)
        beam_encoder_last_state = nest.map_structure(lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state)
        beam_encoder_inputs_length = seq2seq.tile_batch(self.encoder_inputs_length, multiplier=self.beam_width)

        with tf.variable_scope('shared_attention_mechanism', reuse=True):
            self.beam_attention_mechanism = attention_wrapper.LuongAttention(num_units=self.hidden_units, \
                                                                             memory=beam_encoder_outputs, \
                                                                             memory_sequence_length=beam_encoder_inputs_length)

        beam_decoder_initial_state = beam_encoder_last_state
        self.beam_decoder_cell_list = self.init_decoder_cell_list[:-1] + [attention_wrapper.AttentionWrapper(\
            cell = self.init_decoder_cell_list[-1], \
            attention_mechanism=self.beam_attention_mechanism,\
            attention_layer_size=self.hidden_units,\
            cell_input_fn=attn_decoder_input_fn,\
            initial_cell_state=beam_encoder_last_state[-1],\
            alignment_history=False)]
            
        beam_batch_size = self.batch_size * self.beam_width
        beam_initial_state = [state for state in beam_encoder_last_state]
        beam_initial_state[-1] = self.beam_decoder_cell_list[-1].zero_state(batch_size=beam_batch_size, dtype=self.dtype)
        beam_decoder_initial_state = tuple(beam_initial_state)
        
        return MultiRNNCell(self.decoder_cell_list), decoder_initial_state, \
               MultiRNNCell(self.beam_decoder_cell_list), beam_decoder_initial_state
Exemplo n.º 11
0
def decoder(x, decoder_inputs, keep_prob, sequence_length, memory,
            memory_length, first_attention):
    with tf.variable_scope("Decoder") as scope:
        label_embeddings = tf.get_variable(name="embeddings",
                                           shape=[n_classes, embedding_size],
                                           dtype=tf.float32)
        train_inputs_embedded = tf.nn.embedding_lookup(label_embeddings,
                                                       decoder_inputs)
        lstm = rnn.LayerNormBasicLSTMCell(n_hidden,
                                          dropout_keep_prob=keep_prob)
        output_l = layers_core.Dense(n_classes, use_bias=True)
        encoder_state = rnn.LSTMStateTuple(x, x)
        attention_mechanism = BahdanauAttention(
            embedding_size,
            memory=memory,
            memory_sequence_length=memory_length)
        cell = AttentionWrapper(lstm,
                                attention_mechanism,
                                output_attention=False)
        cell_state = cell.zero_state(dtype=tf.float32,
                                     batch_size=train_batch_size)
        cell_state = cell_state.clone(cell_state=encoder_state,
                                      attention=first_attention)
        train_helper = TrainingHelper(train_inputs_embedded, sequence_length)
        train_decoder = BasicDecoder(cell,
                                     train_helper,
                                     cell_state,
                                     output_layer=output_l)
        decoder_outputs_train, decoder_state_train, decoder_seq_train = dynamic_decode(
            train_decoder, impute_finished=True)
        tiled_inputs = tile_batch(memory, multiplier=beam_width)
        tiled_sequence_length = tile_batch(memory_length,
                                           multiplier=beam_width)
        tiled_first_attention = tile_batch(first_attention,
                                           multiplier=beam_width)
        attention_mechanism = BahdanauAttention(
            embedding_size,
            memory=tiled_inputs,
            memory_sequence_length=tiled_sequence_length)
        x2 = tile_batch(x, beam_width)
        encoder_state2 = rnn.LSTMStateTuple(x2, x2)
        cell = AttentionWrapper(lstm,
                                attention_mechanism,
                                output_attention=False)
        cell_state = cell.zero_state(dtype=tf.float32,
                                     batch_size=test_batch_size * beam_width)
        cell_state = cell_state.clone(cell_state=encoder_state2,
                                      attention=tiled_first_attention)
        infer_decoder = BeamSearchDecoder(cell,
                                          embedding=label_embeddings,
                                          start_tokens=[GO] * test_len,
                                          end_token=EOS,
                                          initial_state=cell_state,
                                          beam_width=beam_width,
                                          output_layer=output_l)
        decoder_outputs_infer, decoder_state_infer, decoder_seq_infer = dynamic_decode(
            infer_decoder, maximum_iterations=4)
        return decoder_outputs_train, decoder_outputs_infer, decoder_state_infer
Exemplo n.º 12
0
    def build_dec_cell(self, hidden_size):
        enc_outputs = self.enc_outputs
        enc_last_state = self.enc_last_state
        enc_inputs_length = self.enc_inp_len

        if self.use_beam_search:
            self.logger.info("using beam search decoding")
            enc_outputs = seq2seq.tile_batch(self.enc_outputs,
                                             multiplier=self.p.beam_width)
            enc_last_state = nest.map_structure(
                lambda s: seq2seq.tile_batch(s, self.p.beam_width),
                self.enc_last_state)
            enc_inputs_length = seq2seq.tile_batch(self.enc_inp_len,
                                                   self.p.beam_width)

        if self.p.attention_type.lower() == 'luong':
            self.attention_mechanism = attention_wrapper.LuongAttention(
                num_units=hidden_size,
                memory=enc_outputs,
                memory_sequence_length=enc_inputs_length)
        else:
            self.attention_mechanism = attention_wrapper.BahdanauAttention(
                num_units=hidden_size,
                memory=enc_outputs,
                memory_sequence_length=enc_inputs_length)

        def attn_dec_input_fn(inputs, attention):
            if not self.p.attn_input_feeding:
                return inputs
            else:
                _input_layer = Dense(hidden_size,
                                     dtype=self.p.dtype,
                                     name='attn_input_feeding')
                return _input_layer(tf.concat([inputs, attention], -1))

        self.dec_cell_list = [
            self.build_single_cell(hidden_size) for _ in range(self.p.depth)
        ]

        if self.p.use_attn:
            self.dec_cell_list[-1] = attention_wrapper.AttentionWrapper(
                cell=self.dec_cell_list[-1],
                attention_mechanism=self.attention_mechanism,
                attention_layer_size=hidden_size,
                cell_input_fn=attn_dec_input_fn,
                initial_cell_state=enc_last_state[-1],
                alignment_history=False,
                name='attention_wrapper')

        batch_size = self.p.batch_size if not self.use_beam_search else self.p.batch_size * self.p.beam_width
        initial_state = [state for state in enc_last_state]
        if self.p.use_attn:
            initial_state[-1] = self.dec_cell_list[-1].zero_state(
                batch_size=batch_size, dtype=self.p.dtype)
        dec_initial_state = tuple(initial_state)

        return MultiRNNCell(self.dec_cell_list), dec_initial_state
Exemplo n.º 13
0
    def setup_decoder_cell(self, config, keep_prob, use_beam_search,
                           init_state, attention_states, attention_lengths):
        batch_size = get_state_shape(init_state)[0]
        if use_beam_search:
            attention_states = tile_batch(attention_states,
                                          multiplier=self.beam_width)
            init_state = nest.map_structure(
                lambda s: tile_batch(s, self.beam_width), init_state)
            attention_lengths = tile_batch(attention_lengths,
                                           multiplier=self.beam_width)
            batch_size = batch_size * self.beam_width

        attention_size = shape(attention_states, -1)
        attention = getattr(tf.contrib.seq2seq, config.attention_type)(
            attention_size,
            attention_states,
            memory_sequence_length=attention_lengths)

        def cell_input_fn(inputs, attention):
            # define cell input function to keep input/output dimension same
            if not config.use_attention_input_feeding:
                return inputs
            attn_project = tf.layers.Dense(config.hidden_size,
                                           dtype=tf.float32,
                                           name='attn_input_feeding',
                                           activation=self.activation)
            return attn_project(tf.concat([inputs, attention], axis=-1))

        cells = _setup_decoder_cell(config, keep_prob)
        if config.top_attention:  # apply attention mechanism only on the top decoder layer
            cells[-1] = AttentionWrapper(
                cells[-1],
                attention_mechanism=attention,
                name="AttentionWrapper",
                attention_layer_size=config.hidden_size,
                alignment_history=use_beam_search,
                initial_cell_state=init_state[-1],
                cell_input_fn=cell_input_fn)
            init_state = [state for state in init_state]
            init_state[-1] = cells[-1].zero_state(batch_size=batch_size,
                                                  dtype=tf.float32)
            init_state = tuple(init_state)
            cells = MultiRNNCell(cells)
        else:
            cells = MultiRNNCell(cells)
            cells = AttentionWrapper(cells,
                                     attention_mechanism=attention,
                                     name="AttentionWrapper",
                                     attention_layer_size=config.hidden_size,
                                     alignment_history=use_beam_search,
                                     initial_cell_state=init_state,
                                     cell_input_fn=cell_input_fn)
            init_state = cells.zero_state(batch_size=batch_size, dtype=tf.float32) \
                              .clone(cell_state=init_state)
        return cells, init_state
Exemplo n.º 14
0
    def build_decoder_cell(self):
        encoder_inputs_length = self.encoder_inputs_length  # 编码器输入长度
        if self.beam_search:  # 是否使用beam search
            print("use beamsearch decoding..")
            # 如果使用beam_search,则需要将encoder的输出进行tile_batch
            # tile_batch的功能是将第一个参数的数据复制multiplier份,在此例中是beam_size份
            self.encoder_outputs = tile_batch(self.encoder_outputs,
                                              multiplier=self.beam_size)
            # lambda是一个表达式,在此处相当于是一个关于s的函数
            # nest.map_structure(func,structure)将func应用于每一个structure并返回值
            # 因为LSTM中有c和h两个structure,所以需要使用nest.map_structrue()
            self.encoder_state = nest.map_structure(
                lambda s: tile_batch(s, self.beam_size), self.encoder_state)
            encoder_inputs_length = tile_batch(encoder_inputs_length,
                                               multiplier=self.beam_size)

        # 定义要使用的attention机制。
        # 使用的attention机制是Bahdanau Attention,关于这种attention机制的细节,可以查看论文
        # Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
        # "Neural Machine Translation by Jointly Learning to Align and Translate."
        # ICLR 2015. https://arxiv.org/abs/1409.0473
        # 这种attention机制还有一种正则化的版本,如果需要在tensorflow中使用,加上参数normalize=True即可
        # 关于正则化的细节,可以查看论文
        # Tim Salimans, Diederik P. Kingma.
        # "Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks."
        # https://arxiv.org/abs/1602.07868
        attention_mechanism = BahdanauAttention(
            num_units=self.rnn_size,  # 隐层的维度
            memory=self.encoder_outputs,  # 通常情况下就是encoder的输出
            # memory的mask,超过长度数据不计入attention
            memory_sequence_length=encoder_inputs_length)

        # 定义decoder阶段要使用的RNNCell,然后为其封装attention wrapper
        decoder_cell = self.create_rnn_cell()  # 定义decoder阶段要使用的RNNCell
        decoder_cell = AttentionWrapper(  # AttentionWrapper()用于封装带attention机制的RNN网络
            cell=decoder_cell,  # cell参数指明了需要封装的RNN网络
            attention_mechanism=
            attention_mechanism,  # attention_mechanism指明了AttentionMechanism的实例
            attention_layer_size=self.
            rnn_size,  # attention_layer_size TODO:是attention封装后的RNN状态维度?
            name='Attention_Wrapper'  # name指明了AttentionWrapper的名字
        )

        # 如果使用beam_seach则batch_size = self.batch_size * self.beam_size
        batch_size = self.batch_size if not self.beam_search else self.batch_size * self.beam_size

        # AttentionWrapper.zero_state()的功能是将AttentionWrapper对象0初始化
        # AttentionWrapper对象0初始化后可以使用.clone()方法将参数中的状态赋值给AttentionWrapper对象
        # 本例中使用encoder阶段的最后一个隐层状态来赋值定义decoder阶段的初始化状态
        decoder_initial_state = decoder_cell.zero_state(
            batch_size=batch_size,
            dtype=tf.float32).clone(cell_state=self.encoder_state)

        return decoder_cell, decoder_initial_state
Exemplo n.º 15
0
    def __graph__(self):

        # encoder
        encoder_outputs, encoder_state = self.encoder()

        # decoder
        with tf.variable_scope('decoder'):
            encoder_inputs_length = self.encoder_inputs_length
            if self.beam_search:
                # 如果使用beam_search,则需要将encoder的输出进行tile_batch,其实就是复制beam_size份。
                print("use beamsearch decoding..")
                encoder_outputs = tile_batch(encoder_outputs, multiplier=self.beam_size)
                encoder_state = nest.map_structure(lambda s: tf.contrib.seq2seq.tile_batch(s, self.beam_size), encoder_state)
                encoder_inputs_length = tile_batch(encoder_inputs_length, multiplier=self.beam_size)

            # 定义要使用的attention机制。
            attention_mechanism = BahdanauAttention(num_units=self.rnn_size,
                                                    memory=encoder_outputs,
                                                    memory_sequence_length=encoder_inputs_length)
            # 定义decoder阶段要是用的RNNCell,然后为其封装attention wrapper
            decoder_cell = self.create_rnn_cell()
            decoder_cell = AttentionWrapper(cell=decoder_cell,
                                            attention_mechanism=attention_mechanism,
                                            attention_layer_size=self.rnn_size,
                                            name='Attention_Wrapper')
            # 如果使用beam_seach则batch_size = self.batch_size * self.beam_size
            batch_size = self.batch_size if not self.beam_search else self.batch_size * self.beam_size

            # 定义decoder阶段的初始化状态,直接使用encoder阶段的最后一个隐层状态进行赋值
            decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size,
                                                            dtype=tf.float32).clone(cell_state=encoder_state)

            output_layer = tf.layers.Dense(self.vocab_size, kernel_initializer=tf.truncated_normal_initializer(
                                                            mean=0.0,
                                                            stddev=0.1))

            if self.mode == 'train':
                self.decoder_outputs = self.decoder_train(decoder_cell, decoder_initial_state, output_layer)
                # loss
                self.loss = sequence_loss(logits=self.decoder_outputs, targets=self.decoder_targets, weights=self.mask)

                # summary
                tf.summary.scalar('loss', self.loss)
                self.summary_op = tf.summary.merge_all()

                # optimizer
                optimizer = tf.train.AdamOptimizer(self.learing_rate)
                trainable_params = tf.trainable_variables()
                gradients = tf.gradients(self.loss, trainable_params)
                clip_gradients, _ = tf.clip_by_global_norm(gradients, self.max_gradient_norm)
                self.train_op = optimizer.apply_gradients(zip(clip_gradients, trainable_params))
            elif self.mode == 'decode':
                self.decoder_predict_decode = self.decoder_decode(decoder_cell, decoder_initial_state, output_layer)
    def _create_attention_mechanisms(self, beam_search=False):

        mechanisms = []
        layer_sizes = []

        if self._video_memory is not None:

            if beam_search is True:
                #  TODO potentially broken, please re-check
                self._video_memory = seq2seq.tile_batch(
                    self._video_memory, multiplier=self._hparams.beam_width)

                self._video_features_len = seq2seq.tile_batch(
                    self._video_features_len,
                    multiplier=self._hparams.beam_width)

            for attention_type in self._hparams.attention_type[0]:

                attention_video = self._create_attention_mechanism(
                    num_units=self._hparams.decoder_units_per_layer[-1],
                    memory=self._video_memory,
                    memory_sequence_length=self._video_features_len,
                    attention_type=attention_type)
                mechanisms.append(attention_video)
                layer_sizes.append(self._hparams.decoder_units_per_layer[-1] /
                                   2)

        if self._audio_memory is not None:

            if beam_search is True:
                #  TODO potentially broken, please re-check
                self._audio_memory = seq2seq.tile_batch(
                    self._audio_memory, multiplier=self._hparams.beam_width)

                self._audio_features_len = seq2seq.tile_batch(
                    self._audio_features_len,
                    multiplier=self._hparams.beam_width)

            for attention_type in self._hparams.attention_type[1]:
                attention_audio = self._create_attention_mechanism(
                    num_units=self._hparams.decoder_units_per_layer[-1],
                    memory=self._audio_memory,
                    memory_sequence_length=self._audio_features_len,
                    attention_type=attention_type)
                mechanisms.append(attention_audio)
                layer_sizes.append(self._hparams.decoder_units_per_layer[-1] /
                                   2)

        return mechanisms, layer_sizes
Exemplo n.º 17
0
def create_attention_mechanisms(num_units,
                                attention_types,
                                mode,
                                dtype,
                                beam_search=False,
                                beam_width=None,
                                memory=None,
                                memory_len=None,
                                fusion_type=None):
    r"""
    Creates a list of attention mechanisms (e.g. seq2seq.BahdanauAttention)
    and also a list of ints holding the attention projection layer size
    Args:
        beam_search: `bool`, whether the beam-search decoding algorithm is used or not
    """
    mechanisms = []
    output_attention = None

    if beam_search is True:
        memory = seq2seq.tile_batch(memory, multiplier=beam_width)

        memory_len = seq2seq.tile_batch(memory_len, multiplier=beam_width)

    for attention_type in attention_types:
        attention, output_attention = create_attention_mechanism(
            num_units=num_units,  # has to match decoder's state(query) size
            memory=memory,
            memory_sequence_length=memory_len,
            attention_type=attention_type,
            mode=mode,
            dtype=dtype,
        )
        mechanisms.append(attention)

    N = len(attention_types)
    if fusion_type == 'deep_fusion':
        attention_layer_sizes = None
        attention_layers = [
            AttentionLayers(units=num_units, dtype=dtype) for _ in range(N)
        ]
    elif fusion_type == 'linear_fusion':
        attention_layer_sizes = [
            num_units,
        ] * N
        attention_layers = None
    else:
        raise Exception('Unknown fusion type')

    return mechanisms, attention_layers, attention_layer_sizes, output_attention
Exemplo n.º 18
0
    def _create_decoder_cell(self, encoder_outputs, encoder_state,
                             source_sequence_length):
        """Build an RNN cell that can be used by decoder."""

        # We only make use of encoder_outputs in attention-based models
        if self.attention_option:
            raise ValueError("BasicModel doesn't support attention.")

        cell = model_helper.create_rnn_cell(
            unit_type=self.unit_type,
            num_units=self.num_units,
            num_layers=self.num_decoder_layers,
            num_residual_layers=self.num_decoder_residual_layers,
            forget_bias=self.forget_bias,
            dropout=self.dropout,
            mode=self.mode)

        if self.mode == ModeKeys.INFER and self.beam_width > 0:
            # For beam search, we need to replicate encoder state `beam_width` times
            decoder_initial_state = seq2seq.tile_batch(
                encoder_state, multiplier=self.beam_width)
        else:
            decoder_initial_state = encoder_state

        return cell, decoder_initial_state
Exemplo n.º 19
0
    def _build_decoder_test_beam_search(self):
        r"""
        Builds a beam search test decoder
        """
        if self._hparams.enable_attention is True:
            cells, initial_state = self._add_attention(self._decoder_cells, beam_search=True)
        else:  # does the non-attentive beam decoder need tile_batch ?
            cells = self._decoder_cells

            decoder_initial_state_tiled = seq2seq.tile_batch(  # guess so ? it compiles without it too
                self._decoder_initial_state, multiplier=self._hparams.beam_width)
            initial_state = decoder_initial_state_tiled

        self._decoder_inference = seq2seq.BeamSearchDecoder(
            cell=cells,
            embedding=self._embedding_matrix,
            start_tokens=array_ops.fill([self._batch_size], self._GO_ID),
            end_token=self._EOS_ID,
            initial_state=initial_state,
            beam_width=self._hparams.beam_width,
            output_layer=self._dense_layer,
            length_penalty_weight=0.6,
        )

        outputs, states, lengths = seq2seq.dynamic_decode(
            self._decoder_inference,
            impute_finished=False,
            maximum_iterations=self._hparams.max_label_length,
            swap_memory=False)

        self.inference_outputs = outputs.beam_search_decoder_output
        self.inference_predicted_ids = outputs.predicted_ids[:, :, 0]  # return the first beam
        self.inference_predicted_beam = outputs.predicted_ids
Exemplo n.º 20
0
def add_attention(
    cells,
    attention_types,
    num_units,
    memory,
    memory_len,
    mode,
    batch_size,
    dtype,
    beam_search=False,
    beam_width=None,
    initial_state=None,
    write_attention_alignment=False,
    fusion_type='linear_fusion',
):
    r"""
    Wraps the decoder_cells with an AttentionWrapper
    Args:
        cells: instances of `RNNCell`
        beam_search: `bool` flag for beam search decoders
        batch_size: `Tensor` containing the batch size. Necessary to the initialisation of the initial state

    Returns:
        attention_cells: the Attention wrapped decoder cells
        initial_state: a proper initial state to be used with the returned cells
    """
    attention_mechanisms, attention_layers, attention_layer_sizes, output_attention = create_attention_mechanisms(
        beam_search=beam_search,
        beam_width=beam_width,
        memory=memory,
        memory_len=memory_len,
        num_units=num_units,
        attention_types=attention_types,
        fusion_type=fusion_type,
        mode=mode,
        dtype=dtype)

    if beam_search is True:
        initial_state = seq2seq.tile_batch(initial_state,
                                           multiplier=beam_width)

    attention_cells = seq2seq.AttentionWrapper(
        cell=cells,
        attention_mechanism=attention_mechanisms,
        attention_layer_size=attention_layer_sizes,
        # initial_cell_state=decoder_initial_state,
        alignment_history=write_attention_alignment,
        output_attention=output_attention,
        attention_layer=attention_layers,
    )

    attn_zero = attention_cells.zero_state(
        dtype=dtype,
        batch_size=batch_size *
        beam_width if beam_search is True else batch_size)

    if initial_state is not None:
        initial_state = attn_zero.clone(cell_state=initial_state)

    return attention_cells, initial_state
Exemplo n.º 21
0
 def _build_single_attention_mechanism(memory):
     if not self._is_training:
         memory = seq2seq.tile_batch(memory,
                                     multiplier=self._beam_width)
     return seq2seq.BahdanauAttention(self._num_attention_units,
                                      memory,
                                      memory_sequence_length=None)
Exemplo n.º 22
0
def _get_initial_state(initial_state,
                       tiled_initial_state,
                       cell,
                       batch_size,
                       beam_width,
                       dtype):
    if tiled_initial_state is None:
        if isinstance(initial_state, AttentionWrapperState):
            raise ValueError(
                '`initial_state` must not be an AttentionWrapperState. Use '
                'a plain cell state instead, which will be wrapped into an '
                'AttentionWrapperState automatically.')
        if initial_state is None:
            tiled_initial_state = cell.zero_state(batch_size * beam_width,
                                                  dtype)
        else:
            tiled_initial_state = tile_batch(initial_state,
                                             multiplier=beam_width)

    if isinstance(cell, AttentionWrapper) and \
            not isinstance(tiled_initial_state, AttentionWrapperState):
        zero_state = cell.zero_state(batch_size * beam_width, dtype)
        tiled_initial_state = zero_state.clone(cell_state=tiled_initial_state)

    return tiled_initial_state
Exemplo n.º 23
0
 def infer(
     self,
     cause_encoder,
 ):
     batch_size = tf.shape(self._initial_state)[0]
     tiled_initial_state = tile_batch(self._initial_state,
                                      multiplier=self._beam_width)
     tiled_initial_state = LSTMStateTuple(
         tiled_initial_state,
         tiled_initial_state,
         last_choice=array_ops.fill([batch_size * self._beam_width],
                                    self._SOS))
     infer_decoder = MyBeamSearchDecoder(self._lstm_cell,
                                         embedding=cause_encoder,
                                         start_tokens=tf.fill([batch_size],
                                                              self._SOS),
                                         end_token=self._EOS,
                                         initial_state=tiled_initial_state,
                                         beam_width=self._beam_width,
                                         output_layer=self._project_dense,
                                         lookup_table=self._cause_table,
                                         length_penalty_weight=0.7,
                                         hie=self._hie)
     cause_output_infer, cause_state_infer, cause_length_infer = dynamic_decode(
         infer_decoder,
         parallel_iterations=128,
         maximum_iterations=self._max_cause_length - 1,
         scope='decoder')
     return cause_output_infer, cause_state_infer, cause_length_infer
Exemplo n.º 24
0
    def model(self):
        with tf.variable_scope("encoder"):
            encoder_cell = self._create_rnn_cell()
            source_embedding = tf.get_variable(name="source_embedding",
                                               shape=[self.source_vocab_size, self.embedding_size],
                                               initializer=tf.initializers.truncated_normal())
            encoder_embedding_inputs = tf.nn.embedding_lookup(source_embedding, self.source_input)
            encoder_outputs, encoder_states = tf.nn.dynamic_rnn(cell=encoder_cell,
                                                                inputs=encoder_embedding_inputs,
                                                                dtype=tf.float32)
        with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
            if self.mode=="test":
                encoder_states = seq2seq.tile_batch(encoder_states, self.beam_size)
            decoder_cell = self._create_rnn_cell()
            decoder_cell = rnn.DropoutWrapper(decoder_cell,output_keep_prob=0.5)

            if self.mode=="test":
                batch_size = self.batch_size*self.beam_size
            else:
                batch_size = self.batch_size
            #decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size,dtype=tf.float32)

            output_layer = tf.layers.Dense(units=self.target_vocab_size,
                                           kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1))
            target_embedding = tf.get_variable(name="target_embedding",
                                               shape=[self.target_vocab_size, self.embedding_size])
            if self.mode == "train":
                self.mask = tf.sequence_mask(self.target_length,self.max_target_length,dtype=tf.float32)
                del_end = tf.strided_slice(self.target_input,[0,0],[self.batch_size,-1],[1,1])
                decoder_input = tf.concat([tf.fill([self.batch_size, 1],2),del_end],axis=1)
                decoder_input_embedding = tf.nn.embedding_lookup(target_embedding,decoder_input)
                training_helper = seq2seq.TrainingHelper(inputs=decoder_input_embedding,
                                                         sequence_length=tf.fill([self.batch_size],self.max_target_length))
                training_decoder = seq2seq.BasicDecoder(cell=decoder_cell,
                                                        helper=training_helper,
                                                        initial_state=encoder_states,
                                                        output_layer=output_layer)
                decoder_outputs,_,_ = seq2seq.dynamic_decode(decoder=training_decoder,output_time_major=False,
                                                             impute_finished=True,
                                                             maximum_iterations=self.max_target_length)
                self.decoder_logits_train = tf.identity(decoder_outputs.rnn_output)
                self.decoder_predict_train = tf.argmax(self.decoder_logits_train,axis=-1)
                self.loss_op = tf.reduce_mean(tf.losses.softmax_cross_entropy(
                    onehot_labels=tf.one_hot(self.target_input, depth=self.target_vocab_size),
                    logits=self.decoder_logits_train, weights=self.mask))
                optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
                trainable_params = tf.trainable_variables()
                gradients = tf.gradients(self.loss_op, trainable_params)
                clip_gradients, _ = tf.clip_by_global_norm(gradients, self.max_gradient_norm)
                self.train_op = optimizer.apply_gradients(zip(clip_gradients, trainable_params))
            elif self.mode =="test":
                start_tokens = tf.fill([self.batch_size], value=2)
                end_token = 3
                inference_decoder = seq2seq.BeamSearchDecoder(cell=decoder_cell, embedding=target_embedding,
                                                              start_tokens=start_tokens, end_token=end_token,
                                                              initial_state=encoder_states,
                                                              beam_width=self.beam_size, output_layer=output_layer)
                decoder_outputs, _, _ = seq2seq.dynamic_decode(decoder=inference_decoder, maximum_iterations=self.max_target_length)
                print(decoder_outputs.predicted_ids.get_shape().as_list())
                self.decoder_predict_decode = decoder_outputs.predicted_ids[:, :, 0]
Exemplo n.º 25
0
 def build_attention_wrapper(self, final_cell):
     self.feedforward_inputs = tf.cond(
         self.beam_search_decoding, lambda: seq2seq.tile_batch(
             self.features["inputs"], multiplier=self.hparams.beam_width),
         lambda: self.features["inputs"])
     self.feedforward_inputs_length = tf.cond(
         self.beam_search_decoding, lambda: seq2seq.tile_batch(
             self.features["length"], multiplier=self.hparams.beam_width),
         lambda: self.features["length"])
     attention_mechanism = self.build_attention_mechanism()
     return AttentionWrapper(cell=final_cell,
                             attention_mechanism=attention_mechanism,
                             attention_layer_size=self.hparams.hidden_units,
                             cell_input_fn=self._attention_input_feeding,
                             initial_cell_state=self.initial_state[-1] if
                             self.hparams.depth > 1 else self.initial_state)
Exemplo n.º 26
0
    def test_attention_decoder_given_initial_state(self):
        """Tests beam search with RNNAttentionDecoder given initial state.
        """
        seq_length = np.random.randint(self._max_time, size=[self._batch_size
                                                             ]) + 1
        encoder_values_length = tf.constant(seq_length)
        hparams = {
            "attention": {
                "kwargs": {
                    "num_units": self._attention_dim
                }
            },
            "rnn_cell": {
                "kwargs": {
                    "num_units": self._cell_dim
                }
            }
        }
        decoder = tx.modules.AttentionRNNDecoder(
            vocab_size=self._vocab_size,
            memory=self._encoder_output,
            memory_sequence_length=encoder_values_length,
            hparams=hparams)

        state = decoder.cell.zero_state(self._batch_size, tf.float32)

        cell_state = state.cell_state
        self._test_beam_search(decoder, initial_state=cell_state)

        tiled_cell_state = tile_batch(cell_state, multiplier=self._beam_width)
        self._test_beam_search(decoder,
                               tiled_initial_state=tiled_cell_state,
                               initiated=True)
Exemplo n.º 27
0
    def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
                            source_sequence_length):
        """Build an RNN cell that can be used by decoder."""
        # We only make use of encoder_outputs in attention-based models
        if hparams.attention:
            raise ValueError("BasicModel doesn't support attention.")

        cell = model_helper.create_rnn_cell(
            unit_type=hparams.unit_type,
            num_units=hparams.num_units,
            num_layers=self.num_decoder_layers,
            num_residual_layers=self.num_decoder_residual_layers,
            forget_bias=hparams.forget_bias,
            dropout=hparams.dropout,
            num_gpus=self.num_gpus,
            mode=self.mode,
            single_cell_fn=self.single_cell_fn)

        # For beam search, we need to replicate encoder infos beam_width times
        if self.mode == tf.contrib.learn.ModeKeys.INFER and \
                hparams.beam_width > 0:
            decoder_initial_state = seq2seq.tile_batch(
                encoder_state, multiplier=hparams.beam_width)
        else:
            decoder_initial_state = encoder_state

        return cell, decoder_initial_state
Exemplo n.º 28
0
        def create_decoder_cell():
            cell = tf.contrib.rnn.MultiRNNCell([
                utils.make_cell(self.args.enc_num_units,
                                utils.get_device_str(self.args.num_gpus))
                for _ in range(self.args.dec_layers)
            ])

            if self.args.beam_width > 0 and self.mode == "Infer":
                dec_start_state = seq2seq.tile_batch(self.encoder_state,
                                                     self.beam_width)
                enc_outputs = seq2seq.tile_batch(self.encoder_outputs,
                                                 self.beam_width)
                enc_lengths = seq2seq.tile_batch(self.encoder_inputs_length,
                                                 self.beam_width)
            else:
                dec_start_state = self.encoder_state
                enc_outputs = self.encoder_outputs
                enc_lengths = self.encoder_inputs_length

            if self.args.attention:
                attention_states = enc_outputs

                attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                    self.args.dec_num_units,
                    attention_states,
                    memory_sequence_length=enc_lengths)

                decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
                    cell,
                    attention_mechanism,
                    attention_layer_size=self.args.dec_num_units)

                if self.args.beam_width > 0 and self.mode == "Infer":
                    initial_state = decoder_cell.zero_state(
                        self.batch_size * self.beam_width, tf.float32)
                else:
                    initial_state = decoder_cell.zero_state(
                        self.batch_size, tf.float32)

                initial_state = initial_state.clone(cell_state=dec_start_state)
            else:

                decoder_cell = cell
                initial_state = dec_start_state

            return decoder_cell, initial_state
Exemplo n.º 29
0
 def _build_infer(self, config):
   # infer_decoder/beam_search 
   # skip for flat_baseline  
   tiled_inputs = tile_batch(self.xx_context, multiplier=config.beam_width)
   tiled_sequence_length = tile_batch(self.x_seq_length, multiplier=config.beam_width)
   tiled_first_attention = tile_batch(self.first_attention, multiplier=config.beam_width)
   attention_mechanism = BahdanauAttention(config.decode_size, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length)
   tiled_xx_final = tile_batch(self.xx_final, config.beam_width)
   encoder_state2 = rnn.LSTMStateTuple(tiled_xx_final, tiled_xx_final)
   cell = AttentionWrapper(self.lstm, attention_mechanism, output_attention=False)
   cell_state = cell.zero_state(dtype=tf.float32, batch_size = config.test_batch_size * config.beam_width)
   cell_state = cell_state.clone(cell_state=encoder_state2, attention=tiled_first_attention)
   infer_decoder = BeamSearchDecoder(cell, embedding=self.label_embeddings, start_tokens=[config.GO]*config.test_batch_size, end_token=config.EOS,
                                 initial_state=cell_state, beam_width=config.beam_width, output_layer=self.output_l)
   decoder_outputs_infer, decoder_state_infer, decoder_seq_infer = dynamic_decode(infer_decoder, maximum_iterations=config.max_seq_length)
   self.preds = decoder_outputs_infer.predicted_ids
   self.scores = decoder_state_infer.log_probs
Exemplo n.º 30
0
    def _build_decoder_beam_search(self):

        batch_size, _ = tf.unstack(tf.shape(self._labels))

        attention_mechanisms, layer_sizes = self._create_attention_mechanisms(
            beam_search=True)

        decoder_initial_state_tiled = seq2seq.tile_batch(
            self._decoder_initial_state, multiplier=self._hparams.beam_width)

        if self._hparams.enable_attention is True:

            attention_cells = seq2seq.AttentionWrapper(
                cell=self._decoder_cells,
                attention_mechanism=attention_mechanisms,
                attention_layer_size=layer_sizes,
                initial_cell_state=decoder_initial_state_tiled,
                alignment_history=self._hparams.write_attention_alignment,
                output_attention=self._output_attention)

            initial_state = attention_cells.zero_state(
                dtype=self._hparams.dtype,
                batch_size=batch_size * self._hparams.beam_width)

            initial_state = initial_state.clone(
                cell_state=decoder_initial_state_tiled)

            cells = attention_cells
        else:
            cells = self._decoder_cells
            initial_state = decoder_initial_state_tiled

        self._decoder_inference = seq2seq.BeamSearchDecoder(
            cell=cells,
            embedding=self._embedding_matrix,
            start_tokens=array_ops.fill([batch_size], self._GO_ID),
            end_token=self._EOS_ID,
            initial_state=initial_state,
            beam_width=self._hparams.beam_width,
            output_layer=self._dense_layer,
            length_penalty_weight=0.5,
        )

        outputs, states, lengths = seq2seq.dynamic_decode(
            self._decoder_inference,
            impute_finished=False,
            maximum_iterations=self._hparams.max_label_length,
            swap_memory=False)

        if self._hparams.write_attention_alignment is True:
            self.attention_summary = self._create_attention_alignments_summary(
                states)

        self.inference_outputs = outputs.beam_search_decoder_output
        self.inference_predicted_ids = outputs.predicted_ids[:, :,
                                                             0]  # return the first beam
        self.inference_predicted_beam = outputs.predicted_ids
        self.beam_search_output = outputs.beam_search_decoder_output
Exemplo n.º 31
0
  def sample(self, n, max_length=None, z=None, temperature=None,
             start_inputs=None, beam_width=None, end_token=None):
    """Overrides BaseLstmDecoder `sample` method to add optional beam search.

    Args:
      n: Scalar number of samples to return.
      max_length: (Optional) Scalar maximum sample length to return. Required if
        data representation does not include end tokens.
      z: (Optional) Latent vectors to sample from. Required if model is
        conditional. Sized `[n, z_size]`.
      temperature: (Optional) The softmax temperature to use when not doing beam
        search. Defaults to 1.0. Ignored when `beam_width` is provided.
      start_inputs: (Optional) Initial inputs to use for batch.
        Sized `[n, output_depth]`.
      beam_width: (Optional) Width of beam to use for beam search. Beam search
        is disabled if not provided.
      end_token: (Optional) Scalar token signaling the end of the sequence to
        use for early stopping.
    Returns:
      samples: Sampled sequences. Sized `[n, max_length, output_depth]`.
    Raises:
      ValueError: If `z` is provided and its first dimension does not equal `n`.
    """
    if beam_width is None:
      end_fn = (None if end_token is None else
                lambda x: tf.equal(tf.argmax(x, axis=-1), end_token))
      return super(CategoricalLstmDecoder, self).sample(
          n, max_length, z, temperature, start_inputs, end_fn)

    # If `end_token` is not given, use an impossible value.
    end_token = self._output_depth if end_token is None else end_token
    if z is not None and z.shape[0].value != n:
      raise ValueError(
          '`z` must have a first dimension that equals `n` when given. '
          'Got: %d vs %d' % (z.shape[0].value, n))

    if temperature is not None:
      tf.logging.warning('`temperature` is ignored when using beam search.')
    # Use a dummy Z in unconditional case.
    z = tf.zeros((n, 0), tf.float32) if z is None else z

    # If not given, start with dummy `-1` token and replace with zero vectors in
    # `embedding_fn`.
    start_tokens = (
        tf.argmax(start_inputs, axis=-1, output_type=tf.int32)
        if start_inputs is not None else
        -1 * tf.ones([n], dtype=tf.int32))

    initial_state = initial_cell_state_from_embedding(
        self._dec_cell, z, name='decoder/z_to_initial_state')
    beam_initial_state = seq2seq.tile_batch(
        initial_state, multiplier=beam_width)

    # Tile `z` across beams.
    beam_z = tf.tile(tf.expand_dims(z, 1), [1, beam_width, 1])

    def embedding_fn(tokens):
      # If tokens are the start_tokens (negative), replace with zero vectors.
      next_inputs = tf.cond(
          tf.less(tokens[0, 0], 0),
          lambda: tf.zeros([n, beam_width, self._output_depth]),
          lambda: tf.one_hot(tokens, self._output_depth))

      # Concatenate `z` to next inputs.
      next_inputs = tf.concat([next_inputs, beam_z], axis=-1)
      return next_inputs

    decoder = seq2seq.BeamSearchDecoder(
        self._dec_cell,
        embedding_fn,
        start_tokens,
        end_token,
        beam_initial_state,
        beam_width,
        output_layer=self._output_layer,
        length_penalty_weight=0.0)

    final_output, _, _ = seq2seq.dynamic_decode(
        decoder,
        maximum_iterations=max_length,
        swap_memory=True,
        scope='decoder')

    return tf.one_hot(
        final_output.predicted_ids[:, :, 0],
        self._output_depth)
Exemplo n.º 32
0
  def decode_model(self, train_mode=True):
    with tf.variable_scope('decoder') as scope:
      encoder_outputs = self.encoder_outputs
      source_lengths = self.source_lengths
      encoder_state = self.encoder_state
      global_step = self.global_step
      projection_layer = self.projection_layer
      tgt_sos_id = self.tgt_sos_id
      tgt_eos_id = self.tgt_eos_id
      target = self.target
      target_lookup = self.tgt_lookup
      target_lengths = self.target_lengths
      learning_rate = self.learning_rate
      target_embed = self.tgt_embed

      if train_mode == True:
        reuse = False
      else:
        reuse = False

      # Prepare
      b_size = tf.size(source_lengths)
      if train_mode == False:
        b_size_t = tf.to_int32(b_size * self.beam_width)
        encoder_outputs = tile_batch(encoder_outputs, self.beam_width) #tile_batch(encoder_outputs, beam_width)
        encoder_state = tile_batch(encoder_state, self.beam_width)
        source_lengths = tile_batch(source_lengths, self.beam_width)
      else:
        b_size_t = b_size

      # Create decoder_cell
      rnn_layer = [self.get_cell() for i in range(self.rnn_layer_depth)]
      
      # Create attention cell (top of rnn layer)
      multi_rnn = self.wrap_multi_rnn(rnn_layer)
      attention = self.wrap_attention(multi_rnn, encoder_outputs, source_lengths)
      attention = tf.contrib.rnn.DeviceWrapper(attention, '/cpu:0')

      # sync cell state with encoder

      #decoder_state = [enc for enc in encoder_state]
      decoder_state = attention.zero_state(batch_size=b_size_t, dtype=tf.float32).clone(
          cell_state=encoder_state)

      if train_mode == True:
        # define decoder
        decode_helper = tf.contrib.seq2seq.TrainingHelper(
            target_lookup, target_lengths)
        decoder = tf.contrib.seq2seq.BasicDecoder(
            attention, decode_helper, decoder_state, output_layer=projection_layer)
        outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, scope=scope)
        logits = tf.identity(outputs.rnn_output)

        # Train Loss
        weights = tf.to_float(tf.concat([
          tf.expand_dims(tf.fill([b_size], tf.constant(True)), 1), tf.not_equal(target[:, :-1], tgt_eos_id)]
          , 1))
        loss = tf.contrib.seq2seq.sequence_loss(logits, target, weights=weights)

        # Optimize
        params = tf.trainable_variables()
        gradients = tf.gradients(loss, params)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
        optimizer = tf.train.AdamOptimizer(learning_rate)
        opt = optimizer.apply_gradients(zip(clipped_gradients, params), global_step=global_step)
        return opt, loss, global_step

      else:
        infer_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
            attention, target_embed, tf.fill([b_size], tf.to_int32(tgt_sos_id)), tf.to_int32(tgt_eos_id),
            decoder_state, self.beam_width, output_layer=projection_layer)
        infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
            infer_decoder, maximum_iterations=tf.round(tf.reduce_max(source_lengths) * 2), scope=scope)
        infer_result = infer_outputs.predicted_ids
        return infer_result, target, global_step, target
Exemplo n.º 33
0
def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim,
                   target_dict_dim, is_generating, beam_size,
                   max_generation_length):
    src_word_idx = tf.placeholder(tf.int32, shape=[None, None])
    src_sequence_length = tf.placeholder(tf.int32, shape=[None, ])

    src_embedding_weights = tf.get_variable("source_word_embeddings",
                                            [source_dict_dim, embedding_dim])
    src_embedding = tf.nn.embedding_lookup(src_embedding_weights, src_word_idx)

    src_forward_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size)
    src_reversed_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size)
    # no peephole
    encoder_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
        cell_fw=src_forward_cell,
        cell_bw=src_reversed_cell,
        inputs=src_embedding,
        sequence_length=src_sequence_length,
        dtype=tf.float32)

    # concat the forward outputs and backward outputs
    encoded_vec = tf.concat(encoder_outputs, axis=2)

    # project the encoder outputs to size of decoder lstm
    encoded_proj = tf.contrib.layers.fully_connected(
        inputs=tf.reshape(
            encoded_vec, shape=[-1, embedding_dim * 2]),
        num_outputs=decoder_size,
        activation_fn=None,
        biases_initializer=None)
    encoded_proj_reshape = tf.reshape(
        encoded_proj, shape=[-1, tf.shape(encoded_vec)[1], decoder_size])

    # get init state for decoder lstm's H
    backword_first = tf.slice(encoder_outputs[1], [0, 0, 0], [-1, 1, -1])
    decoder_boot = tf.contrib.layers.fully_connected(
        inputs=tf.reshape(
            backword_first, shape=[-1, embedding_dim]),
        num_outputs=decoder_size,
        activation_fn=tf.nn.tanh,
        biases_initializer=None)

    # prepare the initial state for decoder lstm
    cell_init = tf.zeros(tf.shape(decoder_boot), tf.float32)
    initial_state = LSTMStateTuple(cell_init, decoder_boot)

    # create decoder lstm cell
    decoder_cell = LSTMCellWithSimpleAttention(
        decoder_size,
        encoded_vec
        if not is_generating else seq2seq.tile_batch(encoded_vec, beam_size),
        encoded_proj_reshape if not is_generating else
        seq2seq.tile_batch(encoded_proj_reshape, beam_size),
        src_sequence_length if not is_generating else
        seq2seq.tile_batch(src_sequence_length, beam_size),
        forget_bias=0.0)

    output_layer = Dense(target_dict_dim, name='output_projection')

    if not is_generating:
        trg_word_idx = tf.placeholder(tf.int32, shape=[None, None])
        trg_sequence_length = tf.placeholder(tf.int32, shape=[None, ])
        trg_embedding_weights = tf.get_variable(
            "target_word_embeddings", [target_dict_dim, embedding_dim])
        trg_embedding = tf.nn.embedding_lookup(trg_embedding_weights,
                                               trg_word_idx)

        training_helper = seq2seq.TrainingHelper(
            inputs=trg_embedding,
            sequence_length=trg_sequence_length,
            time_major=False,
            name='training_helper')

        training_decoder = seq2seq.BasicDecoder(
            cell=decoder_cell,
            helper=training_helper,
            initial_state=initial_state,
            output_layer=output_layer)

        # get the max length of target sequence
        max_decoder_length = tf.reduce_max(trg_sequence_length)

        decoder_outputs_train, _, _ = seq2seq.dynamic_decode(
            decoder=training_decoder,
            output_time_major=False,
            impute_finished=True,
            maximum_iterations=max_decoder_length)

        decoder_logits_train = tf.identity(decoder_outputs_train.rnn_output)
        decoder_pred_train = tf.argmax(
            decoder_logits_train, axis=-1, name='decoder_pred_train')
        masks = tf.sequence_mask(
            lengths=trg_sequence_length,
            maxlen=max_decoder_length,
            dtype=tf.float32,
            name='masks')

        # place holder of label sequence
        lbl_word_idx = tf.placeholder(tf.int32, shape=[None, None])

        # compute the loss
        loss = seq2seq.sequence_loss(
            logits=decoder_logits_train,
            targets=lbl_word_idx,
            weights=masks,
            average_across_timesteps=True,
            average_across_batch=True)

        # return feeding list and loss operator
        return {
            'src_word_idx': src_word_idx,
            'src_sequence_length': src_sequence_length,
            'trg_word_idx': trg_word_idx,
            'trg_sequence_length': trg_sequence_length,
            'lbl_word_idx': lbl_word_idx
        }, loss
    else:
        start_tokens = tf.ones([tf.shape(src_word_idx)[0], ],
                               tf.int32) * START_TOKEN_IDX
        # share the same embedding weights with target word
        trg_embedding_weights = tf.get_variable(
            "target_word_embeddings", [target_dict_dim, embedding_dim])

        inference_decoder = beam_search_decoder.BeamSearchDecoder(
            cell=decoder_cell,
            embedding=lambda tokens: tf.nn.embedding_lookup(trg_embedding_weights, tokens),
            start_tokens=start_tokens,
            end_token=END_TOKEN_IDX,
            initial_state=tf.nn.rnn_cell.LSTMStateTuple(
                tf.contrib.seq2seq.tile_batch(initial_state[0], beam_size),
                tf.contrib.seq2seq.tile_batch(initial_state[1], beam_size)),
            beam_width=beam_size,
            output_layer=output_layer)

        decoder_outputs_decode, _, _ = seq2seq.dynamic_decode(
            decoder=inference_decoder,
            output_time_major=False,
            #impute_finished=True,# error occurs
            maximum_iterations=max_generation_length)

        predicted_ids = decoder_outputs_decode.predicted_ids

        return {
            'src_word_idx': src_word_idx,
            'src_sequence_length': src_sequence_length
        }, predicted_ids