Пример #1
0
def get(attention_type,
        num_units,
        memory,
        memory_sequence_length,
        scope=None,
        reuse=None):
    """Returns attention mechanism according to the specified type."""
    with tf.variable_scope(scope, reuse=reuse):
        if attention_type == U.ATT_LUONG:
            attention_mechanism = contrib_seq2seq.LuongAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length)
        elif attention_type == U.ATT_LUONG_SCALED:
            attention_mechanism = contrib_seq2seq.LuongAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                scale=True)
        elif attention_type == U.ATT_BAHDANAU:
            attention_mechanism = contrib_seq2seq.BahdanauAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length)
        elif attention_type == U.ATT_BAHDANAU_NORM:
            attention_mechanism = contrib_seq2seq.BahdanauAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                normalize=True)
        else:
            raise ValueError("Unknown attention type: %s" % attention_type)
    return attention_mechanism
Пример #2
0
    def _create_attention_mechanism(self, attention_type, num_units, memory,
                                    memory_sequence_length):

        if attention_type == 'bahdanau':
            attention_mechanism = seq2seq.BahdanauAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                normalize=False)
            self._output_attention = False
        elif attention_type == 'normed_bahdanau':
            attention_mechanism = seq2seq.BahdanauAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                normalize=True)
            self._output_attention = False
        elif attention_type == 'normed_monotonic_bahdanau':
            attention_mechanism = seq2seq.BahdanauMonotonicAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                normalize=True,
                score_bias_init=-2.0,
                sigmoid_noise=1.0 if self._mode == 'train' else 0.0,
                mode='hard' if self._mode != 'train' else 'parallel')
            self._output_attention = False
        elif attention_type == 'luong':
            attention_mechanism = seq2seq.LuongAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length)
            self._output_attention = True
        elif attention_type == 'scaled_luong':
            attention_mechanism = seq2seq.LuongAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                scale=True,
            )
            self._output_attention = True
        elif attention_type == 'scaled_monotonic_luong':
            attention_mechanism = seq2seq.LuongMonotonicAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                scale=True,
                score_bias_init=-2.0,
                sigmoid_noise=1.0 if self._mode == 'train' else 0.0,
                mode='hard' if self._mode != 'train' else 'parallel')
            self._output_attention = True
        else:
            raise Exception('unknown attention mechanism')

        return attention_mechanism
Пример #3
0
    def build_attention_cell(self, encoder_outputs, encoder_states):
        memory = encoder_outputs

        if self.time_major:
            memory = tf.transpose(encoder_outputs, [1, 0, 2])

        attention_mechanism = seq2seq.LuongAttention(
            num_units=self.hps.att_num_units,
            memory=memory,
            memory_sequence_length=self.iterator.target_length)

        cell = rnn.MultiRNNCell([
            self.build_rnn_cell(FLAGS.cell_type)
            for _ in range(self.hps.stack_layers)
        ])

        cell = seq2seq.AttentionWrapper(
            cell,
            attention_mechanism,
            attention_layer_size=self.hps.att_num_units,
            name='attention')

        batch_size = tf.size(self.iterator.source_length)

        decoder_initial_state = cell.zero_state(
            batch_size=batch_size,
            dtype=dtype).clone(cell_state=encoder_states)

        return cell, decoder_initial_state
Пример #4
0
 def _build_attention(self,
                      enc_outputs,
                      enc_seq_len
                      ):
   with tf.variable_scope("AttentionMechanism"):
     if self.attn_Type == 'bahdanau':
       attention_mechanism = seq2seq.BahdanauAttention(
           num_units=2*self.cell_dim,
           memory=enc_outputs,
           memory_sequence_length=enc_seq_len,
           probability_fn=tf.nn.softmax,
           normalize=True,
           dtype=tf.get_variable_scope().dtype
       )
     elif self.params['attention_type'] == 'luong':
       attention_mechanism = seq2seq.LuongAttention(
           num_units=2*self.cell_dim,
           memory=enc_outputs,
           memory_sequence_length=enc_seq_len,
           probability_fn=tf.nn.softmax,
           dtype=tf.get_variable_scope().dtype
       )
     else:
       raise ValueError('Unknown Attention Type')
     return attention_mechanism
Пример #5
0
    def decoder_cell(self, inputs, lengths):
        attention_mechanism = seq2seq.LuongAttention(
            num_units=self.layer_size,
            memory=inputs,
            memory_sequence_length=lengths,
            scale=True)

        return seq2seq.AttentionWrapper(
            cell=self.cell(),
            attention_mechanism=attention_mechanism,
            attention_layer_size=self.layer_size)
Пример #6
0
    def build_decode_cell(self):
        encoder_outputs = self.encoder_outputs
        encoder_last_state = self.encoder_last_state
        encoder_inputs_length = self.encoder_inputs_length
        # Building attention mechanism: Default Bahdanau
        # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473
        self.attention_mechanism = seq2seq.BahdanauAttention(
            num_units=self.hidden_units,
            memory=encoder_outputs,
            memory_sequence_length=encoder_inputs_length
        )

        if self.attention_type.lower() == 'luong':
            self.attention_mechanism = seq2seq.LuongAttention(
                num_units=self.hidden_units,
                memory=encoder_outputs,
                memory_sequence_length=encoder_inputs_length
            )

        # decoder_cell
        self.decoder_cell_list = [self.build_single_cell(layer=2) for _ in range(self.depth)]

        def attn_decoder_input_fn(inputs, attention):
            if not self.attn_input_feeding:
                return inputs
            # Essential when use_residual=True
            _input_layer = Dense(self.hidden_units * 2, dtype=self.dtype, name='attn_input_feeding')
            return _input_layer(tf.concat([inputs, attention], -1))

        # AttentionWrapper wraps RNNCell with the attention_mechanism
        # Note: We implement Attention mechanism only on the top decoder layer
        self.decoder_cell_list[-1] = seq2seq.AttentionWrapper(
            cell=self.decoder_cell_list[-1],
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.hidden_units,
            cell_input_fn=attn_decoder_input_fn,
            initial_cell_state=encoder_last_state[-1],
            alignment_history=False,
            name='Attention_wrapper'
        )

        # To be compatible with AttentionWrapper, the encoder last state
        # of the top layer should be converted into the AttentionWrapperState form
        # We can easily do this by calling AttentionWrapper.zero_state

        batch_size = self.batch_size
        initial_state = [state for state in encoder_last_state]
        initial_state[-1] = self.decoder_cell_list[-1].zero_state(
            batch_size=batch_size, dtype=self.dtype
        )
        decoder_initial_state = tuple(initial_state)
        return rnn.MultiRNNCell(self.decoder_cell_list), decoder_initial_state
Пример #7
0
 def build_attention_mechanism(self):
     if self.hparams.attention_type == 'luong':
         attention_mechanism = seq2seq.LuongAttention(
             self.hparams.hidden_units, self.feedforward_inputs,
             self.feedforward_inputs_length)
     elif self.hparams.attention_type == 'bahdanau':
         attention_mechanism = seq2seq.BahdanauAttention(
             self.hparams.hidden_units,
             self.feedforward_inputs,
             self.feedforward_inputs_length,
         )
     else:
         raise ValueError(
             "Currently, the only supported attention types are 'luong' and 'bahdanau'."
         )
Пример #8
0
def create_attention_mechanism(attention_option, num_units, memory,
                               source_sequence_length):
    """
    Create attention mechanism based on the attention_option.
    :param attention_option: "luong","scaled_luong","bahdanau","normed_bahdanau"
    :param num_units:
    :param memory: The memory to query; usually the output of an RNN encoder.  This
        tensor should be shaped `[batch_size, max_time, ...]`.
    :param source_sequence_length: (optional) Sequence lengths for the batch entries
        in memory.  If provided, the memory tensor rows are masked with zeros
        for values past the respective sequence lengths.
    :return:
    """
    # Mechanism
    if attention_option == "luong":
        attention_mechanism = seq2seq.LuongAttention(
            num_units, memory, memory_sequence_length=source_sequence_length)
    elif attention_option == "scaled_luong":
        attention_mechanism = seq2seq.LuongAttention(
            num_units,
            memory,
            memory_sequence_length=source_sequence_length,
            scale=True)
    elif attention_option == "bahdanau":
        attention_mechanism = seq2seq.BahdanauAttention(
            num_units, memory, memory_sequence_length=source_sequence_length)
    elif attention_option == "normed_bahdanau":
        attention_mechanism = seq2seq.BahdanauAttention(
            num_units,
            memory,
            memory_sequence_length=source_sequence_length,
            normalize=True)
    else:
        raise ValueError("Unknown attention option %s" % attention_option)

    return attention_mechanism
Пример #9
0
    def _build_decoder_cell(self, encoder_outputs, encoder_state,
                            source_sequence_length):
        beam_width = self.hparams.beam_width
        if self.hparams.time_major:
            memory = tf.transpose(encoder_outputs, [1, 0, 2])

        if self.mode == PREDICT and beam_width > 0:
            memory = seq2seq.tile_batch(memory, beam_width)
            source_sequence_length = seq2seq.tile_batch(
                source_sequence_length, beam_width)
            encoder_state = seq2seq.tile_batch(encoder_state, beam_width)
            batch_size = self.batch_size * beam_width
        else:
            batch_size = self.batch_size

        # Use Attention Mechanism
        attention_machanism = seq2seq.LuongAttention(
            num_units=self.hparams.num_units,
            memory=memory,
            memory_sequence_length=source_sequence_length)

        cell = model_helper.build_rnn_cell(
            self.hparams.unit_type,
            self.hparams.num_units,
            self.hparams.num_layers,
            self.hparams.dropout,
        )
        alignment_history = (self.mode == PREDICT and beam_width == 0)

        cell = seq2seq.AttentionWrapper(
            cell,
            attention_machanism,
            attention_layer_size=self.hparams.num_units,
            alignment_history=alignment_history,
            name='attention')
        if self.hparams.pass_hidden_state:
            initial_state = cell.zero_state(
                batch_size, tf.float32).clone(cell_state=encoder_state)
        else:
            initial_state = cell.zero_state(batch_size, tf.float32)

        return cell, initial_state
    def wrap_att(self,
                 dec_cell,
                 lstm_size,
                 enc_output,
                 lengths,
                 alignment_history=False):
        """
        Wrap a decoder cell within an attention cell like in the paper: global Luong attention.
        """
        attention_mechanism = s2s.LuongAttention(
            num_units=lstm_size,
            memory=enc_output,
            memory_sequence_length=lengths,
            name='LuongAttention')

        # wrapp as a seq2seq AttentionWrapper
        return s2s.AttentionWrapper(cell=dec_cell,
                                    attention_mechanism=attention_mechanism,
                                    attention_layer_size=None,
                                    output_attention=False,
                                    alignment_history=alignment_history)
Пример #11
0
    def build(self, is_train=True):
        # demo_len = self.demo_len
        if self.stack_subsequent_state:
            max_demo_len = self.max_demo_len - 1
            demo_len = self.demo_len - 1
            s_h = tf.stack([
                self.s_h[:, :, :max_demo_len, :, :, :], self.s_h[:, :,
                                                                 1:, :, :, :]
            ],
                           axis=-1)
            depth = self.depth * 2
        else:
            max_demo_len = self.max_demo_len
            demo_len = self.demo_len
            s_h = self.s_h
            depth = self.depth

        # s [bs, h, w, depth] -> feature [bs, v]
        # CNN
        def State_Encoder(s,
                          per,
                          batch_size,
                          scope='State_Encoder',
                          reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                _ = conv2d(s,
                           16,
                           is_train,
                           k_h=3,
                           k_w=3,
                           info=not reuse,
                           batch_norm=True,
                           name='conv1')
                _ = conv2d(_,
                           32,
                           is_train,
                           k_h=3,
                           k_w=3,
                           info=not reuse,
                           batch_norm=True,
                           name='conv2')
                _ = conv2d(_,
                           48,
                           is_train,
                           k_h=3,
                           k_w=3,
                           info=not reuse,
                           batch_norm=True,
                           name='conv3')
                if self.pixel_input:
                    _ = conv2d(_,
                               48,
                               is_train,
                               k_h=3,
                               k_w=3,
                               info=not reuse,
                               batch_norm=True,
                               name='conv4')
                    _ = conv2d(_,
                               48,
                               is_train,
                               k_h=3,
                               k_w=3,
                               info=not reuse,
                               batch_norm=True,
                               name='conv5')
                state_feature = tf.reshape(_, [batch_size, -1])
                if self.state_encoder_fc:
                    state_feature = fc(state_feature,
                                       512,
                                       is_train,
                                       info=not reuse,
                                       name='fc1')
                    state_feature = fc(state_feature,
                                       512,
                                       is_train,
                                       info=not reuse,
                                       name='fc2')
                state_feature = tf.concat([state_feature, per], axis=-1)
                if not reuse:
                    log.info('concat feature {}'.format(state_feature))
                return state_feature

        # s_h [bs, t, h, w, depth] -> feature [bs, v]
        # LSTM
        def Demo_Encoder(s_h,
                         per,
                         seq_lengths,
                         scope='Demo_Encoder',
                         reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                state_features = tf.reshape(
                    State_Encoder(tf.reshape(s_h, [-1, self.h, self.w, depth]),
                                  tf.reshape(per, [-1, self.per_dim]),
                                  self.batch_size * max_demo_len,
                                  reuse=reuse),
                    [self.batch_size, max_demo_len, -1])
                if self.encoder_rnn_type == 'bilstm':
                    fcell = rnn.BasicLSTMCell(num_units=math.ceil(
                        self.num_lstm_cell_units),
                                              state_is_tuple=True)
                    bcell = rnn.BasicLSTMCell(num_units=math.floor(
                        self.num_lstm_cell_units),
                                              state_is_tuple=True)
                    new_h, cell_state = tf.nn.bidirectional_dynamic_rnn(
                        fcell,
                        bcell,
                        state_features,
                        sequence_length=seq_lengths,
                        dtype=tf.float32)
                    new_h = tf.reduce_sum(tf.stack(new_h, axis=2), axis=2)
                    cell_state = rnn.LSTMStateTuple(
                        tf.reduce_sum(tf.stack([cs.c for cs in cell_state],
                                               axis=1),
                                      axis=1),
                        tf.reduce_sum(tf.stack([cs.h for cs in cell_state],
                                               axis=1),
                                      axis=1))
                elif self.encoder_rnn_type == 'lstm':
                    cell = rnn.BasicLSTMCell(
                        num_units=self.num_lstm_cell_units,
                        state_is_tuple=True)
                    new_h, cell_state = tf.nn.dynamic_rnn(
                        cell=cell,
                        dtype=tf.float32,
                        sequence_length=seq_lengths,
                        inputs=state_features)
                elif self.encoder_rnn_type == 'rnn':
                    cell = rnn.BasicRNNCell(num_units=self.num_lstm_cell_units)
                    new_h, cell_state = tf.nn.dynamic_rnn(
                        cell=cell,
                        dtype=tf.float32,
                        sequence_length=seq_lengths,
                        inputs=state_features)
                elif self.encoder_rnn_type == 'gru':
                    cell = rnn.GRUCell(num_units=self.num_lstm_cell_units)
                    new_h, cell_state = tf.nn.dynamic_rnn(
                        cell=cell,
                        dtype=tf.float32,
                        sequence_length=seq_lengths,
                        inputs=state_features)
                else:
                    raise ValueError('Unknown encoder rnn type')

                if self.concat_state_feature_direct_prediction:
                    all_states = tf.concat([new_h, state_features], axis=-1)
                else:
                    all_states = new_h
                return all_states, cell_state.h, cell_state.c

        # program token [bs, len] -> embedded tokens [len] list of [bs, dim]
        # tensors
        # Embedding
        def Token_Embedding(token_dim,
                            embedding_dim,
                            scope='Token_Embedding',
                            reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                # We add token_dim + 1, to use this tokens as a starting token
                # <s>
                embedding_map = tf.get_variable(
                    name="embedding_map",
                    shape=[token_dim + 1, embedding_dim],
                    initializer=tf.random_uniform_initializer(minval=-0.01,
                                                              maxval=0.01))

                def embedding_lookup(t):
                    embedding = tf.nn.embedding_lookup(embedding_map, t)
                    return embedding

                return embedding_lookup

        # program token feature [bs, u] -> program token [bs, dim_program_token]
        # MLP
        def Token_Decoder(f, token_dim, scope='Token_Decoder', reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                _ = fc(f,
                       token_dim,
                       is_train,
                       info=not reuse,
                       batch_norm=False,
                       activation_fn=None,
                       name='fc1')
                return _

        # Input {{{
        # =========
        # test_k list of [bs, ac, max_demo_len - 1] tensor
        self.gt_test_actions_onehot = [
            single_test_a_h for single_test_a_h in tf.unstack(
                tf.transpose(self.test_a_h, [0, 1, 3, 2]), axis=1)
        ]
        # test_k list of [bs, max_demo_len - 1] tensor
        self.gt_test_actions_tokens = [
            single_test_a_h_token
            for single_test_a_h_token in tf.unstack(self.test_a_h_tokens,
                                                    axis=1)
        ]

        # a_h = self.a_h
        # }}}

        # Graph {{{
        # =========
        # Demo -> Demo feature
        demo_h_list = []
        demo_c_list = []
        demo_feature_history_list = []
        for i in range(self.k):
            demo_feature_history, demo_h, demo_c = \
                Demo_Encoder(s_h[:, i], self.per[:, i],
                             demo_len[:, i], reuse=i > 0)
            demo_feature_history_list.append(demo_feature_history)
            demo_h_list.append(demo_h)
            demo_c_list.append(demo_c)
            if i == 0: log.warning(demo_feature_history)
        demo_h_stack = tf.stack(demo_h_list, axis=1)  # [bs, k, v]
        demo_c_stack = tf.stack(demo_c_list, axis=1)  # [bs, k, v]
        if self.demo_aggregation == 'concat':  # [bs, k*v]
            demo_h_summary = tf.reshape(demo_h_stack, [self.batch_size, -1])
            demo_c_summary = tf.reshape(demo_c_stack, [self.batch_size, -1])
        elif self.demo_aggregation == 'avgpool':  # [bs, v]
            demo_h_summary = tf.reduce_mean(demo_h_stack, axis=1)
            demo_c_summary = tf.reduce_mean(demo_c_stack, axis=1)
        elif self.demo_aggregation == 'maxpool':  # [bs, v]
            demo_h_summary = tf.squeeze(tf.layers.max_pooling1d(
                demo_h_stack,
                demo_h_stack.get_shape().as_list()[1],
                1,
                padding='valid',
                data_format='channels_last'),
                                        axis=1)
            demo_c_summary = tf.squeeze(tf.layers.max_pooling1d(
                demo_c_stack,
                demo_c_stack.get_shape().as_list()[1],
                1,
                padding='valid',
                data_format='channels_last'),
                                        axis=1)
        else:
            raise ValueError('Unknown demo aggregation type')

        def get_DecoderHelper(embedding_lookup,
                              seq_lengths,
                              token_dim,
                              gt_tokens=None,
                              unroll_type='teacher_forcing'):
            if unroll_type == 'teacher_forcing':
                if gt_tokens is None:
                    raise ValueError('teacher_forcing requires gt_tokens')
                embedding = embedding_lookup(gt_tokens)
                helper = seq2seq.TrainingHelper(embedding, seq_lengths)
            elif unroll_type == 'scheduled_sampling':
                if gt_tokens is None:
                    raise ValueError('scheduled_sampling requires gt_tokens')
                embedding = embedding_lookup(gt_tokens)
                # sample_prob 1.0: always sample from ground truth
                # sample_prob 0.0: always sample from prediction
                helper = seq2seq.ScheduledEmbeddingTrainingHelper(
                    embedding,
                    seq_lengths,
                    embedding_lookup,
                    1.0 - self.sample_prob,
                    seed=None,
                    scheduling_seed=None)
            elif unroll_type == 'greedy':
                # during evaluation, we perform greedy unrolling.
                start_token = tf.zeros([self.batch_size],
                                       dtype=tf.int32) + token_dim
                end_token = token_dim - 1
                helper = seq2seq.GreedyEmbeddingHelper(embedding_lookup,
                                                       start_token, end_token)
            else:
                raise ValueError('Unknown unroll type')
            return helper

        def LSTM_Decoder(visual_h,
                         visual_c,
                         gt_tokens,
                         lstm_cell,
                         unroll_type='teacher_forcing',
                         seq_lengths=None,
                         max_sequence_len=10,
                         token_dim=50,
                         embedding_dim=128,
                         init_state=None,
                         scope='LSTM_Decoder',
                         reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                # augmented embedding with token_dim + 1 (<s>) token
                s_token = tf.zeros([self.batch_size, 1],
                                   dtype=gt_tokens.dtype) + token_dim + 1
                gt_tokens = tf.concat([s_token, gt_tokens[:, :-1]], axis=1)

                embedding_lookup = Token_Embedding(token_dim,
                                                   embedding_dim,
                                                   reuse=reuse)

                # dynamic_decode implementation
                helper = get_DecoderHelper(embedding_lookup,
                                           seq_lengths,
                                           token_dim,
                                           gt_tokens=gt_tokens,
                                           unroll_type=unroll_type)
                projection_layer = layers_core.Dense(token_dim,
                                                     use_bias=False,
                                                     name="output_projection")
                if init_state is None:
                    init_state = rnn.LSTMStateTuple(visual_c, visual_h)
                decoder = seq2seq.BasicDecoder(lstm_cell,
                                               helper,
                                               init_state,
                                               output_layer=projection_layer)
                # pred_length [batch_size]: length of the predicted sequence
                outputs, final_context_state, pred_length = seq2seq.dynamic_decode(
                    decoder,
                    maximum_iterations=max_sequence_len,
                    scope='dynamic_decoder')
                pred_length = tf.expand_dims(pred_length, axis=1)

                # as dynamic_decode generate variable length sequence output,
                # we pad it dynamically to match input embedding shape.
                rnn_output = outputs.rnn_output
                sz = tf.shape(rnn_output)
                dynamic_pad = tf.zeros(
                    [sz[0], max_sequence_len - sz[1], sz[2]],
                    dtype=rnn_output.dtype)
                pred_seq = tf.concat([rnn_output, dynamic_pad], axis=1)
                seq_shape = pred_seq.get_shape().as_list()
                pred_seq.set_shape(
                    [seq_shape[0], max_sequence_len, seq_shape[2]])

                pred_seq = tf.transpose(
                    tf.reshape(pred_seq,
                               [self.batch_size, max_sequence_len, -1]),
                    [0, 2, 1])  # make_dim: [bs, n, len]
                return pred_seq, pred_length, final_context_state

        if self.scheduled_sampling:
            train_unroll_type = 'scheduled_sampling'
        else:
            train_unroll_type = 'teacher_forcing'

        # Attn
        lstm_cell = rnn.BasicLSTMCell(num_units=self.num_lstm_cell_units)
        attn_mechanisms = []
        for j in range(self.test_k):
            attn_mechanisms_k = []
            for i in range(self.k):
                with tf.variable_scope('AttnMechanism', reuse=i > 0 or j > 0):
                    if self.attn_type == 'luong':
                        attn_mechanism = seq2seq.LuongAttention(
                            self.num_lstm_cell_units,
                            demo_feature_history_list[i],
                            memory_sequence_length=self.demo_len[:, i])
                    elif self.attn_type == 'luong_monotonic':
                        attn_mechanism = seq2seq.LuongMonotonicAttention(
                            self.num_lstm_cell_units,
                            demo_feature_history_list[i],
                            memory_sequence_length=self.demo_len[:, i])
                    else:
                        raise ValueError('Unknown attention type')
                attn_mechanisms_k.append(attn_mechanism)
            attn_mechanisms.append(attn_mechanisms_k)

        self.attn_cells = []
        for i in range(self.test_k):
            attn_cell = PoolingAttentionWrapper(
                lstm_cell,
                attn_mechanisms[i],
                attention_layer_size=self.num_lstm_cell_units,
                alignment_history=True,
                output_attention=True,
                pooling='avgpool')
            self.attn_cells.append(attn_cell)

        # Demo + current state -> action
        self.pred_action_list = []
        self.greedy_pred_action_list = []
        self.greedy_pred_action_len_list = []
        for i in range(self.test_k):
            attn_init_state = self.attn_cells[i].zero_state(
                self.batch_size,
                dtype=tf.float32).clone(cell_state=rnn.LSTMStateTuple(
                    demo_h_summary, demo_c_summary))
            embedding_dim = demo_h_summary.get_shape().as_list()[-1]
            pred_action, pred_action_len, action_state = LSTM_Decoder(
                demo_h_summary,
                demo_c_summary,
                self.gt_test_actions_tokens[i],
                self.attn_cells[i],
                unroll_type=train_unroll_type,
                seq_lengths=self.test_action_len[:, i],
                max_sequence_len=self.max_action_len,
                token_dim=self.action_space,
                embedding_dim=embedding_dim,
                init_state=attn_init_state,
                scope='Manipulation',
                reuse=i > 0)
            assert pred_action.get_shape() == \
                self.gt_test_actions_onehot[i].get_shape()
            self.pred_action_list.append(pred_action)

            greedy_attn_init_state = self.attn_cells[i].zero_state(
                self.batch_size,
                dtype=tf.float32).clone(cell_state=rnn.LSTMStateTuple(
                    demo_h_summary, demo_c_summary))
            greedy_pred_action, greedy_pred_action_len, \
                greedy_action_state = LSTM_Decoder(
                    demo_h_summary, demo_c_summary, self.gt_test_actions_tokens[i],
                    self.attn_cells[i], unroll_type='greedy',
                    seq_lengths=self.test_action_len[:, i],
                    max_sequence_len=self.max_action_len,
                    token_dim=self.action_space,
                    embedding_dim=embedding_dim,
                    init_state=greedy_attn_init_state,
                    scope='Manipulation', reuse=True
                )
            assert greedy_pred_action.get_shape() == \
                self.gt_test_actions_onehot[i].get_shape()
            self.greedy_pred_action_list.append(greedy_pred_action)
            self.greedy_pred_action_len_list.append(greedy_pred_action_len)
        # }}}

        # Build losses {{{
        # ================
        def Sequence_Loss(pred_sequence,
                          gt_sequence,
                          pred_sequence_lengths=None,
                          gt_sequence_lengths=None,
                          max_sequence_len=None,
                          token_dim=None,
                          sequence_type='program',
                          name=None):
            with tf.name_scope(name, "SequenceOutput") as scope:
                log.warning(scope)
                max_sequence_lengths = tf.maximum(pred_sequence_lengths,
                                                  gt_sequence_lengths)
                min_sequence_lengths = tf.minimum(pred_sequence_lengths,
                                                  gt_sequence_lengths)
                gt_mask = tf.sequence_mask(gt_sequence_lengths[:, 0],
                                           max_sequence_len,
                                           dtype=tf.float32,
                                           name='mask')
                max_mask = tf.sequence_mask(max_sequence_lengths[:, 0],
                                            max_sequence_len,
                                            dtype=tf.float32,
                                            name='max_mask')
                min_mask = tf.sequence_mask(min_sequence_lengths[:, 0],
                                            max_sequence_len,
                                            dtype=tf.float32,
                                            name='min_mask')
                labels = tf.reshape(
                    tf.transpose(gt_sequence, [0, 2, 1]),
                    [self.batch_size * max_sequence_len, token_dim])
                logits = tf.reshape(
                    tf.transpose(pred_sequence, [0, 2, 1]),
                    [self.batch_size * max_sequence_len, token_dim])
                # [bs, max_program_len]
                cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
                    labels=labels, logits=logits)
                # normalize loss
                loss = tf.reduce_sum(cross_entropy * tf.reshape(gt_mask, [-1])) / \
                    tf.reduce_sum(gt_mask)
                output = [gt_sequence, pred_sequence]

                label_argmax = tf.argmax(labels, axis=-1)
                logit_argmax = tf.argmax(logits, axis=-1)

                # accuracy
                # token level acc
                correct_token_pred = tf.reduce_sum(
                    tf.to_float(tf.equal(label_argmax, logit_argmax)) *
                    tf.reshape(min_mask, [-1]))
                token_accuracy = correct_token_pred / tf.reduce_sum(max_mask)
                # seq level acc
                seq_equal = tf.equal(
                    tf.reshape(
                        tf.to_float(label_argmax) * tf.reshape(gt_mask, [-1]),
                        [self.batch_size, -1]),
                    tf.reshape(
                        tf.to_float(logit_argmax) * tf.reshape(gt_mask, [-1]),
                        [self.batch_size, -1]))
                len_equal = tf.equal(gt_sequence_lengths[:, 0],
                                     pred_sequence_lengths[:, 0])
                is_same_seq = tf.logical_and(tf.reduce_all(seq_equal, axis=-1),
                                             len_equal)
                seq_accuracy = tf.reduce_sum(
                    tf.to_float(is_same_seq)) / self.batch_size

                pred_tokens = None
                syntax_accuracy = None
                is_correct_syntax = None

                output_stat = SequenceLossOutput(
                    mask=gt_mask,
                    loss=loss,
                    output=output,
                    token_acc=token_accuracy,
                    seq_acc=seq_accuracy,
                    syntax_acc=syntax_accuracy,
                    is_correct_syntax=is_correct_syntax,
                    pred_tokens=pred_tokens,
                    is_same_seq=is_same_seq,
                )

                return output_stat

        self.loss = 0
        self.output = []

        # Manipulation network loss
        avg_action_loss = 0
        avg_action_token_acc = 0
        avg_action_seq_acc = 0
        seq_match = []
        for i in range(self.test_k):
            action_stat = Sequence_Loss(
                self.pred_action_list[i],
                self.gt_test_actions_onehot[i],
                pred_sequence_lengths=tf.expand_dims(self.test_action_len[:,
                                                                          i],
                                                     axis=1),
                gt_sequence_lengths=tf.expand_dims(self.test_action_len[:, i],
                                                   axis=1),
                max_sequence_len=self.max_action_len,
                token_dim=self.action_space,
                sequence_type='action',
                name="Action_Sequence_Loss_{}".format(i))
            avg_action_loss += action_stat.loss
            avg_action_token_acc += action_stat.token_acc
            avg_action_seq_acc += action_stat.seq_acc
            seq_match.append(action_stat.is_same_seq)
            self.output.extend(action_stat.output)
        avg_action_loss /= self.test_k
        avg_action_token_acc /= self.test_k
        avg_action_seq_acc /= self.test_k
        avg_action_seq_all_acc = tf.reduce_sum(
            tf.to_float(tf.reduce_all(tf.stack(seq_match, axis=1),
                                      axis=-1))) / self.batch_size
        self.loss += avg_action_loss

        greedy_avg_action_loss = 0
        greedy_avg_action_token_acc = 0
        greedy_avg_action_seq_acc = 0
        greedy_seq_match = []
        for i in range(self.test_k):
            greedy_action_stat = Sequence_Loss(
                self.greedy_pred_action_list[i],
                self.gt_test_actions_onehot[i],
                pred_sequence_lengths=self.greedy_pred_action_len_list[i],
                gt_sequence_lengths=tf.expand_dims(self.test_action_len[:, i],
                                                   axis=1),
                max_sequence_len=self.max_action_len,
                token_dim=self.action_space,
                sequence_type='action',
                name="Greedy_Action_Sequence_Loss_{}".format(i))
            greedy_avg_action_loss += greedy_action_stat.loss
            greedy_avg_action_token_acc += greedy_action_stat.token_acc
            greedy_avg_action_seq_acc += greedy_action_stat.seq_acc
            greedy_seq_match.append(greedy_action_stat.is_same_seq)
        greedy_avg_action_loss /= self.test_k
        greedy_avg_action_token_acc /= self.test_k
        greedy_avg_action_seq_acc /= self.test_k
        greedy_avg_action_seq_all_acc = tf.reduce_sum(
            tf.to_float(
                tf.reduce_all(tf.stack(greedy_seq_match, axis=1),
                              axis=-1))) / self.batch_size
        # }}}

        # Evalutaion {{{
        # ==============
        self.report_loss = {}
        self.report_accuracy = {}
        self.report_hist = {}
        self.report_loss['avg_action_loss'] = avg_action_loss
        self.report_accuracy['avg_action_token_acc'] = avg_action_token_acc
        self.report_accuracy['avg_action_seq_acc'] = avg_action_seq_acc
        self.report_accuracy['avg_action_seq_all_acc'] = avg_action_seq_all_acc
        self.report_loss['greedy_avg_action_loss'] = greedy_avg_action_loss
        self.report_accuracy['greedy_avg_action_token_acc'] = \
            greedy_avg_action_token_acc
        self.report_accuracy['greedy_avg_action_seq_acc'] = \
            greedy_avg_action_seq_acc
        self.report_accuracy['greedy_avg_action_seq_all_acc'] = \
            greedy_avg_action_seq_all_acc
        self.report_output = []
        # dummy fetch values for evaler
        self.ground_truth_program = self.program
        self.pred_program = []
        self.greedy_pred_program = []
        self.greedy_pred_program_len = []
        self.greedy_program_is_correct_syntax = []
        self.program_is_correct_syntax = []
        self.program_num_execution_correct = []
        self.program_is_correct_execution = []
        self.greedy_num_execution_correct = []
        self.greedy_is_correct_execution = []

        #

        # Tensorboard Summary {{{
        # =======================
        # Loss
        def train_test_scalar_summary(name, value):
            tf.summary.scalar(name, value, collections=['train'])
            tf.summary.scalar("test_{}".format(name),
                              value,
                              collections=['test'])

        train_test_scalar_summary("loss/loss", self.loss)

        if self.scheduled_sampling:
            train_test_scalar_summary("loss/sample_prob", self.sample_prob)
        train_test_scalar_summary("loss/avg_action_loss", avg_action_loss)
        train_test_scalar_summary("loss/avg_action_token_acc",
                                  avg_action_token_acc)
        train_test_scalar_summary("loss/avg_action_seq_acc",
                                  avg_action_seq_acc)
        train_test_scalar_summary("loss/avg_action_seq_all_acc",
                                  avg_action_seq_all_acc)
        tf.summary.scalar("test_loss/greedy_avg_action_loss",
                          greedy_avg_action_loss,
                          collections=['test'])
        tf.summary.scalar("test_loss/greedy_avg_action_token_acc",
                          greedy_avg_action_token_acc,
                          collections=['test'])
        tf.summary.scalar("test_loss/greedy_avg_action_seq_acc",
                          greedy_avg_action_seq_acc,
                          collections=['test'])
        tf.summary.scalar("test_loss/greedy_avg_action_seq_all_acc",
                          greedy_avg_action_seq_all_acc,
                          collections=['test'])

        def program2str(p_token, p_len):
            program_str = []
            for i in range(self.batch_size):
                program_str.append(
                    self.vocab.intseq2str(
                        np.argmax(p_token[i], axis=0)[:p_len[i, 0]]))
            program_str = np.stack(program_str, axis=0)
            return program_str

        tf.summary.text('program_id/id',
                        self.program_id,
                        collections=['train'])
        tf.summary.text('program/ground_truth',
                        tf.py_func(program2str,
                                   [self.program, self.program_len],
                                   tf.string),
                        collections=['train'])
        tf.summary.text('test_program_id/id',
                        self.program_id,
                        collections=['test'])
        tf.summary.text('test_program/ground_truth',
                        tf.py_func(program2str,
                                   [self.program, self.program_len],
                                   tf.string),
                        collections=['test'])

        # Visualization
        def visualized_map(pred, gt):
            dummy = tf.expand_dims(tf.zeros_like(pred), axis=-1)
            pred = tf.expand_dims(tf.nn.softmax(pred, dim=1), axis=-1)
            gt = tf.expand_dims(gt, axis=-1)
            return tf.concat([pred, gt, dummy], axis=-1)

        # Attention visualization
        def build_alignments(alignment_history):
            alignments = []
            for i in alignment_history:
                align = tf.expand_dims(tf.transpose(i.stack(), [1, 2, 0]),
                                       -1) * 255
                align_shape = tf.shape(align)
                alignments.append(align)
                alignments.append(
                    tf.zeros([align_shape[0], 1, align_shape[2], 1],
                             dtype=tf.float32) + 255)
            alignments_image = tf.reshape(
                tf.tile(tf.concat(alignments, axis=1), [1, 1, 1, self.k]),
                [align_shape[0], -1, align_shape[2] * self.k, 1])
            return alignments_image

        alignments = build_alignments(action_state.alignment_history)
        tf.summary.image("attn", alignments, collections=['train'])
        tf.summary.image("test_attn", alignments, collections=['test'])

        greedy_alignments = build_alignments(
            greedy_action_state.alignment_history)
        tf.summary.image("test_greedy_attn",
                         greedy_alignments,
                         collections=['test'])

        if self.pixel_input:
            tf.summary.image("state/initial_state",
                             self.s_h[:, 0, 0, :, :, :],
                             collections=['train'])
            tf.summary.image("state/demo_program_1",
                             self.s_h[0, 0, :, :, :, :],
                             max_outputs=self.max_demo_len,
                             collections=['train'])

        i = 0  # show only the first demo (among k)
        tf.summary.image("visualized_action/k_{}".format(i),
                         visualized_map(self.pred_action_list[i],
                                        self.gt_test_actions_onehot[i]),
                         collections=['train'])
        tf.summary.image("test_visualized_action/k_{}".format(i),
                         visualized_map(self.pred_action_list[i],
                                        self.gt_test_actions_onehot[i]),
                         collections=['test'])
        tf.summary.image("test_visualized_greedy_action/k_{}".format(i),
                         visualized_map(self.greedy_pred_action_list[i],
                                        self.gt_test_actions_onehot[i]),
                         collections=['test'])

        # Visualize demo features
        if self.debug:
            i = 0  # show only the first images
            tf.summary.image("debug/demo_feature_history/k_{}".format(i),
                             tf.image.grayscale_to_rgb(
                                 tf.expand_dims(demo_feature_history_list[i],
                                                -1)),
                             collections=['train'])
        # }}}
        print('\033[93mSuccessfully loaded the model.\033[0m')
Пример #12
0
    def build_decoder_cell(self):
        encoder_outputs = self.encoder_outputs
        encoder_last_states = self.encoder_last_states
        encoder_len = self.encoder_len
        # for beam search copy the batch by beam depth times
        if self.mode == "test":
            encoder_outputs = seq2seq.tile_batch(encoder_outputs,
                                                 multiplier=self.beam_depth)
            encoder_last_states = nest.map_structure(
                lambda s: seq2seq.tile_batch(s, self.beam_depth),
                encoder_last_states)
            encoder_len = seq2seq.tile_batch(self.encoder_len,
                                             self.beam_depth)

        # Bahdanau attention
        self.attention_mechanism = seq2seq.BahdanauAttention(
            num_units=self.state_size * 2,
            memory=encoder_outputs,
            memory_sequence_length=encoder_len)
        # Luong attention
        if self.attention_mode == "Luong":
            self.attention_mechanism = seq2seq.LuongAttention(
                num_units=self.state_size * 2,
                memory=encoder_outputs,
                memory_sequence_length=encoder_len)
        # instantiate decoder cells (uni-directional multi GRU cell)
        decoder_cell_list = [tf.nn.rnn_cell.ResidualWrapper(
            tf.nn.rnn_cell.GRUCell(self.state_size * 2)) for _
            in range(self.num_layers)]
        # apply dropout during training
        if self.mode == "train":
            for i in range(self.num_layers):
                decoder_cell_list[i] = DropoutWrapper(decoder_cell_list[i],
                                                      self.dropout_keep_prob)

        # essential for skip connection
        def atten_decoder_input_fn(inputs, attention):
            _input_layer = Dense(self.state_size * 2)
            return _input_layer(tf.concat([inputs, attention], 1))

        # we only apply attention to last layer of encoder
        decoder_cell_list[-1] = seq2seq.AttentionWrapper(decoder_cell_list[-1],
                                                         self.attention_mechanism,
                                                         self.state_size * 2,
                                                         cell_input_fn=atten_decoder_input_fn)

        # To be compatible with AttentionWrapper, the encoder last state
        # of the top layer should be converted into the AttentionWrapperState form
        # We can easily do this by calling AttentionWrapper.zero_state
        # if test mode every batch should be copied by beam depth times
        if self.mode == "train":
            batch_size = self.batch_size
        else:
            batch_size = self.batch_size * self.beam_depth
        init_state = []

        for i in range(self.num_layers):
            init_state.append(encoder_last_states[i])
        init_state[-1] = decoder_cell_list[-1].zero_state(batch_size,
                                                          dtype=tf.float32)
        decoder_init_state = tuple(init_state)
        # decoder_init_state = encoder_last_states
        return tf.nn.rnn_cell.MultiRNNCell(
            decoder_cell_list), decoder_init_state
Пример #13
0
    for x in range(num_layers):
        encoder_last_state_c = tf.concat(
            [encoder_last_state_fw[x].c, encoder_last_state_bw[x].c], 1)
        encoder_last_state_h = tf.concat(
            [encoder_last_state_fw[x].h, encoder_last_state_bw[x].h], 1)
        encoder_last_state.append(
            tf.contrib.rnn.LSTMStateTuple(c=encoder_last_state_c,
                                          h=encoder_last_state_h))
    encoder_last_state = tuple(encoder_last_state)

    #batch_size = batch_size * beam_width
    ######################################################### ends building encoder
    # building training decoder, no beam search
    with tf.variable_scope('shared_attention_mechanism'):
        attention_mechanism = seq2seq.LuongAttention(
            num_units=hidden_dim * 2,
            memory=encoder_outputs,
            memory_sequence_length=encoder_inputs_length)
    global_decoder_cell = tf.contrib.rnn.MultiRNNCell([
        tf.nn.rnn_cell.BasicLSTMCell(hidden_dim * 2) for _ in range(num_layers)
    ])
    projection_layer = Dense(label_dim)

    decoder_cell = seq2seq.AttentionWrapper(
        cell=global_decoder_cell,
        #tf.nn.rnn_cell.BasicLSTMCell(hidden_dim*2),
        attention_mechanism=attention_mechanism,
        attention_layer_size=hidden_dim * 2)
    #input_vectors = tf.nn.embedding_lookup(tgt_w, decoder_inputs)
    print(decoder_inputs.shape, decoder_inputs.shape)
    #decoder training
    training_helper = seq2seq.TrainingHelper(
Пример #14
0
    def __init__(self,
                 vocab_size,
                 embed_size,
                 num_unit,
                 latent_dim,
                 emoji_dim,
                 batch_size,
                 kl_ceiling,
                 bow_ceiling,
                 decoder_layer=1,
                 start_i=1,
                 end_i=2,
                 beam_width=0,
                 maximum_iterations=50,
                 max_gradient_norm=5,
                 lr=1e-3,
                 dropout=0.2,
                 num_gpu=2,
                 cell_type=tf.nn.rnn_cell.GRUCell,
                 is_seq2seq=False):
        self.ori_sample = None
        self.rep_sample = None
        self.out_sample = None

        self.sess = None

        self.loss_weight = tf.placeholder_with_default(0., shape=())
        self.policy_weight = tf.placeholder_with_default(1., shape=())
        self.ac_vec = tf.placeholder(tf.float32,
                                     shape=[batch_size],
                                     name="accuracy_vector")
        self.ac5_vec = tf.placeholder(tf.float32,
                                      shape=[batch_size],
                                      name="top5_accuracy_vector")

        self.is_policy = tf.placeholder_with_default(False, shape=())
        shape = [batch_size, latent_dim]
        self.rdm = tf.placeholder_with_default(np.zeros(shape,
                                                        dtype=np.float32),
                                               shape=shape)
        self.q_rdm = tf.placeholder_with_default(np.zeros(shape,
                                                          dtype=np.float32),
                                                 shape=shape)

        self.end_i = end_i
        self.batch_size = batch_size
        self.num_gpu = num_gpu
        self.num_unit = num_unit
        self.dropout = tf.placeholder_with_default(dropout, (), name="dropout")
        self.beam_width = beam_width
        self.cell_type = cell_type

        self.emoji = tf.placeholder(tf.int32, shape=[batch_size], name="emoji")
        self.ori = tf.placeholder(tf.int32,
                                  shape=[None, batch_size],
                                  name="original_tweet")  # [len, batch_size]
        self.ori_len = tf.placeholder(tf.int32,
                                      shape=[batch_size],
                                      name="original_tweet_length")
        self.rep = tf.placeholder(tf.int32,
                                  shape=[None, batch_size],
                                  name="response_tweet")
        self.rep_len = tf.placeholder(tf.int32,
                                      shape=[batch_size],
                                      name="response_tweet_length")
        self.rep_input = tf.placeholder(tf.int32,
                                        shape=[None, batch_size],
                                        name="response_start_tag")
        self.rep_output = tf.placeholder(tf.int32,
                                         shape=[None, batch_size],
                                         name="response_end_tag")

        self.reward = tf.placeholder(tf.float32,
                                     shape=[batch_size],
                                     name="reward")

        self.kl_weight = tf.placeholder_with_default(1.,
                                                     shape=(),
                                                     name="kl_weight")

        self.placeholders = [
            self.emoji, self.ori, self.ori_len, self.rep, self.rep_len,
            self.rep_input, self.rep_output
        ]

        with tf.variable_scope("embeddings"):
            embedding = Embedding(vocab_size, embed_size)

            ori_emb = embedding(
                self.ori)  # [max_len, batch_size, embedding_size]
            rep_emb = embedding(self.rep)
            rep_input_emb = embedding(self.rep_input)
            emoji_emb = embedding(self.emoji)  # [batch_size, embedding_size]

        with tf.variable_scope("original_tweet_encoder"):
            ori_encoder_output, ori_encoder_state = build_bidirectional_rnn(
                num_unit,
                ori_emb,
                self.ori_len,
                cell_type,
                num_gpu,
                self.dropout,
                base_gpu=0)
            ori_encoder_state_flat = tf.concat(
                [ori_encoder_state[0], ori_encoder_state[1]], axis=1)

        emoji_vec = tf.layers.dense(emoji_emb,
                                    emoji_dim,
                                    activation=tf.nn.tanh)
        self.emoji_vec = emoji_emb
        # emoji_vec = tf.ones([batch_size, emoji_dim], tf.float32)
        condition_flat = tf.concat([ori_encoder_state_flat, emoji_vec], axis=1)

        with tf.variable_scope("response_tweet_encoder"):
            _, rep_encoder_state = build_bidirectional_rnn(num_unit,
                                                           rep_emb,
                                                           self.rep_len,
                                                           cell_type,
                                                           num_gpu,
                                                           self.dropout,
                                                           base_gpu=2)
            rep_encoder_state_flat = tf.concat(
                [rep_encoder_state[0], rep_encoder_state[1]], axis=1)

        with tf.variable_scope("representation_network"):
            rn_input = tf.concat([rep_encoder_state_flat, condition_flat],
                                 axis=1)
            # simpler representation network
            # r_hidden = rn_input
            r_hidden = tf.layers.dense(
                rn_input,
                latent_dim,
                activation=tf.nn.relu,
                name="r_net_hidden")  # int(1.6 * latent_dim)
            r_hidden_mu = tf.layers.dense(
                r_hidden, latent_dim,
                activation=tf.nn.relu)  # int(1.3 * latent_dim)
            r_hidden_var = tf.layers.dense(r_hidden,
                                           latent_dim,
                                           activation=tf.nn.relu)
            self.mu = tf.layers.dense(r_hidden_mu,
                                      latent_dim,
                                      activation=tf.nn.tanh,
                                      name="q_mean")
            self.log_var = tf.layers.dense(r_hidden_var,
                                           latent_dim,
                                           activation=tf.nn.tanh,
                                           name="q_log_var")

        with tf.variable_scope("prior_network"):
            # simpler prior network
            # p_hidden = condition_flat
            p_hidden = tf.layers.dense(condition_flat,
                                       int(0.62 * latent_dim),
                                       activation=tf.nn.relu,
                                       name="r_net_hidden")
            p_hidden_mu = tf.layers.dense(p_hidden,
                                          int(0.77 * latent_dim),
                                          activation=tf.nn.relu)
            p_hidden_var = tf.layers.dense(p_hidden,
                                           int(0.77 * latent_dim),
                                           activation=tf.nn.relu)
            self.p_mu = tf.layers.dense(p_hidden_mu,
                                        latent_dim,
                                        activation=tf.nn.tanh,
                                        name="p_mean")
            self.p_log_var = tf.layers.dense(p_hidden_var,
                                             latent_dim,
                                             activation=tf.nn.tanh,
                                             name="p_log_var")

        with tf.variable_scope("reparameterization"):
            self.normal = tf.cond(
                self.is_policy, lambda: self.rdm,
                lambda: tf.random_normal(shape=tf.shape(self.mu)))
            self.z_sample = self.mu + tf.exp(self.log_var / 2.) * self.normal

            self.q_normal = tf.cond(
                self.is_policy, lambda: self.q_rdm,
                lambda: tf.random_normal(shape=tf.shape(self.p_mu)))
            self.q_z_sample = self.p_mu + tf.exp(
                self.p_log_var / 2.) * self.q_normal

        if is_seq2seq:
            self.z_sample = self.z_sample - self.z_sample
            self.q_z_sample = self.q_z_sample - self.q_z_sample

        with tf.variable_scope("decoder_train") as decoder_scope:
            if decoder_layer == 2:
                train_decoder_init_state = (
                    tf.concat([self.z_sample, ori_encoder_state[0], emoji_vec],
                              axis=1),
                    tf.concat([self.z_sample, ori_encoder_state[1], emoji_vec],
                              axis=1))
                dim = latent_dim + num_unit + emoji_dim
                cell = tf.nn.rnn_cell.MultiRNNCell([
                    create_rnn_cell(dim, 2, cell_type, num_gpu, self.dropout),
                    create_rnn_cell(dim, 3, cell_type, num_gpu, self.dropout)
                ])
            else:
                train_decoder_init_state = tf.concat(
                    [self.z_sample, ori_encoder_state_flat, emoji_vec], axis=1)
                dim = latent_dim + 2 * num_unit + emoji_dim
                cell = create_rnn_cell(dim, 2, cell_type, num_gpu,
                                       self.dropout)

            with tf.variable_scope("attention"):
                memory = tf.concat(
                    [ori_encoder_output[0], ori_encoder_output[1]], axis=2)
                memory = tf.transpose(memory, [1, 0, 2])

                attention_mechanism = seq2seq.LuongAttention(
                    dim,
                    memory,
                    memory_sequence_length=self.ori_len,
                    scale=True)
                # attention_mechanism = seq2seq.BahdanauAttention(
                #     num_unit, memory, memory_sequence_length=self.ori_len)

            decoder_cell = seq2seq.AttentionWrapper(
                cell, attention_mechanism, attention_layer_size=dim
            )  # TODO: add_name; what atten layer size means
            # decoder_cell = cell

            helper = seq2seq.TrainingHelper(rep_input_emb,
                                            self.rep_len + 1,
                                            time_major=True)
            projection_layer = layers_core.Dense(vocab_size,
                                                 use_bias=False,
                                                 name="output_projection")
            decoder = seq2seq.BasicDecoder(
                decoder_cell,
                helper,
                decoder_cell.zero_state(
                    batch_size,
                    tf.float32).clone(cell_state=train_decoder_init_state),
                output_layer=projection_layer)
            train_outputs, _, _ = seq2seq.dynamic_decode(
                decoder,
                output_time_major=True,
                swap_memory=True,
                scope=decoder_scope)
            self.logits = train_outputs.rnn_output

        with tf.variable_scope("decoder_infer") as decoder_scope:
            # normal_sample = tf.random_normal(shape=(batch_size, latent_dim))

            if decoder_layer == 2:
                infer_decoder_init_state = (tf.concat(
                    [self.q_z_sample, ori_encoder_state[0], emoji_vec],
                    axis=1),
                                            tf.concat([
                                                self.q_z_sample,
                                                ori_encoder_state[1], emoji_vec
                                            ],
                                                      axis=1))
            else:
                infer_decoder_init_state = tf.concat(
                    [self.q_z_sample, ori_encoder_state_flat, emoji_vec],
                    axis=1)

            start_tokens = tf.fill([batch_size], start_i)
            end_token = end_i

            if beam_width > 0:
                infer_decoder_init_state = seq2seq.tile_batch(
                    infer_decoder_init_state, multiplier=beam_width)
                decoder = seq2seq.BeamSearchDecoder(
                    cell=decoder_cell,
                    embedding=embedding.coder,
                    start_tokens=start_tokens,
                    end_token=end_token,
                    initial_state=decoder_cell.zero_state(
                        batch_size * beam_width,
                        tf.float32).clone(cell_state=infer_decoder_init_state),
                    beam_width=beam_width,
                    output_layer=projection_layer,
                    length_penalty_weight=0.0)
            else:
                helper = seq2seq.GreedyEmbeddingHelper(embedding.coder,
                                                       start_tokens, end_token)
                decoder = seq2seq.BasicDecoder(
                    decoder_cell,
                    helper,
                    decoder_cell.zero_state(
                        batch_size,
                        tf.float32).clone(cell_state=infer_decoder_init_state),
                    output_layer=projection_layer  # applied per timestep
                )

            # Dynamic decoding
            infer_outputs, _, infer_lengths = seq2seq.dynamic_decode(
                decoder,
                maximum_iterations=maximum_iterations,
                output_time_major=True,
                swap_memory=True,
                scope=decoder_scope)
            if beam_width > 0:
                self.result = infer_outputs.predicted_ids
            else:
                self.result = infer_outputs.sample_id
                self.result_lengths = infer_lengths

        with tf.variable_scope("loss"):
            max_time = tf.shape(self.rep_output)[0]
            with tf.variable_scope("reconstruction"):
                # TODO: use inference decoder's logits to compute recon_loss
                cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(  # ce = [len, batch_size]
                    labels=self.rep_output,
                    logits=self.logits)
                # rep: [len, batch_size]; logits: [len, batch_size, vocab_size]
                target_mask = tf.sequence_mask(self.rep_len + 1,
                                               max_time,
                                               dtype=self.logits.dtype)
                # time_major
                target_mask_t = tf.transpose(target_mask)  # max_len batch_size
                self.recon_losses = tf.reduce_sum(cross_entropy *
                                                  target_mask_t,
                                                  axis=0)
                self.recon_loss = tf.reduce_sum(
                    cross_entropy * target_mask_t) / batch_size

            with tf.variable_scope("latent"):
                # without prior network
                # self.kl_loss = 0.5 * tf.reduce_sum(tf.exp(self.log_var) + self.mu ** 2 - 1. - self.log_var, 0)
                self.kl_losses = 0.5 * tf.reduce_sum(
                    tf.exp(self.log_var - self.p_log_var) +
                    (self.mu - self.p_mu)**2 / tf.exp(self.p_log_var) - 1. -
                    self.log_var + self.p_log_var,
                    axis=1)
                self.kl_loss = tf.reduce_mean(self.kl_losses)

            with tf.variable_scope("bow"):
                # self.bow_loss = self.kl_weight * 0
                mlp_b = layers_core.Dense(vocab_size,
                                          use_bias=False,
                                          name="MLP_b")
                # is it a mistake that we only model on latent variable?
                latent_logits = mlp_b(
                    tf.concat(
                        [self.z_sample, ori_encoder_state_flat, emoji_vec],
                        axis=1))  # [batch_size, vocab_size]
                latent_logits = tf.expand_dims(
                    latent_logits, 0)  # [1, batch_size, vocab_size]
                latent_logits = tf.tile(
                    latent_logits,
                    [max_time, 1, 1])  # [max_time, batch_size, vocab_size]

                cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(  # ce = [len, batch_size]
                    labels=self.rep_output,
                    logits=latent_logits)
                self.bow_losses = tf.reduce_sum(cross_entropy * target_mask_t,
                                                axis=0)
                self.bow_loss = tf.reduce_sum(
                    cross_entropy * target_mask_t) / batch_size

            if is_seq2seq:
                self.kl_losses = self.kl_losses - self.kl_losses
                self.bow_losses = self.bow_losses - self.bow_losses
                self.kl_loss = self.kl_loss - self.kl_loss
                self.bow_loss = self.bow_loss - self.bow_loss

            self.losses = self.recon_losses + self.kl_losses * self.kl_weight * kl_ceiling + self.bow_losses * bow_ceiling
            self.loss = tf.reduce_mean(self.losses)

        # Calculate and clip gradients
        with tf.variable_scope("optimization"):
            params = tf.trainable_variables()
            gradients = tf.gradients(self.loss, params)
            clipped_gradients, _ = tf.clip_by_global_norm(
                gradients, max_gradient_norm)

            # Optimization
            optimizer = tf.train.AdamOptimizer(lr)
            self.update_step = optimizer.apply_gradients(
                zip(clipped_gradients, params))

        with tf.variable_scope("policy_loss"):
            prob = tf.nn.softmax(
                infer_outputs.rnn_output)  # [max_len, batch_size, vocab_size]
            prob = tf.clip_by_value(prob, 1e-15, 1000.)
            output_prob = tf.reduce_max(tf.log(prob),
                                        axis=2)  # [max_len, batch_size]
            seq_log_prob = tf.reduce_sum(output_prob, axis=0)  # batch_size
            # reward = tf.nn.relu(self.reward)
            self.policy_losses = -self.reward * seq_log_prob
            self.policy_losses *= (0.5 - 1) * self.ac5_vec + 1

        with tf.variable_scope("policy_optimization"):
            # zero = tf.constant(0, dtype=tf.float32)
            # where = tf.cast(tf.less(self.reward, zero), tf.float32)
            # recon = tf.reduce_sum(self.recon_losses * where) / tf.reduce_sum(where)

            final_loss = self.policy_losses * (
                1 - self.ac_vec) * self.policy_weight
            final_loss += self.losses * self.loss_weight
            self.policy_loss = tf.reduce_mean(final_loss)

            # final_loss = self.losses * self.loss_weight + self.policy_losses * self.policy_weight
            # final_loss *= (1 - self.ac_vec)
            # self.policy_loss = tf.reduce_sum(final_loss) / tf.reduce_sum((1 - self.ac_vec))

            gradients = tf.gradients(self.policy_loss, params)
            clipped_gradients, _ = tf.clip_by_global_norm(
                gradients, max_gradient_norm)
            optimizer = tf.train.AdamOptimizer(lr)
            self.policy_step = optimizer.apply_gradients(
                zip(clipped_gradients, params))
Пример #15
0
    def build_decode_cell(self):
        encoder_outputs = self.encoder_outputs
        encoder_last_state = self.encoder_last_state
        encoder_inputs_length = self.encoder_inputs_length

        if self.use_beamsearch_decode:
            encoder_outputs = seq2seq.tile_batch(self.encoder_outputs,
                                                 multiplier=self.beam_with)
            encoder_last_state = nest.map_structure(
                lambda s: seq2seq.tile_batch(s, self.beam_with),
                encoder_last_state)
            encoder_inputs_length = seq2seq.tile_batch(
                self.encoder_inputs_length, multiplier=self.beam_with)

        # Building attention mechanism: Default Bahdanau
        # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473
        self.attention_mechanism = seq2seq.BahdanauAttention(
            num_units=self.hidden_units,
            memory=encoder_outputs,
            memory_sequence_length=encoder_inputs_length)

        if self.attention_type.lower() == 'luong':
            self.attention_mechanism = seq2seq.LuongAttention(
                num_units=self.hidden_units,
                memory=encoder_outputs,
                memory_sequence_length=encoder_inputs_length)

        # 创建decoder_cell
        self.decoder_cell_list = [
            self.build_single_cell() for _ in range(self.depth)
        ]

        def attn_decoder_input_fn(inputs, attention):
            if not self.attn_input_feeding:
                return inputs

            # Essential when use_residual=True
            _input_layer = Dense(self.hidden_units,
                                 dtype=self.dtype,
                                 name='attn_input_feeding')
            return _input_layer(array_ops.concat([inputs, attention], -1))

        # AttentionWrapper wraps RNNCell with the attention_mechanism
        # Note: We implement Attention mechanism only on the top decoder layer
        self.decoder_cell_list[-1] = seq2seq.AttentionWrapper(
            cell=self.decoder_cell_list[-1],
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.hidden_units,
            cell_input_fn=attn_decoder_input_fn,
            initial_cell_state=encoder_last_state[-1],
            alignment_history=False,
            name='Attention_wrapper')

        # To be compatible with AttentionWrapper, the encoder last state
        # of the top layer should be converted into the AttentionWrapperState form
        # We can easily do this by calling AttentionWrapper.zero_state

        # Also if beamsearch decoding is used, the batch_size argument in .zero_state
        # should be ${decoder_beam_width} times to the origianl batch_size
        batch_size = self.batch_size if not self.use_beamsearch_decode else self.batch_size * self.beam_with
        initial_state = [state for state in encoder_last_state]
        initial_state[-1] = self.decoder_cell_list[-1].zero_state(
            batch_size=batch_size, dtype=self.dtype)
        decoder_initial_state = tuple(initial_state)
        return rnn.MultiRNNCell(self.decoder_cell_list), decoder_initial_state
Пример #16
0
    def __init__(self, mode, vocab_size, target_vocab_size, emb_dim,
                 encoder_num_units, encoder_num_layers, decoder_num_units,
                 decoder_num_layers, dropout_emb, dropout_hidden, tgt_sos_id,
                 tgt_eos_id, learning_rate, clip_norm, attention_option,
                 beam_size, optimizer, maximum_iterations):

        assert mode in ["train", "infer"], "invalid mode!"
        assert encoder_num_units == decoder_num_units, "encoder num_units **must** match decoder num_units"
        self.target_vocab_size = target_vocab_size

        # inputs
        self.encoder_inputs = tf.placeholder(tf.int32,
                                             shape=[None, None],
                                             name='encoder_inputs')
        self.decoder_inputs = tf.placeholder(tf.int32,
                                             shape=[None, None],
                                             name='decoder_inputs')
        self.decoder_outputs = tf.placeholder(tf.int32,
                                              shape=[None, None],
                                              name='decoder_outputs')
        self.encoder_lengths = tf.placeholder(tf.int32,
                                              shape=[None],
                                              name='encoder_lengths')
        self.decoder_lengths = tf.placeholder(tf.int32,
                                              shape=[None],
                                              name='decoder_lengths')

        # cell
        def cell(num_units):
            cell = rnn.BasicLSTMCell(num_units=num_units)
            if mode == 'train':
                cell = rnn.DropoutWrapper(cell=cell,
                                          output_keep_prob=1 - dropout_hidden)
            return cell

        # embeddings
        self.embeddings = tf.get_variable('embeddings',
                                          shape=[vocab_size, emb_dim],
                                          dtype=tf.float32)

        # Encoder
        with tf.variable_scope('encoder'):
            # embeddings
            encoder_inputs_emb = tf.nn.embedding_lookup(
                self.embeddings, self.encoder_inputs)
            if mode == 'train':
                encoder_inputs_emb = tf.nn.dropout(encoder_inputs_emb,
                                                   1 - dropout_emb)

            # encoder_rnn_cell
            fw_encoder_cell = cell(encoder_num_units)
            bw_encoder_cell = cell(encoder_num_units)

            # bi_lstm encoder
            (encoder_outputs_fw, encoder_outputs_bw), (
                encoder_state_fw,
                encoder_state_bw) = tf.nn.bidirectional_dynamic_rnn(
                    cell_fw=fw_encoder_cell,
                    cell_bw=bw_encoder_cell,
                    inputs=encoder_inputs_emb,
                    sequence_length=self.encoder_lengths,
                    dtype=tf.float32)
            encoder_outputs = tf.concat(
                [encoder_outputs_fw, encoder_outputs_bw], 2)

            # A linear layer to reduce the encoder's final FW and BW state into a single initial state for the decoder.
            # This is needed because the encoder is bidirectional but the decoder is not.
            encoder_states_c = tf.layers.dense(inputs=tf.concat(
                [encoder_state_fw.c, encoder_state_bw.c], axis=-1),
                                               units=encoder_num_units,
                                               activation=None,
                                               use_bias=False)
            encoder_states_h = tf.layers.dense(inputs=tf.concat(
                [encoder_state_fw.h, encoder_state_bw.h], axis=-1),
                                               units=encoder_num_units,
                                               activation=None,
                                               use_bias=False)
            encoder_states = rnn.LSTMStateTuple(encoder_states_c,
                                                encoder_states_h)

            encoder_lengths = self.encoder_lengths

        # Decoder
        with tf.variable_scope('decoder'):
            decoder_inputs_emb = tf.nn.embedding_lookup(
                self.embeddings, self.decoder_inputs)
            if mode == 'train':
                decoder_inputs_emb = tf.nn.dropout(decoder_inputs_emb,
                                                   1 - dropout_emb)
            # decoder_rnn_cell
            decoder_cell = cell(decoder_num_units)
            with tf.variable_scope('attention_mechanism'):
                if attention_option == "luong":
                    attention_mechanism = seq2seq.LuongAttention(
                        num_units=decoder_num_units,
                        memory=encoder_outputs,
                        memory_sequence_length=encoder_lengths)
                    cell_input_fn = lambda inputs, attention: inputs
                    output_attention = True
                elif attention_option == "scaled_luong":
                    attention_mechanism = seq2seq.LuongAttention(
                        num_units=decoder_num_units,
                        memory=encoder_outputs,
                        memory_sequence_length=encoder_lengths,
                        scale=True)
                    cell_input_fn = lambda inputs, attention: inputs
                    output_attention = True
                elif attention_option == "bahdanau":
                    attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                        num_units=decoder_num_units,
                        memory=encoder_outputs,
                        memory_sequence_length=encoder_lengths)
                    cell_input_fn = lambda inputs, attention: tf.concat(
                        [inputs, attention], -1)
                    output_attention = False
                elif attention_option == "normed_bahdanau":
                    attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                        num_units=decoder_num_units,
                        memory=encoder_outputs,
                        memory_sequence_length=encoder_lengths,
                        normalize=True)
                    cell_input_fn = lambda inputs, attention: tf.concat(
                        [inputs, attention], -1)
                    output_attention = False
                else:
                    raise ValueError("Unknown attention option %s" %
                                     attention_option)
            # # Only generate alignment in greedy INFER mode.
            # alignment_history = (mode == 'infer' and beam_size==0)
            alignment_history = False
            decoder_cell = seq2seq.AttentionWrapper(
                cell=decoder_cell,
                attention_mechanism=attention_mechanism,
                attention_layer_size=decoder_num_units,
                alignment_history=alignment_history,
                cell_input_fn=cell_input_fn,
                output_attention=output_attention)

            batch_size = tf.shape(self.encoder_inputs)[0]
            decoder_initial_state = decoder_cell.zero_state(
                batch_size=batch_size, dtype=tf.float32)
            decoder_initial_state = decoder_initial_state.clone(
                cell_state=encoder_states)

            projection_layer = layers_core.Dense(units=target_vocab_size,
                                                 use_bias=False)

            # train/infer
            if mode == 'train':
                # helper
                helper = seq2seq.TrainingHelper(
                    inputs=decoder_inputs_emb,
                    sequence_length=self.decoder_lengths)
                # decoder
                decoder = seq2seq.BasicDecoder(
                    cell=decoder_cell,
                    helper=helper,
                    initial_state=decoder_initial_state,
                    output_layer=projection_layer)
                # dynamic decoding
                self.final_outputs, self.final_state, self.final_sequence_lengths = seq2seq.dynamic_decode(
                    decoder=decoder, swap_memory=True)
            else:
                start_tokens = tf.fill([batch_size], tgt_sos_id)
                end_token = tgt_eos_id

                # helper
                helper = seq2seq.GreedyEmbeddingHelper(
                    embedding=self.embeddings,
                    start_tokens=start_tokens,
                    end_token=end_token)
                # decoder
                decoder = seq2seq.BasicDecoder(
                    cell=decoder_cell,
                    helper=helper,
                    initial_state=decoder_initial_state,
                    output_layer=projection_layer)

                # dynamic decoding
                self.final_outputs, self.final_state, self.final_sequence_lengths = seq2seq.dynamic_decode(
                    decoder=decoder,
                    maximum_iterations=maximum_iterations,
                    swap_memory=True)

            self.logits = self.final_outputs.rnn_output
            self.sample_id = self.final_outputs.sample_id

        if mode == 'train':
            # loss
            with tf.variable_scope('loss'):
                cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=self.decoder_outputs, logits=self.logits)
                masks = tf.sequence_mask(lengths=self.decoder_lengths,
                                         dtype=tf.float32)
                self.loss = tf.reduce_sum(
                    cross_entropy * masks) / tf.to_float(batch_size)
                tf.summary.scalar('loss', self.loss)

            # summaries
            self.merged = tf.summary.merge_all()

            # train_op
            self.learning_rate = tf.Variable(learning_rate, trainable=False)
            self.global_step = tf.Variable(0, dtype=tf.int32)
            tvars = tf.trainable_variables()
            clipped_gradients, _ = tf.clip_by_global_norm(tf.gradients(
                self.loss, tvars),
                                                          clip_norm=clip_norm)
            optimizer = tf.train.AdamOptimizer(
                learning_rate=self.learning_rate)
            self.train_op = optimizer.apply_gradients(
                zip(clipped_gradients, tvars), global_step=self.global_step)
def build_train_graph(params_dict):
    # Building the Embedding layer + placeholders
    keep_prob = tf.placeholder(tf.float32)
    embedding = tf.get_variable("embedding", initializer=glove, trainable=True)
    document_tokens = tf.placeholder(tf.int32,
                                     shape=[None, None],
                                     name="document_tokens")
    document_emb = tf.nn.embedding_lookup(embedding, document_tokens)
    answer_masks = tf.placeholder(tf.float32,
                                  shape=[None, None, None],
                                  name="answer_masks")
    decoder_inputs = tf.placeholder(tf.int32,
                                    shape=[None, None],
                                    name="decoder_inputs")
    decoder_labels = tf.placeholder(tf.int32,
                                    shape=[None, None],
                                    name="decoder_labels")
    decoder_lengths = tf.placeholder(tf.int32,
                                     shape=[None],
                                     name="decoder_lengths")
    encoder_lengths = tf.placeholder(tf.int32,
                                     shape=[None],
                                     name="encoder_lengths")
    decoder_emb = tf.nn.embedding_lookup(embedding, decoder_inputs)
    question_mask = tf.sequence_mask(decoder_lengths, dtype=tf.float32)
    projection = Dense(embedding.shape[0], use_bias=False)

    training_helper = seq2seq.TrainingHelper(inputs=decoder_emb,
                                             sequence_length=decoder_lengths,
                                             time_major=False)

    # Building the Encoder
    encoder_inputs = tf.matmul(answer_masks,
                               document_emb,
                               name="encoder_inputs")

    output = encoder_inputs
    for n in range(params_dict["num_encoder_layers"]):
        cell_fw = LSTMCell(params_dict["lstm_units"],
                           forget_bias=1.0,
                           state_is_tuple=True)
        cell_bw = LSTMCell(params_dict["lstm_units"],
                           forget_bias=1.0,
                           state_is_tuple=True)
        cell_fw = DropoutWrapper(
            cell_fw,
            output_keep_prob=keep_prob,
        )
        cell_bw = DropoutWrapper(
            cell_bw,
            output_keep_prob=keep_prob,
        )

        state_fw = cell_fw.zero_state(params_dict["batch_size"], tf.float32)
        state_bw = cell_bw.zero_state(params_dict["batch_size"], tf.float32)

        (output_fw,
         output_bw), encoder_state = tf.nn.bidirectional_dynamic_rnn(
             cell_fw,
             cell_bw,
             output,
             initial_state_fw=state_fw,
             initial_state_bw=state_bw,
             sequence_length=encoder_lengths,
             dtype=tf.float32,
             scope='encoder_rnn_' + str(n))
        output = tf.concat([output_fw, output_bw], axis=2)

    encoder_final_output = output
    encoder_state_c = tf.concat((encoder_state[0][0], encoder_state[1][0]), -1)
    encoder_state_h = tf.concat((encoder_state[0][1], encoder_state[1][1]), -1)
    encoder_final_state = LSTMStateTuple(encoder_state_c, encoder_state_h)

    # Attention mechanism
    attention_mechanism = seq2seq.LuongAttention(
        num_units=params_dict["lstm_units"] * 2,
        memory=encoder_final_output,
        memory_sequence_length=encoder_lengths)

    # Building the Decoder
    temp_cell = LSTMCell(params_dict["lstm_units"] * 2, forget_bias=1.0)
    temp_cell = DropoutWrapper(
        temp_cell,
        output_keep_prob=keep_prob,
    )
    decoder_cell = seq2seq.AttentionWrapper(
        cell=temp_cell,
        attention_mechanism=attention_mechanism,
        attention_layer_size=params_dict["lstm_units"] * 2)

    training_decoder = seq2seq.BasicDecoder(
        cell=decoder_cell,
        helper=training_helper,
        initial_state=decoder_cell.zero_state(
            params_dict["batch_size"],
            tf.float32).clone(cell_state=encoder_final_state),
        output_layer=projection)

    training_decoder_output, _, _ = seq2seq.dynamic_decode(
        decoder=training_decoder,
        impute_finished=True,
        maximum_iterations=tf.reduce_max(decoder_lengths))

    training_logits = training_decoder_output.rnn_output
    # Normalize the logits between [0,1]
    prob_logits = tf.nn.softmax(training_logits, axis=-1)

    loss = seq2seq.sequence_loss(logits=training_logits,
                                 targets=decoder_labels,
                                 weights=question_mask,
                                 name="loss")

    return {
        "keep_prob": keep_prob,
        "document_tokens": document_tokens,
        "answer_masks": answer_masks,
        "encoder_lengths": encoder_lengths,
        "decoder_inputs": decoder_inputs,
        "decoder_labels": decoder_labels,
        "decoder_lengths": decoder_lengths,
        "training_logits": training_logits,
        "prob_logits": prob_logits,
        "loss": loss
    }
Пример #18
0
    def __init__(self,
                 vocab_size,
                 embedding_size,
                 lstm_size,
                 num_layer,
                 max_length_encoder,
                 max_length_decoder,
                 max_gradient_norm,
                 batch_size_num,
                 learning_rate,
                 beam_width,
                 embed=None):
        self.batch_size = batch_size_num
        self.max_length_encoder = max_length_encoder
        self.max_length_decoder = max_length_decoder
        with tf.variable_scope('g_model') as scope:
            self.encoder_input = tf.placeholder(tf.int32,
                                                [max_length_encoder, None])
            self.decoder_output = tf.placeholder(tf.int32,
                                                 [max_length_decoder, None])
            self.target_weight = tf.placeholder(
                tf.float32,
                [max_length_decoder, None])  # for pretraining or updating
            self.reward = tf.placeholder(
                tf.float32, [max_length_decoder, None])  # for updating
            self.start_tokens = tf.placeholder(tf.int32,
                                               [None])  # for partial-sampling
            self.max_inference_length = tf.placeholder(tf.int32,
                                                       [])  # for inference

            self.encoder_length = tf.placeholder(tf.int32, [None])
            self.decoder_length = tf.placeholder(tf.int32, [None])
            batch_size = tf.shape(self.encoder_length)[0]
            # batch_size = 1
            decoder_output = self.decoder_output
            # if decoder_output have 0 dimention ???
            self.decoder_input = tf.concat([
                tf.ones([1, batch_size], dtype=tf.int32) * GO_ID,
                decoder_output[:-1]
            ],
                                           axis=0)
            if embed is None:
                embedding = tf.get_variable('embedding',
                                            [vocab_size, embedding_size])
            else:
                embedding = tf.get_variable('embedding',
                                            [vocab_size, embedding_size],
                                            initializer=embed)
            encoder_embedded = tf.nn.embedding_lookup(embedding,
                                                      self.encoder_input)
            decoder_embedded = tf.nn.embedding_lookup(embedding,
                                                      self.decoder_input)

            self.cell_state = tf.placeholder(
                tf.float32,
                [2 * num_layer, None, lstm_size])  # for partial-sampling
            self.attention = tf.placeholder(tf.float32, [None, lstm_size])
            self.time = tf.placeholder(tf.int32)
            self.alignments = tf.placeholder(tf.float32,
                                             [None, max_length_encoder])

            def build_attention_state():
                cell_state = tuple([
                    tf.contrib.rnn.LSTMStateTuple(self.cell_state[i],
                                                  self.cell_state[i + 1])
                    for i in range(0, 2 * num_layer, 2)
                ])
                return tf.contrib.seq2seq.AttentionWrapperState(
                    cell_state, self.attention, self.time, self.alignments,
                    tuple([]))

            partial_decoder_state = build_attention_state()

            def single_cell():
                return tf.contrib.rnn.BasicLSTMCell(lstm_size)

            def multi_cell():
                return tf.contrib.rnn.MultiRNNCell(
                    [single_cell() for _ in range(num_layer)])

            with tf.variable_scope('encoder'):
                encoder_cell = multi_cell()
                encoder_output, encoder_state = tf.nn.dynamic_rnn(
                    encoder_cell,
                    encoder_embedded,
                    self.encoder_length,
                    time_major=True,
                    dtype=tf.float32)

            with tf.variable_scope('decoder') as decoder_scope:
                attention_state = tf.transpose(encoder_output, [1, 0, 2])
                attention_mechanism = seq2seq.LuongAttention(
                    lstm_size,
                    attention_state,
                    memory_sequence_length=self.encoder_length)
                # train or evaluate
                decoder_cell_raw = multi_cell()
                # attention wrapper
                decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
                    decoder_cell_raw,
                    attention_mechanism,
                    attention_layer_size=lstm_size)
                decoder_init_state = decoder_cell.zero_state(
                    batch_size, tf.float32).clone(cell_state=encoder_state)

                helper = tf.contrib.seq2seq.TrainingHelper(decoder_embedded,
                                                           self.decoder_length,
                                                           time_major=True)
                projection_layer = layers_core.Dense(vocab_size)  # use_bias ?
                decoder = tf.contrib.seq2seq.BasicDecoder(
                    decoder_cell,
                    helper,
                    decoder_init_state,
                    output_layer=projection_layer)

                output, decoder_state, _ = tf.contrib.seq2seq.dynamic_decode(
                    decoder,
                    output_time_major=True,
                    swap_memory=True,
                    scope=decoder_scope)
                logits = output.rnn_output
                self.result_train = tf.transpose(output.sample_id)
                self.decoder_state = decoder_state
                # inference (sample)
                helper_sample = tf.contrib.seq2seq.SampleEmbeddingHelper(
                    embedding,
                    start_tokens=tf.fill([batch_size], GO_ID),
                    end_token=EOS_ID)
                decoder_sample = tf.contrib.seq2seq.BasicDecoder(
                    decoder_cell,
                    helper_sample,
                    decoder_init_state,
                    output_layer=projection_layer)
                output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                    decoder_sample,
                    swap_memory=True,
                    scope=decoder_scope,
                    maximum_iterations=self.max_inference_length)
                self.result_sample = output.sample_id

                # inference (partial-sample)
                helper_partial = tf.contrib.seq2seq.SampleEmbeddingHelper(
                    embedding,
                    start_tokens=self.start_tokens,
                    end_token=EOS_ID)
                decoder_partial = tf.contrib.seq2seq.BasicDecoder(
                    decoder_cell,
                    helper_partial,
                    partial_decoder_state,
                    output_layer=projection_layer)
                output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                    decoder_partial,
                    swap_memory=True,
                    scope=decoder_scope,
                    maximum_iterations=self.max_inference_length)
                self.result_partial = output.sample_id

                # inference (greedy)
                helper_greedy = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                    embedding,
                    start_tokens=tf.fill([batch_size], GO_ID),
                    end_token=EOS_ID)
                decoder_greedy = tf.contrib.seq2seq.BasicDecoder(
                    decoder_cell,
                    helper_greedy,
                    decoder_init_state,
                    output_layer=projection_layer)
                output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                    decoder_greedy,
                    swap_memory=True,
                    scope=decoder_scope,
                    maximum_iterations=self.max_inference_length)
                self.result_greedy = output.sample_id

                # inference (beam search)
                # with tf.variable_scope('decoder', reuse=True) as decoder_scope:
                attention_state = tf.contrib.seq2seq.tile_batch(
                    attention_state, multiplier=beam_width)
                source_seq_length = tf.contrib.seq2seq.tile_batch(
                    self.encoder_length, multiplier=beam_width)
                encoder_state = tf.contrib.seq2seq.tile_batch(
                    encoder_state, multiplier=beam_width)
            with tf.variable_scope('decoder', reuse=True) as decoder_scope:
                attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                    lstm_size,
                    attention_state,
                    memory_sequence_length=source_seq_length)

                decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
                    decoder_cell_raw,
                    attention_mechanism,
                    attention_layer_size=lstm_size)
                beam_search_init_state = decoder_cell.zero_state(
                    batch_size * beam_width,
                    tf.float32).clone(cell_state=encoder_state)
                decoder_beam_search = tf.contrib.seq2seq.BeamSearchDecoder(
                    cell=decoder_cell,
                    embedding=embedding,
                    start_tokens=tf.fill([batch_size], GO_ID),
                    end_token=EOS_ID,
                    initial_state=beam_search_init_state,
                    beam_width=beam_width,
                    output_layer=projection_layer,
                    length_penalty_weight=0.0)
                output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                    decoder_beam_search,
                    swap_memory=True,
                    scope=decoder_scope,
                    maximum_iterations=self.max_inference_length)
                self.result_beam_search = tf.transpose(output.predicted_ids,
                                                       [0, 2, 1])

            dim = tf.shape(logits)[0]
            decoder_output = tf.split(decoder_output,
                                      [dim, max_length_decoder - dim])[0]
            target_weight = tf.split(self.target_weight,
                                     [dim, max_length_decoder - dim])[0]
            reward = tf.split(self.reward, [dim, max_length_decoder - dim])[0]

            params = scope.trainable_variables()
            # update for pretraining
            cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=decoder_output, logits=logits)  # max_len * batch
            self.loss_pretrain = tf.reduce_sum(
                target_weight * cross_entropy) / tf.cast(
                    batch_size, tf.float32)
            self.perplexity = tf.exp(
                tf.reduce_sum(target_weight * cross_entropy) /
                tf.reduce_sum(target_weight))
            gradient_pretrain = tf.gradients(self.loss_pretrain, params)
            gradient_pretrain, _ = tf.clip_by_global_norm(
                gradient_pretrain, max_gradient_norm)
            optimizer = tf.train.AdamOptimizer(learning_rate)
            self.opt_pretrain = optimizer.apply_gradients(
                zip(gradient_pretrain, params))

            # update for GAN
            one_hot = tf.one_hot(decoder_output, vocab_size)
            self.prob = tf.reduce_sum(one_hot * tf.nn.softmax(logits), axis=2)
            self.loss_generator = tf.reduce_sum(
                -tf.log(tf.maximum(self.prob, 1e-5)) * reward *
                target_weight) / tf.cast(batch_size, tf.float32)
            gradient_generator = tf.gradients(self.loss_generator, params)
            gradient_generator, _ = tf.clip_by_global_norm(
                gradient_generator, max_gradient_norm)
            optimizer = tf.train.AdamOptimizer(learning_rate)
            self.opt_update = optimizer.apply_gradients(
                zip(gradient_generator, params))
Пример #19
0
def seq2seq_rnn(inputs,
                hidden_size,
                scope,
                use_xavier=True,
                stddev=1e-3,
                weight_decay=None,
                activation_fn=tf.nn.relu,
                bn=False,
                bn_decay=None,
                is_training=None):
    """ RNN with no-linear operation.
    Args:
    inputs: 4-D tensor variable BxNxTxI
    hidden_size: int
    scope: string
    activation_fn: function
    bn: bool, whether to use batch norm
    bn_decay: float or float tensor variable in [0,1]
    is_training: bool Tensor variable
    Return:
    Variable Tensor BxNxO
    """
    # with tf.variable_scope(scope) as sc:
    batch_size = inputs.get_shape()[0].value
    npoint = inputs.get_shape()[1].value
    nstep = inputs.get_shape()[2].value
    in_size = inputs.get_shape()[3].value
    reshaped_inputs = tf.reshape(inputs, (-1, nstep, in_size))

    with tf.variable_scope('encoder'):
        #build encoder
        encoder_cell = tf.nn.rnn_cell.LSTMCell(hidden_size)
        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
            encoder_cell,
            reshaped_inputs,
            sequence_length=tf.fill([batch_size * npoint], 4),
            dtype=tf.float32,
            time_major=False)
    with tf.variable_scope('decoder'):
        #build decoder
        decoder_cell = tf.nn.rnn_cell.LSTMCell(hidden_size)
        decoder_inputs = tf.reshape(encoder_state.h,
                                    [batch_size * npoint, 1, hidden_size])

        # dummy = tf.fill([4096], 1)
        # helper = tf.contrib.seq2seq.TrainingHelper(decoder_inputs, dummy, time_major=False)
        # h1 = decoder_cell.zero_state(batch_size * npoint, np.float32)
        # decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, initial_state=h1)
        # final_outputs, _final_state, _final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder,
        #                                                                                          impute_finished=True,
        #                                                                                          maximum_iterations=4096)

        # building attention mechanism: default Bahdanau
        # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473
        # attention_mechanism = seq2seq.BahdanauAttention(num_units=hidden_size, memory=encoder_outputs)
        # 'Luong' style attention: https://arxiv.org/abs/1508.04025
        attention_mechanism = seq2seq.LuongAttention(num_units=hidden_size,
                                                     memory=encoder_outputs)
        # AttentionWrapper wraps RNNCell with the attention_mechanism
        decoder_cell = seq2seq.AttentionWrapper(
            cell=decoder_cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=hidden_size)

        # Helper to feed inputs for training: read inputs from dense ground truth vectors
        train_helper = seq2seq.TrainingHelper(inputs=decoder_inputs,
                                              sequence_length=tf.fill(
                                                  [batch_size * npoint], 1),
                                              time_major=False)
        decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size *
                                                        npoint,
                                                        dtype=tf.float32)
        train_decoder = seq2seq.BasicDecoder(
            cell=decoder_cell,
            helper=train_helper,
            initial_state=decoder_initial_state,
            output_layer=None)
        decoder_outputs_train, decoder_last_state_train, decoder_outputs_length_train = seq2seq.dynamic_decode(
            decoder=train_decoder,
            output_time_major=False,
            impute_finished=True)
        # decoder_logits_train = tf.identity(decoder_outputs_train.rnn_output)

    #test
    # if is_training == False:
    #     decoding_helper = seq2seq.GreedyEmbeddingHelper(inputs=decoder_inputs, sequence_length=tf.fill([batch_size*npoint], 1),
    #                                       time_major=False)
    #     inference_decoder = seq2seq.BasicDecoder(cell=decoder_cell,helper=decoding_helper,initial_state=decoder_initial_state, output_layer=None)
    #     decoder_outputs, decoder_last_state, decoder_outputs_length= seq2seq.dynamic_decode(
    #         decoder=inference_decoder, output_time_major=False, impute_finished=True)
    #     print(decoder_logits_train)
    #     raw_input()
    # attention_mechanism  =tf.contrib.seq2seq.LuongAttention(hidden_size, encoder_outputs
    # decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism, attention_layer_size=hidden_size)
    # dummy = tf.fill([4096], 1)
    # helper  =tf.contrib.seq2seq.TrainingHelper(decoder_inputs, dummy, time_major=False)
    #
    # #decoder
    # h1 = decoder_cell.zero_state(batch_size*npoint, np.float32)
    # #cell_state = encoder_state
    # decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, initial_state=h1)
    #
    # final_outputs, _final_state, _final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder,impute_finished=True, maximum_iterations=4096)
    # # else:
    # #     dummy = tf.fill([4096], 1)
    # #     helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(encoder_state, dummy, )
    # # else:
    # #     helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_inputs, tf.fill([4096], 1), tf.fill([4096], 2))
    # #     decoder_cell = tf.nn.rnn_cell.GRUCell(hidden_size)
    # #     decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, initial_state=encoder_state)
    # #     final_outputs, _final_state, _final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder,
    # #                                                                                              impute_finished=True,
    # #                                                                                              maximum_iterations=4096)
    outputs = tf.reshape(decoder_last_state_train[0].h,
                         (-1, npoint, hidden_size))
    if bn:
        outputs = batch_norm_for_fc(outputs, is_training, bn_decay, 'bn')

    if activation_fn is not None:
        outputs = activation_fn(outputs)
    return outputs
def build_inference_graph(params_dict):
    # Todo: Check if load the glove is faster than import it from embedding
    # glove = np.load('GloVe/glove.npy')

    # Building the Embedding layer + placeholders
    keep_prob = tf.placeholder(tf.float32)
    embedding = tf.get_variable("embedding", initializer=glove, trainable=True)
    document_tokens = tf.placeholder(tf.int32, shape=[None, None], name="document_tokens")
    document_emb = tf.nn.embedding_lookup(embedding, document_tokens)
    answer_masks = tf.placeholder(tf.float32, shape=[None, None, None], name="answer_masks")
    encoder_lengths = tf.placeholder(tf.int32, shape=[None], name="encoder_lengths")
    projection = Dense(embedding.shape[0], use_bias=False)

    helper = seq2seq.GreedyEmbeddingHelper(embedding, tf.fill([batch_size],
                                                              START_TOKEN), END_TOKEN)

    # Building the Encoder
    encoder_inputs = tf.matmul(answer_masks, document_emb, name="encoder_inputs")

    output = encoder_inputs
    for n in range(params_dict["num_encoder_layers"]):
        cell_fw = LSTMCell(params_dict["lstm_units"], forget_bias=1.0, state_is_tuple=True)
        cell_bw = LSTMCell(params_dict["lstm_units"], forget_bias=1.0, state_is_tuple=True)
        cell_fw = DropoutWrapper(cell_fw, output_keep_prob=keep_prob, )
        cell_bw = DropoutWrapper(cell_bw, output_keep_prob=keep_prob, )

        state_fw = cell_fw.zero_state(params_dict["batch_size"], tf.float32)
        state_bw = cell_bw.zero_state(params_dict["batch_size"], tf.float32)

        (output_fw, output_bw), encoder_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, output,
                                                                                initial_state_fw=state_fw,
                                                                                initial_state_bw=state_bw,
                                                                                sequence_length=encoder_lengths,
                                                                                dtype=tf.float32,
                                                                                scope='encoder_rnn_' + str(n))
        output = tf.concat([output_fw, output_bw], axis=2)

    encoder_final_output = output
    encoder_state_c = tf.concat((encoder_state[0][0], encoder_state[1][0]), -1)
    encoder_state_h = tf.concat((encoder_state[0][1], encoder_state[1][1]), -1)
    encoder_final_state = LSTMStateTuple(encoder_state_c, encoder_state_h)

    # Attention mechanism
    attention_mechanism = seq2seq.LuongAttention(
        num_units=params_dict["lstm_units"] * 2,
        memory=encoder_final_output,
        memory_sequence_length=encoder_lengths)

    # Building the Decoder
    temp_cell = LSTMCell(params_dict["lstm_units"] * 2, forget_bias=1.0)
    temp_cell = DropoutWrapper(temp_cell, output_keep_prob=keep_prob, )
    decoder_cell = seq2seq.AttentionWrapper(
        cell=temp_cell,
        attention_mechanism=attention_mechanism,
        attention_layer_size=params_dict["lstm_units"] * 2)

    decoder = seq2seq.BasicDecoder(
        cell=decoder_cell,
        helper=helper,
        initial_state=decoder_cell.zero_state(params_dict["batch_size"], tf.float32).clone(cell_state=encoder_final_state),
        output_layer=projection)

    decoder_outputs, _, _ = seq2seq.dynamic_decode(decoder, maximum_iterations=16)
    decoder_outputs = decoder_outputs.rnn_output

    # Normalize the logits between [0,1]
    prob_logits = tf.nn.softmax(decoder_outputs, axis=-1)

    return {
        "keep_prob": keep_prob,
        "document_tokens": document_tokens,
        "answer_masks": answer_masks,
        "encoder_lengths": encoder_lengths,
        "decoder_outputs": decoder_outputs,
        "prob_logits": prob_logits
    }
Пример #21
0
    def _build_main_graph(self, xs, xlens, ys, ylens):
        with tf.variable_scope('word_model', reuse=self._reuse_vars):
            embeds = self._variable(
                'embeddings',
                dtype=tf.float32,
                shape=[self._word_symbols, self._word_embedding_size])

            with tf.variable_scope('encoder', reuse=self._reuse_vars):
                fw_cells = self._rnn_cells(self._word_model_rnn_hidden_size,
                                           self._word_model_rnn_layers // 2)
                bw_cells = self._rnn_cells(self._word_model_rnn_hidden_size,
                                           self._word_model_rnn_layers // 2)

                batch_input_embeds = tf.nn.embedding_lookup(embeds, xs)

                rnn_out, rnn_state = tf.nn.bidirectional_dynamic_rnn(
                    fw_cells,
                    bw_cells,
                    batch_input_embeds,
                    xlens,
                    dtype=tf.float32)

            with tf.variable_scope('decoder', reuse=self._reuse_vars):
                # Attention only consumes encoder outputs.
                attention = seq2seq.LuongAttention(
                    self._decoder_attention_size, tf.concat(rnn_out, -1),
                    xlens)
                cells = self._rnn_cells(self._word_model_rnn_hidden_size,
                                        self._word_model_rnn_layers)
                cells = seq2seq.AttentionWrapper(cells, attention)
                decode_init_state = cells.zero_state(self._batch_size,
                                                     tf.float32)

                # This layer sits just before softmax. It seems that if an activation is placed here,
                # the network will not converge well. Why?
                def apply_dropout(v):
                    if self._mode == 'train':
                        return tf.nn.dropout(v, KEEP)
                    else:
                        return v

                final_projection = tf.layers.Dense(
                    self._word_symbols,
                    kernel_regularizer=apply_dropout,
                    use_bias=False)

                if self._mode != 'infer':
                    batch_target_embeds = tf.nn.embedding_lookup(embeds, ys)
                    helper = seq2seq.TrainingHelper(batch_target_embeds, ylens)
                    decoder = seq2seq.BasicDecoder(cells, helper,
                                                   decode_init_state,
                                                   final_projection)
                    (logits,
                     ids), state, lengths = seq2seq.dynamic_decode(decoder)
                    return logits, ids, lengths
                else:
                    helper = seq2seq.GreedyEmbeddingHelper(
                        embeds, tf.tile([self._start_token],
                                        [self._batch_size]), self._end_token)
                    decoder = seq2seq.BasicDecoder(cells, helper,
                                                   decode_init_state,
                                                   final_projection)
                    max_iters = tf.reduce_max(xlens) * 2
                    (logits, ids), state, lengths = seq2seq.dynamic_decode(
                        decoder, maximum_iterations=max_iters)
                    return logits, ids, lengths
Пример #22
0
    def inference(self):
        with tf.variable_scope("embedding"):
            embedding = tf.get_variable(
                "embedding",
                shape=[self.vocab_size, self.embedding_size],
                initializer=tf.truncated_normal_initializer(stddev=0.1,
                                                            dtype=tf.float32))
            encoder_input_data_embedding = tf.nn.embedding_lookup(
                embedding, self.encoder_input_data)
            decoder_input_data_embedding = tf.nn.embedding_lookup(
                embedding, self.decoder_input_data)

        with tf.variable_scope("encoder"):
            en_lstm1 = rnn.BasicLSTMCell(256)
            en_lstm1 = rnn.DropoutWrapper(en_lstm1,
                                          output_keep_prob=self.keep_prob)
            en_lstm2 = rnn.BasicLSTMCell(256)
            en_lstm2 = rnn.DropoutWrapper(en_lstm2,
                                          output_keep_prob=self.keep_prob)
            encoder_cell_fw = rnn.MultiRNNCell([en_lstm1])
            encoder_cell_bw = rnn.MultiRNNCell([en_lstm2])
        bi_encoder_outputs, bi_encoder_state = tf.nn.bidirectional_dynamic_rnn(
            encoder_cell_fw,
            encoder_cell_bw,
            encoder_input_data_embedding,
            sequence_length=self.input_seq_len,
            dtype=tf.float32)
        encoder_outputs = tf.concat(bi_encoder_outputs, -1)
        encoder_state = []
        for layer_id in range(1):  # layer_num
            encoder_state.append(bi_encoder_state[0][layer_id])  # forward
            encoder_state.append(bi_encoder_state[1][layer_id])  # backward
        encoder_state = tuple(encoder_state)

        with tf.variable_scope("decoder"):
            de_lstm1 = rnn.BasicLSTMCell(256)
            de_lstm1 = rnn.DropoutWrapper(de_lstm1,
                                          output_keep_prob=self.keep_prob)
            de_lstm2 = rnn.BasicLSTMCell(256)
            de_lstm2 = rnn.DropoutWrapper(de_lstm2,
                                          output_keep_prob=self.keep_prob)
            decoder_cell = rnn.MultiRNNCell([de_lstm1, de_lstm2])

            attention_mechanism = seq2seq.LuongAttention(
                256, encoder_outputs, self.input_seq_len)
            decoder_cell = seq2seq.AttentionWrapper(decoder_cell,
                                                    attention_mechanism, 256)
            decoder_initial_state = decoder_cell.zero_state(self.batch_size,
                                                            dtype=tf.float32)
            decoder_initial_state = decoder_initial_state.clone(
                cell_state=encoder_state)

            output_projection = Dense(self.vocab_size,
                                      name="output_projection")
            if self.is_train:
                helper = seq2seq.TrainingHelper(decoder_input_data_embedding,
                                                self.output_seq_len)
                decoder = seq2seq.BasicDecoder(decoder_cell,
                                               helper,
                                               decoder_initial_state,
                                               output_layer=output_projection)
                decoder_outputs, _, _ = seq2seq.dynamic_decode(decoder)
                logits = decoder_outputs.rnn_output
                pred = decoder_outputs.sample_id
            else:
                # #################SampleEmbedding#################
                helper = seq2seq.SampleEmbeddingHelper(
                    embedding,
                    start_tokens=[input_data.GO_ID] * self.batch_size,
                    end_token=input_data.EOS_ID)
                # #################GreedyEmbedding#################
                # helper = seq2seq.GreedyEmbeddingHelper(embedding,
                #                                        start_tokens=[input_data.GO_ID] * self.batch_size,
                #                                        end_token=input_data.EOS_ID)
                decoder = seq2seq.BasicDecoder(decoder_cell,
                                               helper,
                                               decoder_initial_state,
                                               output_layer=output_projection)
                decoder_outputs, _, _ = seq2seq.dynamic_decode(
                    decoder, maximum_iterations=10)
                logits = decoder_outputs.rnn_output
                pred = decoder_outputs.sample_id
            return logits, pred
Пример #23
0
def create_model(embeddings, hypothesis_max_length, sentence_max_length, seed):

    tf.set_random_seed(seed)

    # Based on https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/bidirectional_rnn.py
    num_classes = 2
    dims = 100
    learning_rate = 0.001
    X_t = tf.placeholder(tf.int32, [None, hypothesis_max_length], name="topic_input")
    X_s = tf.placeholder(tf.int32, [None, sentence_max_length], name="sentence_input")
    L_t = tf.placeholder(tf.int32, [None, ], name="topic_length")
    L_s = tf.placeholder(tf.int32, [None, ], name="sentence_length")
    Y = tf.placeholder(tf.float32, [None, num_classes], name="target")

    def BiRNN(x, layer):

        with tf.variable_scope('encoder_{}'.format(layer),reuse=False):
            # Prepare data shape to match `rnn` function requirements
            # Current data input shape: (batch_size, timesteps, n_input)
            # Required shape: 'timesteps' tensors list of shape (batch_size, num_input)

            # Unstack to get a list of 'timesteps' tensors of shape (batch_size, num_input)
            # x = tf.unstack(x, max_length, 1)

            # Define lstm cells with tensorflow
            # Forward direction cell
            lstm_fw_cell = rnn.BasicLSTMCell(dims, forget_bias=1.0)
            # Backward direction cell
            lstm_bw_cell = rnn.BasicLSTMCell(dims, forget_bias=1.0)


            ((fw_outputs, bw_outputs), (fw_states, bw_states)) = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell,
                                                                                                 lstm_bw_cell,
                                                                                                 x,
                                                                                                 dtype=tf.float32)
            outputs = tf.concat([fw_outputs, bw_outputs], axis=2)


            # print("BiLSTM lengths: ", len(outputs))
            # Linear activation, using rnn inner loop last output
            return outputs

    def BiRNNAtt(x, attention, layer) :
        with tf.variable_scope('encoder_sentence_{}'.format(layer),reuse=False):
            # Prepare data shape to match `rnn` function requirements
            # Current data input shape: (batch_size, timesteps, n_input)
            # Required shape: 'timesteps' tensors list of shape (batch_size, num_input)

            # Unstack to get a list of 'timesteps' tensors of shape (batch_size, num_input)
            # x = tf.unstack(x, max_length, 1)

            # Define lstm cells with tensorflow
            # Forward direction cell
            lstm_fw_cell = rnn.BasicLSTMCell(dims, forget_bias=1.0)
            lstm_fw_att = seq2seq.AttentionWrapper(lstm_fw_cell, attention)
            # Backward direction cell
            lstm_bw_cell = rnn.BasicLSTMCell(dims, forget_bias=1.0)
            lstm_bw_att = seq2seq.AttentionWrapper(lstm_bw_cell, attention)


            ((fw_outputs, bw_outputs), (fw_states, bw_states)) = tf.nn.bidirectional_dynamic_rnn(lstm_fw_att,
                                                                                                 lstm_bw_att,
                                                                                                 x,
                                                                                                 dtype=tf.float32)
            outputs = tf.concat([fw_outputs, bw_outputs], axis=2)


            # print("BiLSTM lengths: ", len(outputs))
            # Linear activation, using rnn inner loop last output
            return outputs


    topic_word_embeddings = tf.Variable(embeddings, dtype=tf.float32, name="topic_embeddings")
    topic_embedded_word_id = tf.nn.embedding_lookup(topic_word_embeddings, X_t)

    sentence_word_embeddings = tf.Variable(embeddings, dtype=tf.float32, name="sentence_embeddings")
    sentence_embedded_word_id = tf.nn.embedding_lookup(sentence_word_embeddings, X_s)

    topic_bilstm_out = BiRNN(topic_embedded_word_id, "topic")

    attention_mechanism = seq2seq.LuongAttention(100, topic_bilstm_out, L_t)
    # sentence_bilstm_out = BiRNNAtt(sentence_embedded_word_id, attention_mechanism, "sentence")
    sentence_bilstm_out = BiRNN(sentence_embedded_word_id, "sentence")
    # output = tf.concat((topic_bilstm_out[:, -1], sentence_bilstm_out[:, -1]), axis=1)
    sentence_attention, topic_attention = _inter_atten(topic_bilstm_out, sentence_bilstm_out, L_t, L_s)  # TODO CST 2019-06-28: Add sentence lengths as input
    # wheigh by attention
    topic_att_wheighted = tf.multiply(topic_bilstm_out, tf.multiply(topic_bilstm_out, topic_attention))
    sentence_att_wheighted = tf.multiply(sentence_bilstm_out, tf.multiply(sentence_bilstm_out, sentence_attention))
    # attention diff
    topic_att_diff = tf.subtract(topic_bilstm_out, topic_attention)
    sentence_att_diff = tf.subtract(sentence_bilstm_out, sentence_attention)
    # attention_output = tf.reduce_sum(tf.concat((topic_attention, sentence_attention), axis=1), axis=1)
    # attention_output = tf.reduce_sum(tf.concat((topic_attention, sentence_attention, topic_att_wheighted, sentence_att_wheighted, topic_att_diff, sentence_att_diff), axis=1), axis=1)
    attention_output = tf.reduce_sum(tf.concat((topic_att_wheighted, sentence_att_wheighted), axis=1), axis=1)
    output = tf.concat((topic_bilstm_out[:, -1], sentence_bilstm_out[:, -1], attention_output), axis=1)
    # output = attention_output
    # output = tf.concat((topic_bilstm_out[:, -1], sentence_bilstm_out[:, -1]), axis=1)

    logits = tf.layers.dense(output, 2)
    prediction = tf.nn.softmax(logits, name="output")

    # Define loss and optimizer
    loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.minimize(loss_op, name="train")

    # Initialize the variables (i.e. assign their default value)
    init = tf.global_variables_initializer()
    return X_t, X_s, L_t, L_s, Y, prediction, train_op
Пример #24
0
    def _create_attention_mechanism(self,
                                    attention_type,
                                    num_units,
                                    memory,
                                    memory_sequence_length):
        r"""
        Instantiates a seq2seq attention mechanism, also setting the _output_attention flag accordingly.

        Warning: if different types of mechanisms are used within the same decoder, this function needs
        to be refactored to return the right output_attention flag for each `AttentionWrapper` object.
        Args:
            attention_type: `String`, one of `bahdanau`, `luong` with optional `normed`, `scaled` or
              `monotonic` prefixes. See code for the precise format.
            num_units: `int`, depth of the query mechanism. See downstream documentation.
            memory: A 3D Tensor [batch_size, Ts, num_features], the attended memory
            memory_sequence_length: A 1D Tensor [batch_size] holding the true sequence lengths
        """

        if attention_type == 'bahdanau':
            attention_mechanism = seq2seq.BahdanauAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                normalize=False,
                dtype=self._hparams.dtype
            )
            self._output_attention = False
        elif attention_type == 'normed_bahdanau':
            attention_mechanism = seq2seq.BahdanauAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                normalize=True,
                dtype=self._hparams.dtype,
            )
            self._output_attention = False
        elif attention_type == 'normed_monotonic_bahdanau':
            attention_mechanism = seq2seq.BahdanauMonotonicAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                normalize=True,
                score_bias_init=-2.0,
                sigmoid_noise=1.0 if self._mode == 'train' else 0.0,
                mode='hard' if self._mode != 'train' else 'parallel',
                dtype=self._hparams.dtype,
            )
            self._output_attention = False
        elif attention_type == 'luong':
            attention_mechanism = seq2seq.LuongAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                dtype=self._hparams.dtype,
            )
            self._output_attention = True
        elif attention_type == 'scaled_luong':
            attention_mechanism = seq2seq.LuongAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                scale=True,
                dtype=self._hparams.dtype,
            )
            self._output_attention = True
        elif attention_type == 'scaled_monotonic_luong':
            attention_mechanism = seq2seq.LuongMonotonicAttention(
                num_units=num_units,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                scale=True,
                score_bias_init=-2.0,
                sigmoid_noise=1.0 if self._mode == 'train' else 0.0,
                mode='hard' if self._mode != 'train' else 'parallel',
                dtype=self._hparams.dtype,
            )
            self._output_attention = True
        else:
            raise Exception('unknown attention mechanism')

        return attention_mechanism
Пример #25
0
    def __init__(self, config, batch_size, embedding, encoder_input, input_len, is_training=True, ru=False):
        self.config = config
        with tf.variable_scope("encoder_input"):
            self.embedding = embedding
            self.encoder_input = encoder_input
            self.input_len = input_len
            self.batch_size = batch_size

            self.is_training = is_training

        with tf.variable_scope("encoder_rnn"):
            encoder_emb_inputs = tf.nn.embedding_lookup(self.embedding, self.encoder_input)

            def create_cell():
                if self.config.RNN_CELL == 'lnlstm':
                    cell = rnn.LayerNormBasicLSTMCell(self.config.ENC_RNN_SIZE)
                elif self.config.RNN_CELL == 'lstm':
                    cell = rnn.BasicLSTMCell(self.config.ENC_RNN_SIZE)
                elif self.config.RNN_CELL == 'gru':
                    cell = rnn.GRUCell(self.config.ENC_RNN_SIZE)
                else:
                    logger.error('rnn_cell {} not supported'.format(self.config.RNN_CELL))
                if self.is_training:
                    cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=self.config.DROPOUT_KEEP)
                return cell

            cell_fw = create_cell()
            cell_bw = create_cell()

            output = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, encoder_emb_inputs, dtype=tf.float32)
            encoder_outputs, encoder_state = output

            def get_last_hidden():
                if self.config.RNN_CELL == 'gru':
                    return tf.concat([encoder_state[0], encoder_state[1]], -1)
                else:
                    return tf.concat([encoder_state[0][1], encoder_state[1][1]], -1)

            # last fw and bw hidden state
            if self.config.ENC_FUNC == 'mean':
                encoder_rnn_output = tf.reduce_mean(tf.concat(encoder_outputs, -1), 1)
            elif self.config.ENC_FUNC == 'lasth':
                encoder_rnn_output = get_last_hidden()
            elif self.config.ENC_FUNC in ['attn', 'attn_scale']:
                attn = seq2seq.LuongAttention(self.config.ENC_RNN_SIZE * 2,
                                              tf.concat(encoder_outputs, -1),
                                              self.input_len, scale=self.config.ENC_FUNC == 'attn_scale')
                encoder_rnn_output, _ = _compute_attention(attn, get_last_hidden(), None, None)
            elif self.config.ENC_FUNC in ['attn_ba', 'attn_ba_norm']:
                attn = seq2seq.BahdanauAttention(self.config.ENC_RNN_SIZE,
                                                 tf.concat(encoder_outputs, -1),
                                                 self.input_len, normalize=self.config.ENC_FUNC == 'attn_ba_norm')
                encoder_rnn_output, _ = _compute_attention(attn, get_last_hidden(), None, None)
            else:
                logger.error('enc_func {} not supported'.format(self.config.ENC_FUNC))

        with tf.name_scope("mu"):
            mu = tf.layers.dense(encoder_rnn_output, self.config.LATENT_VARIABLE_SIZE, activation=tf.nn.tanh)
            self.mu = tf.layers.dense(mu, self.config.LATENT_VARIABLE_SIZE, activation=None)
        with tf.name_scope("log_var"):
            logvar = tf.layers.dense(encoder_rnn_output, self.config.LATENT_VARIABLE_SIZE, activation=tf.nn.tanh)
            self.logvar = tf.layers.dense(logvar, self.config.LATENT_VARIABLE_SIZE, activation=None)

        with tf.name_scope("epsilon"):
            epsilon = tf.random_normal((self.batch_size, self.config.LATENT_VARIABLE_SIZE), mean=0.0, stddev=1.0)

        with tf.name_scope("latent_variables"):
            if self.is_training:
                self.latent_variables = self.mu + (tf.exp(0.5 * self.logvar) * epsilon)
            else:
                self.latent_variables = self.mu + (tf.exp(0.5 * self.logvar) * 0)
Пример #26
0
    def build_attention_graph_tensor(
        self,
        target,
        batch_size,
        trainable: bool = True,
    ):

        # =================================1, 定义模型的输入数据()
        with tf.variable_scope("get_data", reuse=tf.AUTO_REUSE):
            self.room_context = self.middle_state_cnn_feature
            self.furniture_fea = self.furniture_cnn_feature

            self.y = target
            self.target = target
            self.decoder_targets = tf.reshape(self.target,
                                              [-1, self.max_length])

            self.cnn_out = self.cnn_output_distribute

            # 序列的长度, 布局中input 与 output 的长度相同
            self.seq_length = tf.reshape(self.target, [-1, self.max_length])
            self.mask = self.seq_length
            self.seq_length = tf.reduce_sum(self.seq_length, axis=-1)
            self.encoder_inputs_length = self.seq_length

            # batch_size
            self.batch_size = self.seq_length.shape[0]
            print("---debug: self.batch_size:", self.batch_size)
            self.batch_size = tf.Print(self.batch_size, [self.batch_size],
                                       message="--debug: self.batch_size")

        # Encoder
        with tf.variable_scope("encoder"):

            embedding = tf.get_variable('embedding',
                                        [self.label_size, self.ebd_size])

            # 家具特征
            with tf.variable_scope("furniture_fea", reuse=tf.AUTO_REUSE):

                if self.use_furniture_cnn:
                    self.furniture_fea = self.furniture_fea

                with tf.variable_scope("furniture_concat_process",
                                       reuse=tf.AUTO_REUSE):
                    if self.use_cnn_predict_encoder:
                        self.enc = tf.concat([
                            self.furniture_fea, self.cnn_out, self.room_context
                        ],
                                             axis=-1)
                    else:
                        self.enc = tf.concat(
                            [self.furniture_fea, self.room_context], axis=-1)
                    if self.use_single:
                        enc_shape = self.enc.shape.as_list()
                        enc_shape_ = [-1, 1, enc_shape[-1]]
                        self.enc = tf.reshape(self.enc, shape=enc_shape_)
                    else:
                        enc_shape = self.enc.shape.as_list()
                        enc_shape[0] = -1
                        enc_shape = self.enc.shape.as_list()
                        enc_shape_ = [-1, self.max_length, enc_shape[-1]]
                        self.enc = tf.reshape(self.enc, shape=enc_shape_)

            with tf.variable_scope("furniture_encoder", reuse=tf.AUTO_REUSE):

                def single_rnn_cell():
                    # 创建单个cell,这里需要注意的是一定要使用一个single_rnn_cell的函数,不然直接把cell放在MultiRNNCell
                    # 的列表中最终模型会发生错误
                    single_cell = tf.contrib.rnn.LSTMCell(hp.hidden_units)
                    # 添加dropout
                    cell = tf.contrib.rnn.DropoutWrapper(single_cell,
                                                         output_keep_prob=1 -
                                                         self.dropout_rate)
                    return cell

                encoder_cell = rnn.MultiRNNCell(
                    [single_rnn_cell() \
                     for i in range(self.num_blocks * self.num_blocks)])

                encoder_output, encoder_state = tf.nn.dynamic_rnn(
                    encoder_cell,
                    inputs=self.enc,
                    dtype=self.enc.dtype,
                    sequence_length=self.encoder_inputs_length)

        # =================================2, 定义模型的encoder部分
        with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
            encoder_inputs_length = self.encoder_inputs_length
            # encoder_output = tf.concat(encoder_output, -1)
            # BahdanauAttention 与 LuongAttention 主要不同点再对齐函数上:在计算第 i个位置的score,
            # 前者是需要使用 s_{i-1}和h_{j} 来进行计算,后者使用s_{i}和h_{j}计算,这么来看还是后者直观上更合理些,
            # 逻辑上也更顺滑。两种机制在不同任务上的性能貌似差距也不是很大,具体的细节还待进一步做实验比较。
            #
            # attention_mechanim = seq2seq.BahdanauAttention(self.hidden_units, encoder_output,
            #                                                self.max_length, normalize=True)
            attention_mechanim = seq2seq.LuongAttention(
                self.hidden_units,
                encoder_output,
                self.max_length,
                scale=True,
                memory_sequence_length=encoder_inputs_length)

            batch_size = self.batch_size
            decoder_cell = rnn.MultiRNNCell(
                [single_rnn_cell() \
                 for i in range(self.num_blocks * self.num_blocks)])
            decoder_cell = seq2seq.AttentionWrapper(
                decoder_cell,
                attention_mechanim,
                attention_layer_size=self.hidden_units,
                name="Attention_Wrapper")

            #  定义decoder阶段的初始化状态,直接使用encoder阶段的最后一个隐层状态进行赋值
            decoder_initial_state = decoder_cell.zero_state(
                batch_size, tf.float32).clone(cell_state=encoder_state)

            output_layer = tf.layers.Dense(
                self.label_size,
                kernel_initializer=tf.truncated_normal_initializer(mean=0.0,
                                                                   stddev=0.1))

            if trainable:
                self.y = tf.reshape(self.y, [-1, self.max_length])
                self.decoder_inputs = tf.concat(
                    (tf.ones_like(self.y[:, :1]) * 0, self.y[:, :-1]), -1)
                decoder_inputs_embedded = tf.nn.embedding_lookup(
                    embedding, self.decoder_input)

                training_helper = seq2seq.TrainingHelper(
                    inputs=decoder_inputs_embedded,
                    sequence_length=self.encoder_inputs_length,
                    time_major=False)

                training_decoder = seq2seq.BasicDecoder(
                    decoder_cell,
                    training_helper,
                    decoder_initial_state,
                    output_layer=output_layer)

                decoder_outputs, _, _ = seq2seq.dynamic_decode(
                    decoder=training_decoder,
                    impute_finished=True,
                    maximum_iterations=tf.convert_to_tensor(self.max_length,
                                                            dtype=tf.int32))

                # 根据输出计算loss和梯度,并定义进行更新的AdamOptimizer和train_op
                self.decoder_logits_train = tf.identity(
                    decoder_outputs.rnn_output)
                self.logits = self.decoder_logits_train
                self.logits = tf.reshape(self.logits,
                                         shape=[-1, self.label_size])
                self.decoder_predict_train = tf.argmax(
                    self.decoder_logits_train,
                    axis=-1,
                    name='decoder_pred_train')
                self.pred = tf.reshape(self.decoder_predict_train, sahpe=[-1])
                # 使用sequence_loss计算loss,这里需要传入之前定义的mask标志
                self.loss = tf.contrib.seq2seq.sequence_loss(
                    logits=self.decoder_logits_train,
                    targets=self.decoder_targets,
                    weights=self.mask)

                correct = tf.cast(
                    tf.equal(self.decoder_targets, self.decoder_predict_train),
                    "float") * self.mask / (tf.reduce_sum(self.mask))
                accuracy = tf.reduce_sum(correct, name="cnn_accuracy")

                optimizer = tf.train.AdamOptimizer(self.learing_rate)
                trainable_params = tf.trainable_variables()
                gradients = tf.gradients(self.loss, trainable_params)
                clip_gradients, _ = tf.clip_by_global_norm(
                    gradients, self.max_gradient_norm)
                train_op = optimizer.apply_gradients(
                    zip(clip_gradients, trainable_params))
                self.attention_train_op = train_op
                self.attention_acc = accuracy
Пример #27
0
        def create_decoder(mode,index):

            greedy_unroll_type = 'greedy'
            if mode == "action":
                scope_attention = 'AttnMechanismAction'
                scope_decoder = 'DecoderAction'
                gt_tokens = self.gt_actions_tokens
                token_dim = self.action_space

            else:
                gt_tokens = self.gt_per
                scope_attention = 'AttnMechanismPerception_'+str(index)
                scope_decoder = 'DecoderPerception_'+str(index)
                token_dim = 3

            # Create basic decoder cell
            lstm_cell = rnn.BasicLSTMCell(num_units=self.num_lstm_cell_units)

            # Create attention meachnisms
            attn_mechanisms = []

            for i in range(self.k):
                with tf.variable_scope(scope_attention, reuse=i > 0):
                    if self.attn_type == 'luong':
                        attn_mechanism = seq2seq.LuongAttention(
                            self.num_lstm_cell_units, demo_feature_history_list[i],
                            memory_sequence_length=self.demo_len[:, i])
                    elif self.attn_type == 'luong_monotonic':
                        attn_mechanism = seq2seq.LuongMonotonicAttention(
                            self.num_lstm_cell_units, demo_feature_history_list[i],
                            memory_sequence_length=self.demo_len[:, i])
                    else:
                        raise ValueError('Unknown attention type')
                attn_mechanisms.append(attn_mechanism)



            attn_cells = []
            for i in range(self.k):
                attn_cell = seq2seq.AttentionWrapper(
                    lstm_cell, attn_mechanisms[i],
                    attention_layer_size=self.num_lstm_cell_units,
                    alignment_history=True,
                    output_attention=True)
                attn_cells.append(attn_cell)

            pred_list = []
            greedy_pred_list = []
            greedy_pred_len_list = []

            for i in range(self.k):

                attn_init_state = attn_cells[i].zero_state(
                    self.batch_size, dtype=tf.float32).clone(
                        cell_state=rnn.LSTMStateTuple(demo_h_list[i], demo_c_list[i]))

                embedding_dim = demo_h_list[i].get_shape().as_list()[-1]

                if mode == 'action':
                    gt_tokens_i = gt_tokens[i]
                else:
                    gt_tokens_i = gt_tokens[i][:,:,index]

                # demo summaries: [bs,v] or [bs,k*v]
                # pred action: shape [bs, ,seq_len]
                pred, pred_len, state = LSTM_Decoder(
                        demo_h_list[i], demo_c_list[i], gt_tokens_i,
                        attn_cells[i], unroll_type=train_unroll_type,
                        seq_lengths=self.action_len[:, i],
                        max_sequence_len=self.max_action_len,
                        token_dim=token_dim,
                        embedding_dim=embedding_dim,
                        init_state=attn_init_state,
                        sequence_type=mode,
                        scope=scope_decoder, reuse=i > 0
                    )


                pred_list.append(pred)

                greedy_attn_init_state = attn_cells[i].zero_state(
                    self.batch_size, dtype=tf.float32).clone(
                        cell_state=rnn.LSTMStateTuple(demo_h_list[i], demo_c_list[i]))

                greedy_pred, greedy_pred_len, greedy_state = LSTM_Decoder(
                        demo_h_list[i], demo_c_list[i], gt_tokens_i,
                        attn_cells[i], unroll_type=greedy_unroll_type,
                        seq_lengths=self.action_len[:, i],
                        max_sequence_len=self.max_action_len,
                        token_dim=token_dim,
                        embedding_dim=embedding_dim,
                        init_state=greedy_attn_init_state,
                        sequence_type=mode,
                        scope=scope_decoder, reuse=True
                    )
                #assert greedy_pred.get_shape() == \
                #    gt_onehot[i].get_shape()
                greedy_pred_list.append(greedy_pred)
                greedy_pred_len_list.append(greedy_pred_len)

            return pred_list, greedy_pred_list, greedy_pred_len_list