示例#1
0
  def __call__(self, name, max_len, reuse=False):
    component_scope = self._variable_scope + name
    with tf.variable_scope(component_scope) as scope:
      if reuse:
        scope.reuse_variables()
      max_len = max_len + 2 if self._word_delimiters else max_len
      word_lens_source = tf.placeholder(dtype=tf.int32, shape=[None], name='source/word_lens')
      word_lens_target = tf.placeholder(dtype=tf.int32, shape=[None], name='source/word_target')

      chars_p_s = tf.placeholder(dtype=tf.int32, shape=[max_len, None],
                                 name='source/char_sequence%i' % max_len)
      chars_p_t = tf.placeholder(dtype=tf.int32, shape=[max_len, None],
                                 name='target/char_sequence%i' % max_len)

      chars_s = tf.nn.embedding_lookup(self.Wchar_s, chars_p_s)
      chars_s = tf.transpose(chars_s, [1, 0, 2])
      chars_t = tf.nn.embedding_lookup(self.Wchar_t, chars_p_t)
      chars_t = tf.transpose(chars_t, [1, 0, 2])
      chars = tf.concat([chars_s, chars_t], 2)


      enc_output_infer, enc_state_infer = tf.nn.dynamic_rnn(self.char_rnn_cell_infer, chars,  dtype=tf.float32,
                                                            sequence_length=tf.cast(word_lens_source, dtype=tf.int64),
                                                            swap_memory=True, scope='encoder')
      attn_keys, attn_values, attn_score_fn, attn_construct_fn = attention_decoder_fn.prepare_attention(enc_output_infer, 'luong', self.char_rnn_cell_infer.output_size)
      dec_fn_inf = attention_decoder_fn.attention_decoder_fn_train(
        enc_state_infer, attn_keys, attn_values, attn_score_fn, attn_construct_fn)
      outputs_infer, _, _ = seq2seq.dynamic_rnn_decoder(self.char_rnn_cell_infer, dec_fn_inf,
                                                     inputs=chars, sequence_length=word_lens_target, swap_memory=True,
                                                     scope='decoder')

      scope.reuse_variables()
      enc_output, enc_state = tf.nn.dynamic_rnn(self.char_rnn_cell_train, chars, dtype=tf.float32,
                                                            sequence_length=tf.cast(word_lens_source, dtype=tf.int64),
                                                            swap_memory=True, scope='encoder')
      attn_keys, attn_values, attn_score_fn, attn_construct_fn = attention_decoder_fn.prepare_attention(
        enc_output, 'luong', self.char_rnn_cell_infer.output_size)
      dec_fn = attention_decoder_fn.attention_decoder_fn_train(
        enc_state, attn_keys, attn_values, attn_score_fn, attn_construct_fn)
      outputs, _, _ = seq2seq.dynamic_rnn_decoder(self.char_rnn_cell_train, dec_fn, inputs=chars,
                                              sequence_length=word_lens_target, swap_memory=True,
                                              scope='decoder')

      output = _get_last_state_dyn(self.max_norm, word_lens_target, outputs)
      output_infer = _get_last_state_dyn(self.max_norm, word_lens_target, outputs_infer)
      inputs = [
        chars_p_s, word_lens_source,
        chars_p_t, word_lens_target
      ]

      char_feature_extractor1 = CharLevelInputExtraction(self._char_vocab_source, max_len, component_scope + 'source/')
      char_feature_extractor2 = CharLevelInputExtraction(self._char_vocab_target, max_len, component_scope + 'target/')
      char_feature_extractor = CombineFeatureExtraction(char_feature_extractor1, char_feature_extractor2)
      return Component(inputs, output,
                       output_infer=output_infer,
                       feature_extractor=char_feature_extractor,
                       name='c_rnn_joint')
示例#2
0
    def __init__(self,
                 num_symbols,
                 num_embed_units,
                 num_units,
                 num_layers,
                 is_train,
                 vocab=None,
                 embed=None,
                 learning_rate=0.1,
                 learning_rate_decay_factor=0.95,
                 max_gradient_norm=5.0,
                 num_samples=512,
                 max_length=30,
                 use_lstm=True):

        self.posts_1 = tf.placeholder(tf.string, shape=(None, None))
        self.posts_2 = tf.placeholder(tf.string, shape=(None, None))
        self.posts_3 = tf.placeholder(tf.string, shape=(None, None))
        self.posts_4 = tf.placeholder(tf.string, shape=(None, None))

        self.entity_1 = tf.placeholder(tf.string, shape=(None, None, None, 3))
        self.entity_2 = tf.placeholder(tf.string, shape=(None, None, None, 3))
        self.entity_3 = tf.placeholder(tf.string, shape=(None, None, None, 3))
        self.entity_4 = tf.placeholder(tf.string, shape=(None, None, None, 3))

        self.entity_mask_1 = tf.placeholder(tf.float32,
                                            shape=(None, None, None))
        self.entity_mask_2 = tf.placeholder(tf.float32,
                                            shape=(None, None, None))
        self.entity_mask_3 = tf.placeholder(tf.float32,
                                            shape=(None, None, None))
        self.entity_mask_4 = tf.placeholder(tf.float32,
                                            shape=(None, None, None))

        self.posts_length_1 = tf.placeholder(tf.int32, shape=(None))
        self.posts_length_2 = tf.placeholder(tf.int32, shape=(None))
        self.posts_length_3 = tf.placeholder(tf.int32, shape=(None))
        self.posts_length_4 = tf.placeholder(tf.int32, shape=(None))

        self.responses = tf.placeholder(tf.string, shape=(None, None))
        self.responses_length = tf.placeholder(tf.int32, shape=(None))

        self.epoch = tf.Variable(0, trainable=False, name='epoch')
        self.epoch_add_op = self.epoch.assign(self.epoch + 1)

        if is_train:
            self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
        else:
            self.symbols = tf.Variable(np.array(['.'] * num_symbols),
                                       name="symbols")

        self.symbol2index = HashTable(KeyValueTensorInitializer(
            self.symbols,
            tf.Variable(
                np.array([i for i in range(num_symbols)], dtype=np.int32),
                False)),
                                      default_value=UNK_ID,
                                      name="symbol2index")

        self.posts_input_1 = self.symbol2index.lookup(self.posts_1)

        self.posts_2_target = self.posts_2_embed = self.symbol2index.lookup(
            self.posts_2)
        self.posts_3_target = self.posts_3_embed = self.symbol2index.lookup(
            self.posts_3)
        self.posts_4_target = self.posts_4_embed = self.symbol2index.lookup(
            self.posts_4)

        self.responses_target = self.symbol2index.lookup(self.responses)

        batch_size, decoder_len = tf.shape(self.posts_1)[0], tf.shape(
            self.responses)[1]

        self.posts_input_2 = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.posts_2_embed, [tf.shape(self.posts_2)[1] - 1, 1],
                     1)[0]
        ], 1)
        self.posts_input_3 = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.posts_3_embed, [tf.shape(self.posts_3)[1] - 1, 1],
                     1)[0]
        ], 1)
        self.posts_input_4 = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.posts_4_embed, [tf.shape(self.posts_4)[1] - 1, 1],
                     1)[0]
        ], 1)

        self.responses_target = self.symbol2index.lookup(self.responses)

        batch_size, decoder_len = tf.shape(self.posts_1)[0], tf.shape(
            self.responses)[1]

        self.responses_input = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0]
        ], 1)

        self.encoder_2_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.posts_length_2 - 1,
                                 tf.shape(self.posts_2)[1]),
                      reverse=True,
                      axis=1), [-1, tf.shape(self.posts_2)[1]])
        self.encoder_3_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.posts_length_3 - 1,
                                 tf.shape(self.posts_3)[1]),
                      reverse=True,
                      axis=1), [-1, tf.shape(self.posts_3)[1]])
        self.encoder_4_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.posts_length_4 - 1,
                                 tf.shape(self.posts_4)[1]),
                      reverse=True,
                      axis=1), [-1, tf.shape(self.posts_4)[1]])

        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        if embed is None:
            self.embed = tf.get_variable('embed',
                                         [num_symbols, num_embed_units],
                                         tf.float32)
        else:
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input_1 = tf.nn.embedding_lookup(self.embed,
                                                      self.posts_input_1)
        self.encoder_input_2 = tf.nn.embedding_lookup(self.embed,
                                                      self.posts_input_2)
        self.encoder_input_3 = tf.nn.embedding_lookup(self.embed,
                                                      self.posts_input_3)
        self.encoder_input_4 = tf.nn.embedding_lookup(self.embed,
                                                      self.posts_input_4)

        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)

        entity_embedding_1 = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entity_1)),
            [
                batch_size,
                tf.shape(self.entity_1)[1],
                tf.shape(self.entity_1)[2], 3 * num_embed_units
            ])
        entity_embedding_2 = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entity_2)),
            [
                batch_size,
                tf.shape(self.entity_2)[1],
                tf.shape(self.entity_2)[2], 3 * num_embed_units
            ])
        entity_embedding_3 = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entity_3)),
            [
                batch_size,
                tf.shape(self.entity_3)[1],
                tf.shape(self.entity_3)[2], 3 * num_embed_units
            ])
        entity_embedding_4 = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entity_4)),
            [
                batch_size,
                tf.shape(self.entity_4)[1],
                tf.shape(self.entity_4)[2], 3 * num_embed_units
            ])

        head_1, relation_1, tail_1 = tf.split(entity_embedding_1,
                                              [num_embed_units] * 3,
                                              axis=3)
        head_2, relation_2, tail_2 = tf.split(entity_embedding_2,
                                              [num_embed_units] * 3,
                                              axis=3)
        head_3, relation_3, tail_3 = tf.split(entity_embedding_3,
                                              [num_embed_units] * 3,
                                              axis=3)
        head_4, relation_4, tail_4 = tf.split(entity_embedding_4,
                                              [num_embed_units] * 3,
                                              axis=3)

        with tf.variable_scope('graph_attention'):
            #[batch_size, max_reponse_length, max_triple_num, 2*embed_units]
            head_tail_1 = tf.concat([head_1, tail_1], axis=3)
            #[batch_size, max_reponse_length, max_triple_num, embed_units]
            head_tail_transformed_1 = tf.layers.dense(
                head_tail_1,
                num_embed_units,
                activation=tf.tanh,
                name='head_tail_transform')
            #[batch_size, max_reponse_length, max_triple_num, embed_units]
            relation_transformed_1 = tf.layers.dense(relation_1,
                                                     num_embed_units,
                                                     name='relation_transform')
            #[batch_size, max_reponse_length, max_triple_num]
            e_weight_1 = tf.reduce_sum(relation_transformed_1 *
                                       head_tail_transformed_1,
                                       axis=3)
            #[batch_size, max_reponse_length, max_triple_num]
            alpha_weight_1 = tf.nn.softmax(e_weight_1)
            #[batch_size, max_reponse_length, embed_units]
            graph_embed_1 = tf.reduce_sum(
                tf.expand_dims(alpha_weight_1, 3) *
                (tf.expand_dims(self.entity_mask_1, 3) * head_tail_1),
                axis=2)

        with tf.variable_scope('graph_attention', reuse=True):
            head_tail_2 = tf.concat([head_2, tail_2], axis=3)
            head_tail_transformed_2 = tf.layers.dense(
                head_tail_2,
                num_embed_units,
                activation=tf.tanh,
                name='head_tail_transform')
            relation_transformed_2 = tf.layers.dense(relation_2,
                                                     num_embed_units,
                                                     name='relation_transform')
            e_weight_2 = tf.reduce_sum(relation_transformed_2 *
                                       head_tail_transformed_2,
                                       axis=3)
            alpha_weight_2 = tf.nn.softmax(e_weight_2)
            graph_embed_2 = tf.reduce_sum(
                tf.expand_dims(alpha_weight_2, 3) *
                (tf.expand_dims(self.entity_mask_2, 3) * head_tail_2),
                axis=2)

        with tf.variable_scope('graph_attention', reuse=True):
            head_tail_3 = tf.concat([head_3, tail_3], axis=3)
            head_tail_transformed_3 = tf.layers.dense(
                head_tail_3,
                num_embed_units,
                activation=tf.tanh,
                name='head_tail_transform')
            relation_transformed_3 = tf.layers.dense(relation_3,
                                                     num_embed_units,
                                                     name='relation_transform')
            e_weight_3 = tf.reduce_sum(relation_transformed_3 *
                                       head_tail_transformed_3,
                                       axis=3)
            alpha_weight_3 = tf.nn.softmax(e_weight_3)
            graph_embed_3 = tf.reduce_sum(
                tf.expand_dims(alpha_weight_3, 3) *
                (tf.expand_dims(self.entity_mask_3, 3) * head_tail_3),
                axis=2)

        with tf.variable_scope('graph_attention', reuse=True):
            head_tail_4 = tf.concat([head_4, tail_4], axis=3)
            head_tail_transformed_4 = tf.layers.dense(
                head_tail_4,
                num_embed_units,
                activation=tf.tanh,
                name='head_tail_transform')
            relation_transformed_4 = tf.layers.dense(relation_4,
                                                     num_embed_units,
                                                     name='relation_transform')
            e_weight_4 = tf.reduce_sum(relation_transformed_4 *
                                       head_tail_transformed_4,
                                       axis=3)
            alpha_weight_4 = tf.nn.softmax(e_weight_4)
            graph_embed_4 = tf.reduce_sum(
                tf.expand_dims(alpha_weight_4, 3) *
                (tf.expand_dims(self.entity_mask_4, 3) * head_tail_4),
                axis=2)

        if use_lstm:
            cell = MultiRNNCell([LSTMCell(num_units)] * num_layers)
        else:
            cell = MultiRNNCell([GRUCell(num_units)] * num_layers)

        output_fn, sampled_sequence_loss = output_projection_layer(
            num_units, num_symbols, num_samples)

        encoder_output_1, encoder_state_1 = dynamic_rnn(cell,
                                                        self.encoder_input_1,
                                                        self.posts_length_1,
                                                        dtype=tf.float32,
                                                        scope="encoder")

        attention_keys_1, attention_values_1, attention_score_fn_1, attention_construct_fn_1 \
                = attention_decoder_fn.prepare_attention(graph_embed_1, encoder_output_1, 'luong', num_units)
        decoder_fn_train_1 = attention_decoder_fn.attention_decoder_fn_train(
            encoder_state_1,
            attention_keys_1,
            attention_values_1,
            attention_score_fn_1,
            attention_construct_fn_1,
            max_length=tf.reduce_max(self.posts_length_2))
        encoder_output_2, encoder_state_2, alignments_ta_2 = dynamic_rnn_decoder(
            cell,
            decoder_fn_train_1,
            self.encoder_input_2,
            self.posts_length_2,
            scope="decoder")
        self.alignments_2 = tf.transpose(alignments_ta_2.stack(),
                                         perm=[1, 0, 2])

        self.decoder_loss_2 = sampled_sequence_loss(encoder_output_2,
                                                    self.posts_2_target,
                                                    self.encoder_2_mask)

        with variable_scope.variable_scope('', reuse=True):
            attention_keys_2, attention_values_2, attention_score_fn_2, attention_construct_fn_2 \
                    = attention_decoder_fn.prepare_attention(graph_embed_2, encoder_output_2, 'luong', num_units)
            decoder_fn_train_2 = attention_decoder_fn.attention_decoder_fn_train(
                encoder_state_2,
                attention_keys_2,
                attention_values_2,
                attention_score_fn_2,
                attention_construct_fn_2,
                max_length=tf.reduce_max(self.posts_length_3))
            encoder_output_3, encoder_state_3, alignments_ta_3 = dynamic_rnn_decoder(
                cell,
                decoder_fn_train_2,
                self.encoder_input_3,
                self.posts_length_3,
                scope="decoder")
            self.alignments_3 = tf.transpose(alignments_ta_3.stack(),
                                             perm=[1, 0, 2])

            self.decoder_loss_3 = sampled_sequence_loss(
                encoder_output_3, self.posts_3_target, self.encoder_3_mask)

            attention_keys_3, attention_values_3, attention_score_fn_3, attention_construct_fn_3 \
                    = attention_decoder_fn.prepare_attention(graph_embed_3, encoder_output_3, 'luong', num_units)
            decoder_fn_train_3 = attention_decoder_fn.attention_decoder_fn_train(
                encoder_state_3,
                attention_keys_3,
                attention_values_3,
                attention_score_fn_3,
                attention_construct_fn_3,
                max_length=tf.reduce_max(self.posts_length_4))
            encoder_output_4, encoder_state_4, alignments_ta_4 = dynamic_rnn_decoder(
                cell,
                decoder_fn_train_3,
                self.encoder_input_4,
                self.posts_length_4,
                scope="decoder")
            self.alignments_4 = tf.transpose(alignments_ta_4.stack(),
                                             perm=[1, 0, 2])

            self.decoder_loss_4 = sampled_sequence_loss(
                encoder_output_4, self.posts_4_target, self.encoder_4_mask)

            attention_keys, attention_values, attention_score_fn, attention_construct_fn \
                    = attention_decoder_fn.prepare_attention(graph_embed_4, encoder_output_4, 'luong', num_units)

        if is_train:
            with variable_scope.variable_scope('', reuse=True):
                decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(
                    encoder_state_4,
                    attention_keys,
                    attention_values,
                    attention_score_fn,
                    attention_construct_fn,
                    max_length=tf.reduce_max(self.responses_length))
                self.decoder_output, _, alignments_ta = dynamic_rnn_decoder(
                    cell,
                    decoder_fn_train,
                    self.decoder_input,
                    self.responses_length,
                    scope="decoder")
                self.alignments = tf.transpose(alignments_ta.stack(),
                                               perm=[1, 0, 2])

                self.decoder_loss = sampled_sequence_loss(
                    self.decoder_output, self.responses_target,
                    self.decoder_mask)

            self.params = tf.trainable_variables()

            self.learning_rate = tf.Variable(float(learning_rate),
                                             trainable=False,
                                             dtype=tf.float32)
            self.learning_rate_decay_op = self.learning_rate.assign(
                self.learning_rate * learning_rate_decay_factor)
            self.global_step = tf.Variable(0, trainable=False)

            #opt = tf.train.GradientDescentOptimizer(self.learning_rate)
            opt = tf.train.MomentumOptimizer(self.learning_rate, 0.9)

            gradients = tf.gradients(
                self.decoder_loss + self.decoder_loss_2 + self.decoder_loss_3 +
                self.decoder_loss_4, self.params)
            clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
                gradients, max_gradient_norm)
            self.update = opt.apply_gradients(zip(clipped_gradients,
                                                  self.params),
                                              global_step=self.global_step)

        else:
            with variable_scope.variable_scope('', reuse=True):
                decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference(
                    output_fn, encoder_state_4, attention_keys,
                    attention_values, attention_score_fn,
                    attention_construct_fn, self.embed, GO_ID, EOS_ID,
                    max_length, num_symbols)
                self.decoder_distribution, _, alignments_ta = dynamic_rnn_decoder(
                    cell, decoder_fn_inference, scope="decoder")
                output_len = tf.shape(self.decoder_distribution)[1]
                self.alignments = tf.transpose(
                    alignments_ta.gather(tf.range(output_len)), [1, 0, 2])

            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, num_symbols - 2],
                         2)[1], 2) + 2  # for removing UNK
            self.generation = tf.nn.embedding_lookup(self.symbols,
                                                     self.generation_index,
                                                     name="generation")

            self.params = tf.trainable_variables()

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2,
                                    max_to_keep=10,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)
示例#3
0
  def test_dynamic_rnn_decoder_time_major(self):
    with self.test_session() as sess:
      with variable_scope.variable_scope(
          "root", initializer=init_ops.constant_initializer(0.5)) as varscope:
        # Define inputs/outputs to model
        batch_size = 2
        encoder_embedding_size = 3
        decoder_embedding_size = 4
        encoder_hidden_size = 5
        decoder_hidden_size = encoder_hidden_size
        input_sequence_length = 6
        decoder_sequence_length = 7
        num_decoder_symbols = 20
        start_of_sequence_id = end_of_sequence_id = 1
        decoder_embeddings = variable_scope.get_variable(
            "decoder_embeddings", [num_decoder_symbols, decoder_embedding_size],
            initializer=init_ops.random_normal_initializer(stddev=0.1))
        inputs = constant_op.constant(
            0.5,
            shape=[input_sequence_length, batch_size, encoder_embedding_size])
        decoder_inputs = constant_op.constant(
            0.4,
            shape=[decoder_sequence_length, batch_size, decoder_embedding_size])
        decoder_length = constant_op.constant(
            decoder_sequence_length, dtype=dtypes.int32, shape=[batch_size,])
        with variable_scope.variable_scope("rnn") as scope:
          # setting up weights for computing the final output
          output_fn = lambda x: layers.linear(x, num_decoder_symbols,
                                              scope=scope)

          # Define model
          encoder_outputs, encoder_state = rnn.dynamic_rnn(
              cell=core_rnn_cell_impl.GRUCell(encoder_hidden_size),
              inputs=inputs,
              dtype=dtypes.float32,
              time_major=True,
              scope=scope)

        with variable_scope.variable_scope("decoder") as scope:
          # Train decoder
          decoder_cell = core_rnn_cell_impl.GRUCell(decoder_hidden_size)
          decoder_fn_train = Seq2SeqTest._decoder_fn_with_context_state(
              decoder_fn_lib.simple_decoder_fn_train(
                  encoder_state=encoder_state))
          (decoder_outputs_train, decoder_state_train,
           decoder_context_state_train) = (seq2seq.dynamic_rnn_decoder(
               cell=decoder_cell,
               decoder_fn=decoder_fn_train,
               inputs=decoder_inputs,
               sequence_length=decoder_length,
               time_major=True,
               scope=scope))
          decoder_outputs_train = output_fn(decoder_outputs_train)

          # Setup variable reuse
          scope.reuse_variables()

          # Inference decoder
          decoder_fn_inference = Seq2SeqTest._decoder_fn_with_context_state(
              decoder_fn_lib.simple_decoder_fn_inference(
                  output_fn=output_fn,
                  encoder_state=encoder_state,
                  embeddings=decoder_embeddings,
                  start_of_sequence_id=start_of_sequence_id,
                  end_of_sequence_id=end_of_sequence_id,
                  #TODO: find out why it goes to +1
                  maximum_length=decoder_sequence_length - 1,
                  num_decoder_symbols=num_decoder_symbols,
                  dtype=dtypes.int32))
          (decoder_outputs_inference, decoder_state_inference,
           decoder_context_state_inference) = (seq2seq.dynamic_rnn_decoder(
               cell=decoder_cell,
               decoder_fn=decoder_fn_inference,
               time_major=True,
               scope=scope))

        # Run model
        variables.global_variables_initializer().run()
        (decoder_outputs_train_res, decoder_state_train_res,
         decoder_context_state_train_res) = sess.run([
             decoder_outputs_train, decoder_state_train,
             decoder_context_state_train
         ])
        (decoder_outputs_inference_res, decoder_state_inference_res,
         decoder_context_state_inference_res) = sess.run([
             decoder_outputs_inference, decoder_state_inference,
             decoder_context_state_inference
         ])

        # Assert outputs
        self.assertEqual((decoder_sequence_length, batch_size,
                          num_decoder_symbols), decoder_outputs_train_res.shape)
        self.assertEqual((batch_size, num_decoder_symbols),
                         decoder_outputs_inference_res.shape[1:3])
        self.assertEqual(decoder_sequence_length,
                         decoder_context_state_inference_res)
        self.assertEqual((batch_size, decoder_hidden_size),
                         decoder_state_train_res.shape)
        self.assertEqual((batch_size, decoder_hidden_size),
                         decoder_state_inference_res.shape)
        self.assertEqual(decoder_sequence_length,
                         decoder_context_state_train_res)
        # The dynamic decoder might end earlier than `maximal_length`
        # under inference
        self.assertGreaterEqual(decoder_sequence_length,
                                decoder_state_inference_res.shape[0])
示例#4
0
  def test_attention(self):
    with self.test_session() as sess:
      with variable_scope.variable_scope(
          "root", initializer=init_ops.constant_initializer(0.5)):
        # Define inputs/outputs to model
        batch_size = 2
        encoder_embedding_size = 3
        decoder_embedding_size = 4
        encoder_hidden_size = 5
        decoder_hidden_size = encoder_hidden_size
        input_sequence_length = 6
        decoder_sequence_length = 7
        num_decoder_symbols = 20
        start_of_sequence_id = end_of_sequence_id = 1
        decoder_embeddings = variable_scope.get_variable(
            "decoder_embeddings", [num_decoder_symbols, decoder_embedding_size],
            initializer=init_ops.random_normal_initializer(stddev=0.1))
        inputs = constant_op.constant(
            0.5,
            shape=[input_sequence_length, batch_size, encoder_embedding_size])
        decoder_inputs = constant_op.constant(
            0.4,
            shape=[decoder_sequence_length, batch_size, decoder_embedding_size])
        decoder_length = constant_op.constant(
            decoder_sequence_length, dtype=dtypes.int32, shape=[batch_size,])

        # attention
        attention_option = "luong"  # can be "bahdanau"

        with variable_scope.variable_scope("rnn") as scope:
          # Define model
          encoder_outputs, encoder_state = rnn.dynamic_rnn(
              cell=core_rnn_cell_impl.GRUCell(encoder_hidden_size),
              inputs=inputs,
              dtype=dtypes.float32,
              time_major=True,
              scope=scope)

          # attention_states: size [batch_size, max_time, num_units]
          attention_states = array_ops.transpose(encoder_outputs, [1, 0, 2])

        with variable_scope.variable_scope("decoder") as scope:
          # Prepare attention
          (attention_keys, attention_values, attention_score_fn,
           attention_construct_fn) = (attention_decoder_fn.prepare_attention(
               attention_states, attention_option, decoder_hidden_size))
          decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(
              encoder_state=encoder_state,
              attention_keys=attention_keys,
              attention_values=attention_values,
              attention_score_fn=attention_score_fn,
              attention_construct_fn=attention_construct_fn)

          # setting up weights for computing the final output
          def create_output_fn():

            def output_fn(x):
              return layers.linear(x, num_decoder_symbols, scope=scope)

            return output_fn

          output_fn = create_output_fn()

          # Train decoder
          decoder_cell = core_rnn_cell_impl.GRUCell(decoder_hidden_size)
          (decoder_outputs_train, decoder_state_train, _) = (
              seq2seq.dynamic_rnn_decoder(
                  cell=decoder_cell,
                  decoder_fn=decoder_fn_train,
                  inputs=decoder_inputs,
                  sequence_length=decoder_length,
                  time_major=True,
                  scope=scope))
          decoder_outputs_train = output_fn(decoder_outputs_train)
          # Setup variable reuse
          scope.reuse_variables()

          # Inference decoder
          decoder_fn_inference = (
              attention_decoder_fn.attention_decoder_fn_inference(
                  output_fn=output_fn,
                  encoder_state=encoder_state,
                  attention_keys=attention_keys,
                  attention_values=attention_values,
                  attention_score_fn=attention_score_fn,
                  attention_construct_fn=attention_construct_fn,
                  embeddings=decoder_embeddings,
                  start_of_sequence_id=start_of_sequence_id,
                  end_of_sequence_id=end_of_sequence_id,
                  maximum_length=decoder_sequence_length - 1,
                  num_decoder_symbols=num_decoder_symbols,
                  dtype=dtypes.int32))
          (decoder_outputs_inference, decoder_state_inference, _) = (
              seq2seq.dynamic_rnn_decoder(
                  cell=decoder_cell,
                  decoder_fn=decoder_fn_inference,
                  time_major=True,
                  scope=scope))

        # Run model
        variables.global_variables_initializer().run()
        (decoder_outputs_train_res, decoder_state_train_res) = sess.run(
            [decoder_outputs_train, decoder_state_train])
        (decoder_outputs_inference_res, decoder_state_inference_res) = sess.run(
            [decoder_outputs_inference, decoder_state_inference])

        # Assert outputs
        self.assertEqual((decoder_sequence_length, batch_size,
                          num_decoder_symbols), decoder_outputs_train_res.shape)
        self.assertEqual((batch_size, num_decoder_symbols),
                         decoder_outputs_inference_res.shape[1:3])
        self.assertEqual((batch_size, decoder_hidden_size),
                         decoder_state_train_res.shape)
        self.assertEqual((batch_size, decoder_hidden_size),
                         decoder_state_inference_res.shape)
        # The dynamic decoder might end earlier than `maximal_length`
        # under inference
        self.assertGreaterEqual(decoder_sequence_length,
                                decoder_state_inference_res.shape[0])
    def test_dynamic_rnn_decoder_time_major(self):
        with self.test_session() as sess:
            with variable_scope.variable_scope(
                    "root", initializer=init_ops.constant_initializer(
                        0.5)) as varscope:
                # Define inputs/outputs to model
                batch_size = 2
                encoder_embedding_size = 3
                decoder_embedding_size = 4
                encoder_hidden_size = 5
                decoder_hidden_size = encoder_hidden_size
                input_sequence_length = 6
                decoder_sequence_length = 7
                num_decoder_symbols = 20
                start_of_sequence_id = end_of_sequence_id = 1
                decoder_embeddings = variable_scope.get_variable(
                    "decoder_embeddings",
                    [num_decoder_symbols, decoder_embedding_size],
                    initializer=init_ops.random_normal_initializer(stddev=0.1))
                inputs = constant_op.constant(0.5,
                                              shape=[
                                                  input_sequence_length,
                                                  batch_size,
                                                  encoder_embedding_size
                                              ])
                decoder_inputs = constant_op.constant(
                    0.4,
                    shape=[
                        decoder_sequence_length, batch_size,
                        decoder_embedding_size
                    ])
                decoder_length = constant_op.constant(decoder_sequence_length,
                                                      dtype=dtypes.int32,
                                                      shape=[
                                                          batch_size,
                                                      ])
                with variable_scope.variable_scope("rnn") as scope:
                    # setting up weights for computing the final output
                    output_fn = lambda x: layers.linear(
                        x, num_decoder_symbols, scope=scope)

                    # Define model
                    encoder_outputs, encoder_state = rnn.dynamic_rnn(
                        cell=core_rnn_cell_impl.GRUCell(encoder_hidden_size),
                        inputs=inputs,
                        dtype=dtypes.float32,
                        time_major=True,
                        scope=scope)

                with variable_scope.variable_scope("decoder") as scope:
                    # Train decoder
                    decoder_cell = core_rnn_cell_impl.GRUCell(
                        decoder_hidden_size)
                    decoder_fn_train = Seq2SeqTest._decoder_fn_with_context_state(
                        decoder_fn_lib.simple_decoder_fn_train(
                            encoder_state=encoder_state))
                    (decoder_outputs_train, decoder_state_train,
                     decoder_context_state_train) = (
                         seq2seq.dynamic_rnn_decoder(
                             cell=decoder_cell,
                             decoder_fn=decoder_fn_train,
                             inputs=decoder_inputs,
                             sequence_length=decoder_length,
                             time_major=True,
                             scope=scope))
                    decoder_outputs_train = output_fn(decoder_outputs_train)

                    # Setup variable reuse
                    scope.reuse_variables()

                    # Inference decoder
                    decoder_fn_inference = Seq2SeqTest._decoder_fn_with_context_state(
                        decoder_fn_lib.simple_decoder_fn_inference(
                            output_fn=output_fn,
                            encoder_state=encoder_state,
                            embeddings=decoder_embeddings,
                            start_of_sequence_id=start_of_sequence_id,
                            end_of_sequence_id=end_of_sequence_id,
                            #TODO: find out why it goes to +1
                            maximum_length=decoder_sequence_length - 1,
                            num_decoder_symbols=num_decoder_symbols,
                            dtype=dtypes.int32))
                    (decoder_outputs_inference, decoder_state_inference,
                     decoder_context_state_inference) = (
                         seq2seq.dynamic_rnn_decoder(
                             cell=decoder_cell,
                             decoder_fn=decoder_fn_inference,
                             time_major=True,
                             scope=scope))

                # Run model
                variables.global_variables_initializer().run()
                (decoder_outputs_train_res, decoder_state_train_res,
                 decoder_context_state_train_res) = sess.run([
                     decoder_outputs_train, decoder_state_train,
                     decoder_context_state_train
                 ])
                (decoder_outputs_inference_res, decoder_state_inference_res,
                 decoder_context_state_inference_res) = sess.run([
                     decoder_outputs_inference, decoder_state_inference,
                     decoder_context_state_inference
                 ])

                # Assert outputs
                self.assertEqual(
                    (decoder_sequence_length, batch_size, num_decoder_symbols),
                    decoder_outputs_train_res.shape)
                self.assertEqual((batch_size, num_decoder_symbols),
                                 decoder_outputs_inference_res.shape[1:3])
                self.assertEqual(decoder_sequence_length,
                                 decoder_context_state_inference_res)
                self.assertEqual((batch_size, decoder_hidden_size),
                                 decoder_state_train_res.shape)
                self.assertEqual((batch_size, decoder_hidden_size),
                                 decoder_state_inference_res.shape)
                self.assertEqual(decoder_sequence_length,
                                 decoder_context_state_train_res)
                # The dynamic decoder might end earlier than `maximal_length`
                # under inference
                self.assertGreaterEqual(decoder_sequence_length,
                                        decoder_state_inference_res.shape[0])
    def test_attention(self):
        with self.test_session() as sess:
            with variable_scope.variable_scope(
                    "root", initializer=init_ops.constant_initializer(0.5)):
                # Define inputs/outputs to model
                batch_size = 2
                encoder_embedding_size = 3
                decoder_embedding_size = 4
                encoder_hidden_size = 5
                decoder_hidden_size = encoder_hidden_size
                input_sequence_length = 6
                decoder_sequence_length = 7
                num_decoder_symbols = 20
                start_of_sequence_id = end_of_sequence_id = 1
                decoder_embeddings = variable_scope.get_variable(
                    "decoder_embeddings",
                    [num_decoder_symbols, decoder_embedding_size],
                    initializer=init_ops.random_normal_initializer(stddev=0.1))
                inputs = constant_op.constant(0.5,
                                              shape=[
                                                  input_sequence_length,
                                                  batch_size,
                                                  encoder_embedding_size
                                              ])
                decoder_inputs = constant_op.constant(
                    0.4,
                    shape=[
                        decoder_sequence_length, batch_size,
                        decoder_embedding_size
                    ])
                decoder_length = constant_op.constant(decoder_sequence_length,
                                                      dtype=dtypes.int32,
                                                      shape=[
                                                          batch_size,
                                                      ])

                # attention
                attention_option = "luong"  # can be "bahdanau"

                with variable_scope.variable_scope("rnn") as scope:
                    # Define model
                    encoder_outputs, encoder_state = rnn.dynamic_rnn(
                        cell=core_rnn_cell_impl.GRUCell(encoder_hidden_size),
                        inputs=inputs,
                        dtype=dtypes.float32,
                        time_major=True,
                        scope=scope)

                    # attention_states: size [batch_size, max_time, num_units]
                    attention_states = array_ops.transpose(
                        encoder_outputs, [1, 0, 2])

                with variable_scope.variable_scope("decoder") as scope:
                    # Prepare attention
                    (attention_keys, attention_values, attention_score_fn,
                     attention_construct_fn) = (
                         attention_decoder_fn.prepare_attention(
                             attention_states, attention_option,
                             decoder_hidden_size))
                    decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(
                        encoder_state=encoder_state,
                        attention_keys=attention_keys,
                        attention_values=attention_values,
                        attention_score_fn=attention_score_fn,
                        attention_construct_fn=attention_construct_fn)

                    # setting up weights for computing the final output
                    def create_output_fn():
                        def output_fn(x):
                            return layers.linear(x,
                                                 num_decoder_symbols,
                                                 scope=scope)

                        return output_fn

                    output_fn = create_output_fn()

                    # Train decoder
                    decoder_cell = core_rnn_cell_impl.GRUCell(
                        decoder_hidden_size)
                    (decoder_outputs_train, decoder_state_train,
                     _) = (seq2seq.dynamic_rnn_decoder(
                         cell=decoder_cell,
                         decoder_fn=decoder_fn_train,
                         inputs=decoder_inputs,
                         sequence_length=decoder_length,
                         time_major=True,
                         scope=scope))
                    decoder_outputs_train = output_fn(decoder_outputs_train)
                    # Setup variable reuse
                    scope.reuse_variables()

                    # Inference decoder
                    decoder_fn_inference = (
                        attention_decoder_fn.attention_decoder_fn_inference(
                            output_fn=output_fn,
                            encoder_state=encoder_state,
                            attention_keys=attention_keys,
                            attention_values=attention_values,
                            attention_score_fn=attention_score_fn,
                            attention_construct_fn=attention_construct_fn,
                            embeddings=decoder_embeddings,
                            start_of_sequence_id=start_of_sequence_id,
                            end_of_sequence_id=end_of_sequence_id,
                            maximum_length=decoder_sequence_length - 1,
                            num_decoder_symbols=num_decoder_symbols,
                            dtype=dtypes.int32))
                    (decoder_outputs_inference, decoder_state_inference,
                     _) = (seq2seq.dynamic_rnn_decoder(
                         cell=decoder_cell,
                         decoder_fn=decoder_fn_inference,
                         time_major=True,
                         scope=scope))

                # Run model
                variables.global_variables_initializer().run()
                (decoder_outputs_train_res,
                 decoder_state_train_res) = sess.run(
                     [decoder_outputs_train, decoder_state_train])
                (decoder_outputs_inference_res,
                 decoder_state_inference_res) = sess.run(
                     [decoder_outputs_inference, decoder_state_inference])

                # Assert outputs
                self.assertEqual(
                    (decoder_sequence_length, batch_size, num_decoder_symbols),
                    decoder_outputs_train_res.shape)
                self.assertEqual((batch_size, num_decoder_symbols),
                                 decoder_outputs_inference_res.shape[1:3])
                self.assertEqual((batch_size, decoder_hidden_size),
                                 decoder_state_train_res.shape)
                self.assertEqual((batch_size, decoder_hidden_size),
                                 decoder_state_inference_res.shape)
                # The dynamic decoder might end earlier than `maximal_length`
                # under inference
                self.assertGreaterEqual(decoder_sequence_length,
                                        decoder_state_inference_res.shape[0])
示例#7
0
    def __init__(self,
                 num_symbols,
                 num_embed_units,
                 num_units,
                 num_layers,
                 beam_size,
                 embed,
                 learning_rate=0.5,
                 remove_unk=False,
                 learning_rate_decay_factor=0.95,
                 max_gradient_norm=5.0,
                 num_samples=512,
                 max_length=8,
                 use_lstm=False):

        self.posts = tf.placeholder(tf.string, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None),
                                           'enc_lens')  # batch
        self.responses = tf.placeholder(tf.string, (None, None),
                                        'dec_inps')  # batch*len
        self.responses_length = tf.placeholder(tf.int32, (None),
                                               'dec_lens')  # batch

        # initialize the training process
        self.learning_rate = tf.Variable(float(learning_rate),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)
        self.global_step = tf.Variable(0, trainable=False)

        self.symbol2index = MutableHashTable(key_dtype=tf.string,
                                             value_dtype=tf.int64,
                                             default_value=UNK_ID,
                                             shared_name="in_table",
                                             name="in_table",
                                             checkpoint=True)
        self.index2symbol = MutableHashTable(key_dtype=tf.int64,
                                             value_dtype=tf.string,
                                             default_value='_UNK',
                                             shared_name="out_table",
                                             name="out_table",
                                             checkpoint=True)
        # build the vocab table (string to index)

        self.posts_input = self.symbol2index.lookup(self.posts)  # batch*len
        self.responses_target = self.symbol2index.lookup(
            self.responses)  #batch*len

        batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(
            self.responses)[1]
        self.responses_input = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID,
            tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0]
        ], 1)  # batch*len
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        # build the embedding table (index to vector)
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable('embed',
                                         [num_symbols, num_embed_units],
                                         tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(
            self.embed, self.posts_input)  #batch*len*unit
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)

        if use_lstm:
            cell = MultiRNNCell([LSTMCell(num_units)] * num_layers)
        else:
            cell = MultiRNNCell([GRUCell(num_units)] * num_layers)

        # rnn encoder
        encoder_output, encoder_state = dynamic_rnn(cell,
                                                    self.encoder_input,
                                                    self.posts_length,
                                                    dtype=tf.float32,
                                                    scope="encoder")

        # get output projection function
        output_fn, sampled_sequence_loss = output_projection_layer(
            num_units, num_symbols, num_samples)

        # get attention function
        attention_keys, attention_values, attention_score_fn, attention_construct_fn \
                = attention_decoder_fn.prepare_attention(encoder_output, 'luong', num_units)

        with tf.variable_scope('decoder'):
            decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(
                encoder_state, attention_keys, attention_values,
                attention_score_fn, attention_construct_fn)
            self.decoder_output, _, _ = dynamic_rnn_decoder(
                cell,
                decoder_fn_train,
                self.decoder_input,
                self.responses_length,
                scope="decoder_rnn")
            self.decoder_loss = sampled_sequence_loss(self.decoder_output,
                                                      self.responses_target,
                                                      self.decoder_mask)

        with tf.variable_scope('decoder', reuse=True):
            decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference(
                output_fn, encoder_state, attention_keys, attention_values,
                attention_score_fn, attention_construct_fn, self.embed, GO_ID,
                EOS_ID, max_length, num_symbols)

            self.decoder_distribution, _, _ = dynamic_rnn_decoder(
                cell, decoder_fn_inference, scope="decoder_rnn")
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, num_symbols - 2],
                         2)[1], 2) + 2  # for removing UNK
            self.generation = self.index2symbol.lookup(self.generation_index,
                                                       name='generation')

        with tf.variable_scope('decoder', reuse=True):
            decoder_fn_beam_inference = attention_decoder_fn_beam_inference(
                output_fn, encoder_state, attention_keys, attention_values,
                attention_score_fn, attention_construct_fn, self.embed, GO_ID,
                EOS_ID, max_length, num_symbols, beam_size, remove_unk)
            _, _, self.context_state = dynamic_rnn_decoder(
                cell, decoder_fn_beam_inference, scope="decoder_rnn")
            (log_beam_probs, beam_parents, beam_symbols, result_probs,
             result_parents, result_symbols) = self.context_state

            self.beam_parents = tf.transpose(tf.reshape(
                beam_parents.stack(), [max_length + 1, -1, beam_size]),
                                             [1, 0, 2],
                                             name='beam_parents')
            self.beam_symbols = tf.transpose(
                tf.reshape(beam_symbols.stack(),
                           [max_length + 1, -1, beam_size]), [1, 0, 2])
            self.beam_symbols = self.index2symbol.lookup(tf.cast(
                self.beam_symbols, tf.int64),
                                                         name="beam_symbols")

            self.result_probs = tf.transpose(tf.reshape(
                result_probs.stack(), [max_length + 1, -1, beam_size * 2]),
                                             [1, 0, 2],
                                             name='result_probs')
            self.result_symbols = tf.transpose(
                tf.reshape(result_symbols.stack(),
                           [max_length + 1, -1, beam_size * 2]), [1, 0, 2])
            self.result_parents = tf.transpose(tf.reshape(
                result_parents.stack(), [max_length + 1, -1, beam_size * 2]),
                                               [1, 0, 2],
                                               name='result_parents')
            self.result_symbols = self.index2symbol.lookup(
                tf.cast(self.result_symbols, tf.int64), name='result_symbols')

        self.params = tf.trainable_variables()

        # calculate the gradient of parameters
        opt = tf.train.GradientDescentOptimizer(self.learning_rate)
        gradients = tf.gradients(self.decoder_loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, max_gradient_norm)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                    max_to_keep=3,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)

        # Exporter for serving
        self.model_exporter = exporter.Exporter(self.saver)
        inputs = {"enc_inps:0": self.posts, "enc_lens:0": self.posts_length}
        outputs = {
            "beam_symbols": self.beam_symbols,
            "beam_parents": self.beam_parents,
            "result_probs": self.result_probs,
            "result_symbols": self.result_symbols,
            "result_parents": self.result_parents
        }
        self.model_exporter.init(tf.get_default_graph().as_graph_def(),
                                 named_graph_signatures={
                                     "inputs":
                                     exporter.generic_signature(inputs),
                                     "outputs":
                                     exporter.generic_signature(outputs)
                                 })
    def __init__(self,
            num_symbols,
            num_qwords, #modify
            num_embed_units,
            num_units,
            num_layers,
            is_train,
            vocab=None,
            embed=None,
            question_data=True,
            learning_rate=0.5,
            learning_rate_decay_factor=0.95,
            max_gradient_norm=5.0,
            num_samples=512,
            max_length=30,
            use_lstm=False):

        self.posts = tf.placeholder(tf.string, shape=(None, None))  # batch*len
        self.posts_length = tf.placeholder(tf.int32, shape=(None))  # batch
        self.responses = tf.placeholder(tf.string, shape=(None, None))  # batch*len
        self.responses_length = tf.placeholder(tf.int32, shape=(None))  # batch
        self.keyword_tensor = tf.placeholder(tf.float32, shape=(None, 3, None)) #(batch * len) * 3 * numsymbol
        self.word_type = tf.placeholder(tf.int32, shape=(None))   #(batch * len)

        # build the vocab table (string to index)
        if is_train:
            self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
        else:
            self.symbols = tf.Variable(np.array(['.']*num_symbols), name="symbols")
        self.symbol2index = HashTable(KeyValueTensorInitializer(self.symbols,
            tf.Variable(np.array([i for i in range(num_symbols)], dtype=np.int32), False)),
            default_value=UNK_ID, name="symbol2index")
        self.posts_input = self.symbol2index.lookup(self.posts)   # batch*len
        self.responses_target = self.symbol2index.lookup(self.responses)   #batch*len
        
        batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(self.responses)[1]
        self.responses_input = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32)*GO_ID,
            tf.split(self.responses_target, [decoder_len-1, 1], 1)[0]], 1)   # batch*len
        #delete the last column of responses_target) and add 'GO at the front of it.
        self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.responses_length-1,
            decoder_len), reverse=True, axis=1), [-1, decoder_len]) # bacth * len

        print "embedding..."
        # build the embedding table (index to vector)
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32)
        else:
            print len(vocab), len(embed), len(embed[0])
            print embed
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts_input) #batch*len*unit
        self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input)

        print "embedding finished"

        if use_lstm:
            cell = MultiRNNCell([LSTMCell(num_units)] * num_layers)
        else:
            cell = MultiRNNCell([GRUCell(num_units)] * num_layers)

        # rnn encoder
        encoder_output, encoder_state = dynamic_rnn(cell, self.encoder_input,
                self.posts_length, dtype=tf.float32, scope="encoder")
        # get output projection function
        output_fn, sampled_sequence_loss = output_projection_layer(num_units,
                num_symbols, num_qwords, num_samples, question_data)

        print "encoder_output.shape:", encoder_output.get_shape()

        # get attention function
        attention_keys, attention_values, attention_score_fn, attention_construct_fn \
              = attention_decoder_fn.prepare_attention(encoder_output, 'luong', num_units)

        # get decoding loop function
        decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(encoder_state,
                attention_keys, attention_values, attention_score_fn, attention_construct_fn)
        decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference(output_fn,
                self.keyword_tensor,
                encoder_state, attention_keys, attention_values, attention_score_fn,
                attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols)

        if is_train:
            # rnn decoder
            self.decoder_output, _, _ = dynamic_rnn_decoder(cell, decoder_fn_train,
                    self.decoder_input, self.responses_length, scope="decoder")
            # calculate the loss of decoder
            # self.decoder_output = tf.Print(self.decoder_output, [self.decoder_output])
            self.decoder_loss, self.log_perplexity = sampled_sequence_loss(self.decoder_output,
                    self.responses_target, self.decoder_mask, self.keyword_tensor, self.word_type)

            # building graph finished and get all parameters
            self.params = tf.trainable_variables()

            for item in tf.trainable_variables():
                print item.name, item.get_shape()

            # initialize the training process
            self.learning_rate = tf.Variable(float(learning_rate), trainable=False,
                    dtype=tf.float32)
            self.learning_rate_decay_op = self.learning_rate.assign(
                    self.learning_rate * learning_rate_decay_factor)

            self.global_step = tf.Variable(0, trainable=False)

            # calculate the gradient of parameters

            opt = tf.train.GradientDescentOptimizer(self.learning_rate)
            gradients = tf.gradients(self.decoder_loss, self.params)
            clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients,
                    max_gradient_norm)
            self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                    global_step=self.global_step)

        else:
            # rnn decoder
            self.decoder_distribution, _, _ = dynamic_rnn_decoder(cell, decoder_fn_inference,
                    scope="decoder")
            print("self.decoder_distribution.shape():",self.decoder_distribution.get_shape())
            self.decoder_distribution = tf.Print(self.decoder_distribution, ["distribution.shape()", tf.reduce_sum(self.decoder_distribution)])
            # generating the response
            self.generation_index = tf.argmax(tf.split(self.decoder_distribution,
                [2, num_symbols-2], 2)[1], 2) + 2 # for removing UNK
            self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index)

            self.params = tf.trainable_variables()

        self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2,
                max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
示例#9
0
def test_dynamic_rnn_decoder():
    with tf.Session() as sess:
        with tf.variable_scope(
                "root", initializer=tf.constant_initializer(0.5)) as varscope:
            batch_size = 2
            encoder_embedding_size = 3
            decoder_embedding_size = 4
            encoder_hidden_size = 5
            decoder_hidden_size = encoder_hidden_size
            input_sequence_length = 6
            decoder_sequence_length = 7
            num_decoder_symbols = 20
            start_of_sequence_id = end_of_sequence_id = 1

            decoder_embeddings = tf.get_variable(
                "decoder_embeddings",
                [num_decoder_symbols, decoder_embedding_size],
                initializer=tf.random_normal_initializer(stddev=0.1))

            inputs = tf.constant(0.5,
                                 shape=[
                                     input_sequence_length, batch_size,
                                     encoder_embedding_size
                                 ])

            decoder_inputs = tf.constant(0.4,
                                         shape=[
                                             decoder_sequence_length,
                                             batch_size, decoder_embedding_size
                                         ])

            decoder_length = tf.constant(decoder_sequence_length,
                                         dtype=dtypes.int32,
                                         shape=[
                                             batch_size,
                                         ])

            with tf.variable_scope("rnn") as scope:
                # setting up weights for computing the final output
                output_fn = lambda x: layers.linear(
                    x, num_decoder_symbols, scope=scope)

                # Define model
                encoder_outputs, encoder_state = rnn.dynamic_rnn(
                    cell=core_rnn_cell_impl.GRUCell(encoder_hidden_size),
                    inputs=inputs,
                    dtype=dtypes.float32,
                    time_major=True,
                    scope=scope)

            with tf.variable_scope("decoder") as scope:
                # Train decoder
                decoder_cell = core_rnn_cell_impl.GRUCell(decoder_hidden_size)

                decoder_fn_train = _decoder_fn_with_context_state(
                    decoder_fn_lib.simple_decoder_fn_train(
                        encoder_state=encoder_state))

                (decoder_outputs_train, decoder_state_train,
                 decoder_context_state_train) = seq2seq.dynamic_rnn_decoder(
                     cell=decoder_cell,
                     decoder_fn=decoder_fn_train,
                     inputs=decoder_inputs,
                     sequence_length=decoder_length,
                     time_major=True,
                     scope=scope)

                decoder_outputs_train = output_fn(decoder_outputs_train)

                # Setup variable reuse
                scope.reuse_variables()

                # Inference decoder
                decoder_fn_inference = _decoder_fn_with_context_state(
                    decoder_fn_lib.simple_decoder_fn_inference(
                        output_fn=output_fn,
                        encoder_state=encoder_state,
                        embeddings=decoder_embeddings,
                        start_of_sequence_id=start_of_sequence_id,
                        end_of_sequence_id=end_of_sequence_id,
                        maximum_length=decoder_sequence_length - 1,
                        num_decoder_symbols=num_decoder_symbols,
                        dtype=dtypes.int32))

                (decoder_outputs_inference, decoder_state_inference,
                 decoder_context_state_inference) = (
                     seq2seq.dynamic_rnn_decoder(
                         cell=decoder_cell,
                         decoder_fn=decoder_fn_inference,
                         time_major=True,
                         scope=scope))

                output_train = tf.argmax(decoder_outputs_train, axis=2)
                output_inference = tf.argmax(decoder_outputs_inference, axis=2)

                tf.global_variables_initializer().run()
                (decoder_outputs_train_res, decoder_state_train_res,
                 decoder_context_state_train_res) = sess.run([
                     decoder_outputs_train, decoder_state_train,
                     decoder_context_state_train
                 ])

                (decoder_outputs_inference_res, decoder_state_inference_res,
                 decoder_context_state_inference_res) = sess.run([
                     decoder_outputs_inference, decoder_state_inference,
                     decoder_context_state_inference
                 ])

                print np.shape(decoder_outputs_train_res)
                print np.shape(decoder_outputs_inference_res)
                output_train, output_inference = sess.run(
                    [output_train, output_inference])
                print output_train
                print output_inference
示例#10
0
    def _build_graph(self):

        # build the graph
        self.graph = tf.Graph()

        with self.graph.as_default():
            tf.set_random_seed(self.random_seed)

            # DATASET PLACEHOLDERS

            # (batch, time)
            source = tf.placeholder(tf.int32)
            source_mask = tf.placeholder(tf.float32)
            target = tf.placeholder(tf.int32)
            target_mask = tf.placeholder(tf.float32)
            output = tf.placeholder(tf.int32)
            output_mask = tf.placeholder(tf.float32)

            # TODO: add factored contexts (POS, NER, ETC...)
            # ner_context = tf.placeholder(tf.int32)

            # sets the probability of dropping out
            dropout_prob = tf.placeholder(tf.float32)

            with tf.name_scope('embeddings'):
                source_embeddings = tf.get_variable(
                    "source_embeddings",
                    [self.src_vocab_size, self.config['embedding_size']],
                    trainable=True)
                # TODO: support factors for source and target inputs
                # ner_embeddings = tf.get_variable("ner_embeddings", [self.meta['num_ner_tags'], self.meta['ner_embedding_size']],
                #                                   trainable=True)

                # default: just embed the tokens in the source context
                source_embed = tf.nn.embedding_lookup(source_embeddings,
                                                      source)

                if self.use_ner_embeddings:
                    pass
                    # TODO: support factors for source input
                    # ner_embed = tf.nn.embedding_lookup(ner_embeddings, ner_context)
                    # context_embed = tf.concat([context_embed, ner_embed], 2)
                    # context_embed.set_shape([None, None, self.meta['embedding_size'] + self.meta['ner_embedding_size']])
                else:
                    # this is to fix shape inference bug in rnn.py -- see this issue: https://github.com/tensorflow/tensorflow/issues/2938
                    source_embed.set_shape(
                        [None, None, self.config['embedding_size']])

                # TODO: switch this to target language embeddings
                # TODO: support target language factors (POS, NER, etc...)
                target_embeddings = tf.get_variable(
                    "target_embeddings",
                    [self.trg_vocab_size, self.config['embedding_size']])

                # target embeddings - these are the _inputs_ to the decoder
                target_embed = tf.nn.embedding_lookup(target_embeddings,
                                                      target)
                target_embed.set_shape(
                    [None, None, self.config['embedding_size']])

            # Construct input representation that we'll put attention over
            # Note: dropout is turned on/off by `dropout_prob`
            with tf.name_scope('input_representation'):
                lstm_cells = [
                    tf.contrib.rnn.DropoutWrapper(
                        tf.contrib.rnn.LSTMCell(
                            self.config['encoder_hidden_size'],
                            use_peepholes=True,
                            state_is_tuple=True),
                        input_keep_prob=dropout_prob,
                        output_keep_prob=dropout_prob)
                    for _ in range(self.config['lstm_stack_size'])
                ]

                cell = tf.contrib.rnn.MultiRNNCell(lstm_cells,
                                                   state_is_tuple=True)

                # use the description mask to get the sequence lengths
                source_sequence_length = tf.cast(tf.reduce_sum(source_mask, 1),
                                                 tf.int64)

                # BIDIRECTIONAL RNNs
                # Bidir outputs are (output_fw, output_bw)
                bidir_outputs, bidir_state = tf.nn.bidirectional_dynamic_rnn(
                    cell_fw=cell,
                    cell_bw=cell,
                    inputs=source_embed,
                    sequence_length=source_sequence_length,
                    dtype=tf.float32)
                l_to_r_states, r_to_l_states = bidir_state

                # Transpose to be time-major
                # TODO: do we need to transpose?
                # attention_states = tf.transpose(tf.concat(bidir_outputs, 2), [1, 0, 2])
                attention_states = tf.concat(bidir_outputs, 2)

                # Note: encoder is bidirectional, so we reduce dimensionality by 1/2 to make decoder initial state
                init_state_transformation = tf.get_variable(
                    'decoder_init_transform',
                    (self.config['encoder_hidden_size'] * 2,
                     self.config['decoder_hidden_size']))
                initialization_state = tf.matmul(
                    tf.concat([r_to_l_states[-1][1], l_to_r_states[-1][1]], 1),
                    init_state_transformation)

                # alternatively just use the final l_to_r state
                # initialization_state = l_to_r_states[-1][1]

                # TODO: try with simple L-->R GRU
                # encoder_outputs, encoder_state = rnn.dynamic_rnn(
                #     cell=core_rnn_cell_impl.GRUCell(encoder_hidden_size),
                #     inputs=inputs,
                #     dtype=dtypes.float32,
                #     time_major=False,
                #     scope=scope)

            with tf.name_scope('target_representation'):
                target_lstm_cells = [
                    tf.contrib.rnn.DropoutWrapper(
                        tf.contrib.rnn.LSTMCell(
                            self.config['encoder_hidden_size'],
                            use_peepholes=True,
                            state_is_tuple=True),
                        input_keep_prob=dropout_prob,
                        output_keep_prob=dropout_prob)
                    for _ in range(self.config['lstm_stack_size'])
                ]

                target_cell = tf.contrib.rnn.MultiRNNCell(target_lstm_cells,
                                                          state_is_tuple=True)
                # bidirectional target representation
                target_lengths = tf.cast(tf.reduce_sum(target_mask, axis=1),
                                         dtype=tf.int32)
                target_bidir_outputs, target_bidir_state = tf.nn.bidirectional_dynamic_rnn(
                    cell_fw=target_cell,
                    cell_bw=target_cell,
                    inputs=target_embed,
                    sequence_length=target_lengths,
                    dtype=tf.float32,
                    scope='target_bidir_rnn')
                target_l_to_r_states, target_r_to_l_states = target_bidir_state
                target_representation = tf.concat(target_bidir_outputs, 2)

            # Now construct the decoder
            decoder_hidden_size = self.config['decoder_hidden_size']
            # attention
            attention_option = "bahdanau"  # can be "luong"

            with variable_scope.variable_scope("decoder") as scope:

                # Prepare attention
                (attention_keys, attention_values, attention_score_fn,
                 attention_construct_fn) = (
                     attention_decoder_fn.prepare_attention(
                         attention_states, attention_option,
                         decoder_hidden_size))

                decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(
                    encoder_state=initialization_state,
                    attention_keys=attention_keys,
                    attention_values=attention_values,
                    attention_score_fn=attention_score_fn,
                    attention_construct_fn=attention_construct_fn)

                # Note: this is different from the "normal" seq2seq encoder-decoder model, because we have different
                # input and output vocabularies for the decoder (target vocab vs. QE symbols)
                # num_decoder_symbols = self.output_vocab_size
                # decoder vocab is characters or sub-words? -- either way, we need to learn the vocab over the entity set
                # setting up weights for computing the final output
                # def create_output_fn():
                #     def output_fn(x):
                #         return layers.linear(x, num_decoder_symbols, scope=scope)
                #     return output_fn

                # output_fn = create_output_fn()

                intermediate_dim = 512
                output_transformation_1 = tf.Variable(
                    tf.random_normal([
                        self.config['decoder_hidden_size'] +
                        self.config['encoder_hidden_size'] * 2,
                        intermediate_dim
                    ]),
                    name='output_transformation_1')
                output_biases_1 = tf.Variable(tf.zeros([intermediate_dim]),
                                              name='output_biases_1')

                output_transformation_2 = tf.Variable(
                    tf.random_normal(
                        [intermediate_dim, self.output_vocab_size]),
                    name='output_transformation_2')
                output_biases_2 = tf.Variable(tf.zeros(
                    [self.output_vocab_size]),
                                              name='output_biases_2')

                # Train decoder
                decoder_cell = core_rnn_cell_impl.GRUCell(decoder_hidden_size)

                (decoder_outputs_train, decoder_state_train,
                 _) = (seq2seq.dynamic_rnn_decoder(
                     cell=decoder_cell,
                     decoder_fn=decoder_fn_train,
                     inputs=target_embed,
                     sequence_length=target_lengths,
                     time_major=False,
                     scope=scope))

                # TODO: for attentive QE, we don't need to separate train and inference decoders
                # TODO: we can directly use train decoder output at both training and prediction time

                # concat with target lm representation
                decoder_outputs_train = tf.concat(
                    [decoder_outputs_train, target_representation], 2)
                decoder_outputs_train = tf.nn.elu(decoder_outputs_train)
                decoder_outputs_train = tf.nn.dropout(decoder_outputs_train,
                                                      keep_prob=dropout_prob)

                output_shape = tf.shape(decoder_outputs_train)

                decoder_outputs_train = tf.matmul(
                    tf.reshape(decoder_outputs_train,
                               [output_shape[0] * output_shape[1], -1]),
                    output_transformation_1)
                decoder_outputs_train += output_biases_1
                decoder_outputs_train = tf.nn.elu(decoder_outputs_train)
                decoder_outputs_train = tf.nn.dropout(decoder_outputs_train,
                                                      keep_prob=dropout_prob)

                # one more linear layer
                decoder_outputs_train = tf.matmul(decoder_outputs_train,
                                                  output_transformation_2)
                decoder_outputs_train += output_biases_2

                decoder_outputs_train = tf.reshape(
                    decoder_outputs_train,
                    [output_shape[0], output_shape[1], -1])

                # DEBUGGING: dump these
                # self.decoder_outputs_train = decoder_outputs_train

            with tf.name_scope('predictions'):
                prediction_logits = decoder_outputs_train
                logit_histo = tf.summary.histogram('prediction_logits',
                                                   prediction_logits)

                predictions = tf.nn.softmax(prediction_logits)
                self.predictions = predictions

                # correct_predictions = tf.equal(tf.cast(tf.argmax(predictions, 1), tf.int32), entity)
                # accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
                # accuracy_summary = tf.summary.scalar('accuracy', accuracy)

            with tf.name_scope('xent'):
                # Note: set output and output_mask shape because they're needed here:
                # https://github.com/tensorflow/tensorflow/blob/r1.0/tensorflow/contrib/seq2seq/python/ops/loss.py#L65-L70
                output.set_shape([None, None])
                output_mask.set_shape([None, None])
                costs = tf.contrib.seq2seq.sequence_loss(
                    logits=decoder_outputs_train,
                    targets=output,
                    weights=output_mask,
                    average_across_timesteps=True)
                cost = tf.reduce_mean(costs)
                cost_summary = tf.summary.scalar('minibatch_cost', cost)

            # expose placeholders and ops on the class
            self.source = source
            self.source_mask = source_mask
            self.target = target
            self.target_mask = target_mask
            self.output = output
            self.output_mask = output_mask
            self.predictions = predictions
            self.cost = cost
            self.dropout_prob = dropout_prob

            # TODO: expose embeddings so that they can be visualized?

            optimizer = tf.train.AdamOptimizer()
            with tf.name_scope('train'):
                gradients = optimizer.compute_gradients(
                    cost, tf.trainable_variables())
                if self.config['max_gradient_norm'] is not None:
                    gradients, variables = zip(*gradients)
                    clipped_gradients, _ = clip_ops.clip_by_global_norm(
                        gradients, self.config['max_gradient_norm'])
                    gradients = list(zip(clipped_gradients, variables))

                for gradient, variable in gradients:
                    if isinstance(gradient, ops.IndexedSlices):
                        grad_values = gradient.values
                    else:
                        grad_values = gradient
                    tf.summary.histogram(variable.name, variable)
                    tf.summary.histogram(variable.name + '/gradients',
                                         grad_values)
                    tf.summary.histogram(variable.name + '/gradient_norm',
                                         clip_ops.global_norm([grad_values]))

                self.full_graph_optimizer = optimizer.apply_gradients(
                    gradients)

                # Optimizer #2 -- updates entity representations only
                # entity_representation_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                #                                                      "representation/entity_lookup")
                # self.entity_representation_optimizer = optimizer.minimize(cost,
                #                                                           var_list=entity_representation_train_vars)

            self.saver = tf.train.Saver()

            # self.accuracy = accuracy
            self.merged = tf.summary.merge_all()

            logger.info('Finished building model graph')
示例#11
0
    def __init__(self, sess, config, api, log_dir, forward, scope=None):
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.topic_vocab = api.topic_vocab
        self.topic_vocab_size = len(self.topic_vocab)
        self.da_vocab = api.dialog_act_vocab
        self.da_vocab_size = len(self.da_vocab)
        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.context_cell_size = config.cxt_cell_size
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size

        with tf.name_scope("io"):
            # all dialog context and known attributes
            self.input_contexts = tf.placeholder(dtype=tf.int32, shape=(None, None, self.max_utt_len), name="dialog_context")
            self.floors = tf.placeholder(dtype=tf.int32, shape=(None, None), name="floor")
            self.context_lens = tf.placeholder(dtype=tf.int32, shape=(None,), name="context_lens")
            self.topics = tf.placeholder(dtype=tf.int32, shape=(None,), name="topics")
            self.my_profile = tf.placeholder(dtype=tf.float32, shape=(None, 4), name="my_profile")
            self.ot_profile = tf.placeholder(dtype=tf.float32, shape=(None, 4), name="ot_profile")

            # target response given the dialog context
            self.output_tokens = tf.placeholder(dtype=tf.int32, shape=(None, None), name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32, shape=(None,), name="output_lens")
            self.output_das = tf.placeholder(dtype=tf.int32, shape=(None,), name="output_dialog_acts")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr), trainable=False, name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        max_dialog_len = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        with variable_scope.variable_scope("topicEmbedding"):
            t_embedding = tf.get_variable("embedding", [self.topic_vocab_size, config.topic_embed_size], dtype=tf.float32)
            topic_embedding = embedding_ops.embedding_lookup(t_embedding, self.topics)

        if config.use_hcf:
            with variable_scope.variable_scope("dialogActEmbedding"):
                d_embedding = tf.get_variable("embedding", [self.da_vocab_size, config.da_embed_size], dtype=tf.float32)
                da_embedding = embedding_ops.embedding_lookup(d_embedding, self.output_das)

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable("embedding", [self.vocab_size, config.embed_size], dtype=tf.float32)
            embedding_mask = tf.constant([0 if i == 0 else 1 for i in range(self.vocab_size)], dtype=tf.float32,
                                         shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            input_embedding = embedding_ops.embedding_lookup(embedding, tf.reshape(self.input_contexts, [-1]))
            input_embedding = tf.reshape(input_embedding, [-1, self.max_utt_len, config.embed_size])
            output_embedding = embedding_ops.embedding_lookup(embedding, self.output_tokens)

            if config.sent_type == "bow":
                input_embedding, sent_size = get_bow(input_embedding)
                output_embedding, _ = get_bow(output_embedding)

            elif config.sent_type == "rnn":
                sent_cell = self.get_rnncell("gru", self.sent_cell_size, config.keep_prob, 1)
                input_embedding, sent_size = get_rnn_encode(input_embedding, sent_cell, scope="sent_rnn")
                output_embedding, _ = get_rnn_encode(output_embedding, sent_cell, self.output_lens,
                                                     scope="sent_rnn", reuse=True)
            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1)
                input_embedding, sent_size = get_bi_rnn_encode(input_embedding, fwd_sent_cell, bwd_sent_cell, scope="sent_bi_rnn")
                output_embedding, _ = get_bi_rnn_encode(output_embedding, fwd_sent_cell, bwd_sent_cell, self.output_lens, scope="sent_bi_rnn", reuse=True)
            else:
                raise ValueError("Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")

            # reshape input into dialogs
            input_embedding = tf.reshape(input_embedding, [-1, max_dialog_len, sent_size])
            if config.keep_prob < 1.0:
                input_embedding = tf.nn.dropout(input_embedding, config.keep_prob)

            # convert floors into 1 hot
            floor_one_hot = tf.one_hot(tf.reshape(self.floors, [-1]), depth=2, dtype=tf.float32)
            floor_one_hot = tf.reshape(floor_one_hot, [-1, max_dialog_len, 2])

            joint_embedding = tf.concat([input_embedding, floor_one_hot], 2, "joint_embedding")

        with variable_scope.variable_scope("contextRNN"):
            enc_cell = self.get_rnncell(config.cell_type, self.context_cell_size, keep_prob=1.0, num_layer=config.num_layer)
            # and enc_last_state will be same as the true last state
            _, enc_last_state = tf.nn.dynamic_rnn(
                enc_cell,
                joint_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)

            if config.num_layer > 1:
                enc_last_state = tf.concat(enc_last_state, 1)

        # combine with other attributes
        if config.use_hcf:
            attribute_embedding = da_embedding
            attribute_fc1 = layers.fully_connected(attribute_embedding, 30, activation_fn=tf.tanh, scope="attribute_fc1")

        cond_list = [topic_embedding, self.my_profile, self.ot_profile, enc_last_state]
        cond_embedding = tf.concat(cond_list, 1)

        with variable_scope.variable_scope("recognitionNetwork"):
            if config.use_hcf:
                recog_input = tf.concat([cond_embedding, output_embedding, attribute_fc1], 1)
            else:
                recog_input = tf.concat([cond_embedding, output_embedding], 1)
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(recog_input, config.latent_size * 2, activation_fn=None, scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        with variable_scope.variable_scope("priorNetwork"):
            # P(XYZ)=P(Z|X)P(X)P(Y|X,Z)
            prior_fc1 = layers.fully_connected(cond_embedding, np.maximum(config.latent_size * 2, 100),
                                               activation_fn=tf.tanh, scope="fc1")
            prior_mulogvar = layers.fully_connected(prior_fc1, config.latent_size * 2, activation_fn=None,
                                                    scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            # use sampled Z or posterior Z
            latent_sample = tf.cond(self.use_prior,
                                    lambda: sample_gaussian(prior_mu, prior_logvar),
                                    lambda: sample_gaussian(recog_mu, recog_logvar))

        with variable_scope.variable_scope("generationNetwork"):
            gen_inputs = tf.concat([cond_embedding, latent_sample], 1)

            # BOW loss
            bow_fc1 = layers.fully_connected(gen_inputs, 400, activation_fn=tf.tanh, scope="bow_fc1")
            if config.keep_prob < 1.0:
                bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
            self.bow_logits = layers.fully_connected(bow_fc1, self.vocab_size, activation_fn=None, scope="bow_project")

            # Y loss
            if config.use_hcf:
                meta_fc1 = layers.fully_connected(gen_inputs, 400, activation_fn=tf.tanh, scope="meta_fc1")
                if config.keep_prob <1.0:
                    meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob)
                self.da_logits = layers.fully_connected(meta_fc1, self.da_vocab_size, scope="da_project")
                da_prob = tf.nn.softmax(self.da_logits)
                pred_attribute_embedding = tf.matmul(da_prob, d_embedding)
                if forward:
                    selected_attribute_embedding = pred_attribute_embedding
                else:
                    selected_attribute_embedding = attribute_embedding
                dec_inputs = tf.concat([gen_inputs, selected_attribute_embedding], 1)
            else:
                self.da_logits = tf.zeros((batch_size, self.da_vocab_size))
                dec_inputs = gen_inputs

            # Decoder
            if config.num_layer > 1:
                dec_init_state = [layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None,
                                                        scope="init_state-%d" % i) for i in range(config.num_layer)]
                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state")

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size, config.keep_prob, config.num_layer)
            dec_cell = rnn_cell.OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:
                loop_func = decoder_fn_lib.context_decoder_fn_inference(None, dec_init_state, embedding,
                                                                        start_of_sequence_id=self.go_id,
                                                                        end_of_sequence_id=self.eos_id,
                                                                        maximum_length=self.max_utt_len,
                                                                        num_decoder_symbols=self.vocab_size,
                                                                        context_vector=selected_attribute_embedding)
                dec_input_embedding = None
                dec_seq_lens = None
            else:
                loop_func = decoder_fn_lib.context_decoder_fn_train(dec_init_state, selected_attribute_embedding)
                dec_input_embedding = embedding_ops.embedding_lookup(embedding, self.output_tokens)
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = self.output_lens - 1

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(dec_input_embedding, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(tf.random_uniform((batch_size, max_out_len-1), minval=0.0, maxval=1.0),
                                              config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(dec_input_embedding, [-1, max_out_len-1, config.embed_size])

            dec_outs, _, final_context_state = dynamic_rnn_decoder(dec_cell, loop_func, inputs=dec_input_embedding, sequence_length=dec_seq_lens)
            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.arg_max(dec_outs, 2)

        if not forward:
            with variable_scope.variable_scope("loss"):
                labels = self.output_tokens[:, 1:]
                label_mask = tf.to_float(tf.sign(labels))

                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=dec_outs, labels=labels)
                rc_loss = tf.reduce_sum(rc_loss * label_mask, reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)
                # used only for perpliexty calculation. Not used for optimzation
                self.rc_ppl = tf.exp(tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))

                """ as n-trial multimodal distribution. """
                tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1), [1, max_out_len - 1, 1])
                bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=tile_bow_logits, labels=labels) * label_mask
                bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1)
                self.avg_bow_loss  = tf.reduce_mean(bow_loss)

                # reconstruct the meta info about X
                if config.use_hcf:
                    da_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.da_logits, labels=self.output_das)
                    self.avg_da_loss = tf.reduce_mean(da_loss)
                else:
                    self.avg_da_loss = 0.0

                kld = gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar)
                self.avg_kld = tf.reduce_mean(kld)
                if log_dir is not None:
                    kl_weights = tf.minimum(tf.to_float(self.global_t)/config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld
                aug_elbo = self.avg_bow_loss + self.avg_da_loss + self.elbo

                tf.summary.scalar("da_loss", self.avg_da_loss)
                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("bow_loss", self.avg_bow_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu, prior_logvar)
                self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu, recog_logvar)
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss - self.log_p_z + self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2)