Ejemplo n.º 1
0
 def _build_helper(self, batch_size, embeddings, inputs, inputs_length,
                   mode, hparams, decoder_hparams):
     """Builds a helper instance for BasicDecoder."""
     # Auxiliary decoding mode at training time.
     if decoder_hparams.auxiliary:
         start_tokens = tf.fill([batch_size], text_encoder.PAD_ID)
         # helper = helpers.FixedContinuousEmbeddingHelper(
         #     embedding=embeddings,
         #     start_tokens=start_tokens,
         #     end_token=text_encoder.EOS_ID,
         #     num_steps=hparams.aux_decode_length)
         helper = contrib_seq2seq.SampleEmbeddingHelper(
             embedding=embeddings,
             start_tokens=start_tokens,
             end_token=text_encoder.EOS_ID,
             softmax_temperature=None)
     # Continuous decoding.
     elif hparams.decoder_continuous:
         # Scheduled mixing.
         if mode == tf.estimator.ModeKeys.TRAIN and hparams.scheduled_training:
             helper = helpers.ScheduledContinuousEmbeddingTrainingHelper(
                 inputs=inputs,
                 sequence_length=inputs_length,
                 mixing_concentration=hparams.scheduled_mixing_concentration
             )
         # Pure continuous decoding (hard to train!).
         elif mode == tf.estimator.ModeKeys.TRAIN:
             helper = helpers.ContinuousEmbeddingTrainingHelper(
                 inputs=inputs, sequence_length=inputs_length)
         # EVAL and PREDICT expect teacher forcing behavior.
         else:
             helper = contrib_seq2seq.TrainingHelper(
                 inputs=inputs, sequence_length=inputs_length)
     # Standard decoding.
     else:
         # Scheduled sampling.
         if mode == tf.estimator.ModeKeys.TRAIN and hparams.scheduled_training:
             helper = contrib_seq2seq.ScheduledEmbeddingTrainingHelper(
                 inputs=inputs,
                 sequence_length=inputs_length,
                 embedding=embeddings,
                 sampling_probability=hparams.scheduled_sampling_probability
             )
         # Teacher forcing (also for EVAL and PREDICT).
         else:
             helper = contrib_seq2seq.TrainingHelper(
                 inputs=inputs, sequence_length=inputs_length)
     return helper
Ejemplo n.º 2
0
def language_decoder(inputs,
                     embed_seq,
                     seq_len,
                     embedding_lookup,
                     dim,
                     start_tokens,
                     end_token,
                     max_seq_len,
                     unroll_type='teacher_forcing',
                     output_layer=None,
                     is_train=True,
                     scope='language_decoder',
                     reuse=tf.AUTO_REUSE):
    """
    Args:
        seq: sequence of token (usually ground truth sequence)
        embed_seq: pre-embedded sequence of token for teacher forcing
        embedding_lookup: embedding lookup function for greedy unrolling
        start_token: tensor for start token [<s>] * bs
        end_token: integer for end token <e>
    """
    with tf.variable_scope(scope, reuse=reuse) as scope:
        init_c = fc_layer(inputs,
                          dim,
                          use_bias=True,
                          use_bn=False,
                          activation_fn=None,
                          is_training=is_train,
                          scope='Linear_c',
                          reuse=reuse)
        init_h = fc_layer(inputs,
                          dim,
                          use_bias=True,
                          use_bn=False,
                          activation_fn=None,
                          is_training=is_train,
                          scope='Linear_h',
                          reuse=reuse)
        init_state = rnn.LSTMStateTuple(init_c, init_h)
        log.warning(scope.name)
        if unroll_type == 'teacher_forcing':
            helper = seq2seq.TrainingHelper(embed_seq, seq_len)
        elif unroll_type == 'greedy':
            helper = seq2seq.GreedyEmbeddingHelper(embedding_lookup,
                                                   start_tokens, end_token)
        else:
            raise ValueError('Unknown unroll_type')

        cell = rnn.BasicLSTMCell(num_units=dim, state_is_tuple=True)
        decoder = seq2seq.BasicDecoder(cell,
                                       helper,
                                       init_state,
                                       output_layer=output_layer)
        outputs, _, pred_length = seq2seq.dynamic_decode(
            decoder, maximum_iterations=max_seq_len, scope='dynamic_decoder')

        output = outputs.rnn_output
        pred = outputs.sample_id

        return output, pred, pred_length
Ejemplo n.º 3
0
    def train_decode_layer(self, dec_embeddig_input, dec_cell, output_layer):
        atten_mech = seq2seq.BahdanauAttention(
            num_units=self.hidden_dim * 2,
            memory=self.enc_output,
            memory_sequence_length=self.target_len,
            normalize=True,
            name='BahadanauAttention')
        dec_cell = seq2seq.AttentionWrapper(dec_cell,
                                            atten_mech,
                                            self.hidden_dim * 2,
                                            name='dec_attention_cell')

        initial_state = dec_cell.zero_state(
            batch_size=self.batch_size,
            dtype=tf.float32).clone(cell_state=self.enc_state)

        train_helper = seq2seq.TrainingHelper(dec_embeddig_input,
                                              self.target_len)
        training_decoder = seq2seq.BasicDecoder(dec_cell,
                                                train_helper,
                                                initial_state=initial_state,
                                                output_layer=output_layer)
        train_logits, _, _ = seq2seq.dynamic_decode(
            training_decoder,
            output_time_major=False,
            impute_finished=False,
            maximum_iterations=self.max_target_len)
        return train_logits
Ejemplo n.º 4
0
def seq2seq_rnn_no_attention(inputs,
                             hidden_size,
                             scope,
                             use_xavier=True,
                             stddev=1e-3,
                             weight_decay=None,
                             activation_fn=tf.nn.relu,
                             bn=False,
                             bn_decay=None,
                             is_training=None):
    batch_size = inputs.get_shape()[0].value
    npoint = inputs.get_shape()[1].value
    nstep = inputs.get_shape()[2].value
    in_size = inputs.get_shape()[3].value
    reshaped_inputs = tf.reshape(inputs, (-1, nstep, in_size))

    with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
        # build encoder
        encoder_cell = tf.nn.rnn_cell.LSTMCell(hidden_size)
        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
            encoder_cell,
            reshaped_inputs,
            sequence_length=tf.fill([batch_size * npoint], 4),
            dtype=tf.float32,
            time_major=False)
    with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
        decoder_cell = tf.nn.rnn_cell.LSTMCell(hidden_size)
        decoder_inputs = tf.reshape(encoder_state.h,
                                    [batch_size * npoint, 1, hidden_size])

        # decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size * npoint, dtype=tf.float32)
        # decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(
        #     decoder_cell, decoder_inputs,
        #     sequence_length=tf.fill([batch_size * npoint], 1),
        #     dtype=tf.float32, time_major=False)

        # Helper to feed inputs for training: read inputs from dense ground truth vectors
        train_helper = seq2seq.TrainingHelper(inputs=decoder_inputs,
                                              sequence_length=tf.fill(
                                                  [batch_size * npoint], 1),
                                              time_major=False)
        decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size *
                                                        npoint,
                                                        dtype=tf.float32)
        train_decoder = seq2seq.BasicDecoder(
            cell=decoder_cell,
            helper=train_helper,
            initial_state=decoder_initial_state,
            output_layer=None)
        decoder_outputs_train, decoder_last_state_train, decoder_outputs_length_train = seq2seq.dynamic_decode(
            decoder=train_decoder,
            output_time_major=False,
            impute_finished=True)
    outputs = tf.reshape(decoder_last_state_train.c, (-1, npoint, hidden_size))
    if bn:
        outputs = batch_norm_for_fc(outputs, is_training, bn_decay, 'bn')

    if activation_fn is not None:
        outputs = activation_fn(outputs)
    return outputs
Ejemplo n.º 5
0
    def _decoder(self, thought, labels):
        """Internally used to build a decoder RNN."""

        # Labels are shifted to the right by adding a start-of-string token.
        if self._time_major:
            sos_tokens = tf.constant([[2] * self._batch_size], dtype=tf.int64)
            shifted_labels = tf.concat([sos_tokens, labels[:-1]], 0)
        else:
            sos_tokens = tf.constant([[2]] * self._batch_size, dtype=tf.int64)
            shifted_labels = tf.concat([sos_tokens, labels[:, :-1]], 1)

        decoder_in = self._get_embeddings(shifted_labels)

        if self._cuda:
            decoder_out = tf.contrib.cudnn_rnn.CudnnGRU(
                1, self._output_size,
                direction='unidirectional')(decoder_in)[0]

        else:
            rnn_cell = tf.contrib.cudnn_rnn.CudnnCompatibleGRUCell(
                self._output_size)
            max_seq_lengths = tf.constant([self._max_sequence_length] *
                                          self._batch_size)
            helper = seq2seq.TrainingHelper(decoder_in,
                                            max_seq_lengths,
                                            time_major=self._time_major)
            decoder = seq2seq.BasicDecoder(rnn_cell, helper, thought)
            decoder_out = seq2seq.dynamic_decode(
                decoder, output_time_major=self._time_major)[0].rnn_output

        return decoder_out
Ejemplo n.º 6
0
        def get_DecoderHelper(embedding_lookup, seq_lengths, token_dim,
                              gt_tokens=None, sequence_type='program',
                              unroll_type='teacher_forcing'):
            if unroll_type == 'teacher_forcing':
                if gt_tokens is None:
                    raise ValueError('teacher_forcing requires gt_tokens')
                embedding = embedding_lookup(gt_tokens)
                helper = seq2seq.TrainingHelper(embedding, seq_lengths)
            elif unroll_type == 'scheduled_sampling':
                if gt_tokens is None:
                    raise ValueError('scheduled_sampling requires gt_tokens')
                embedding = embedding_lookup(gt_tokens)
                # sample_prob 1.0: always sample from ground truth
                # sample_prob 0.0: always sample from prediction
                helper = seq2seq.ScheduledEmbeddingTrainingHelper(
                    embedding, seq_lengths, embedding_lookup,
                    1.0 - self.sample_prob, seed=None,
                    scheduling_seed=None)
            elif unroll_type == 'greedy':
                # during evaluation, we perform greedy unrolling.
                start_token = tf.zeros([self.batch_size], dtype=tf.int32) + \
                              token_dim
                if sequence_type == 'program':
                    end_token = self.vocab.token2int['m)']
                elif sequence_type == 'action':
                    end_token = token_dim - 1
                else:
                    # Hack to have no end token, greater than number of perceptions
                    end_token = 11
                helper = seq2seq.GreedyEmbeddingHelper(
                    embedding_lookup, start_token, end_token)
            else:
                raise ValueError('Unknown unroll type')

            return helper
Ejemplo n.º 7
0
 def _build_decoder(self, decoder_cell, batch_size, lstm_holistic_features,
                    cnn_fmap):
     embedding_fn = functools.partial(tf.one_hot, depth=self.num_classes)
     output_layer = Compute2dAttentionLayer(512, self.num_classes, cnn_fmap)
     if self._is_training:
         train_helper = seq2seq.TrainingHelper(
             embedding_fn(self._groundtruth_dict['decoder_inputs']),
             sequence_length=self._groundtruth_dict['decoder_lengths'],
             time_major=False)
         decoder = seq2seq.BasicDecoder(
             cell=decoder_cell,
             helper=train_helper,
             initial_state=lstm_holistic_features,
             output_layer=output_layer)
     else:
         lstm0_state_tile = tf.nn.rnn_cell.LSTMStateTuple(
             tf.tile(lstm_holistic_features[0].c, [self._beam_width, 1]),
             tf.tile(lstm_holistic_features[0].h, [self._beam_width, 1]))
         lstm1_state_tile = tf.nn.rnn_cell.LSTMStateTuple(
             tf.tile(lstm_holistic_features[1].c, [self._beam_width, 1]),
             tf.tile(lstm_holistic_features[1].h, [self._beam_width, 1]))
         lstm_holistic_features_tile = (lstm0_state_tile, lstm1_state_tile)
         decoder = seq2seq.BeamSearchDecoder(
             cell=decoder_cell,
             embedding=embedding_fn,
             start_tokens=tf.fill([batch_size], self.start_label),
             end_token=self.end_label,
             initial_state=lstm_holistic_features_tile,
             beam_width=self._beam_width,
             output_layer=output_layer,
             length_penalty_weight=0.0)
     return decoder
Ejemplo n.º 8
0
def train_decoder(agenda, embeddings,
                  dec_inputs, base_sent_hiddens, insert_word_embeds, delete_word_embeds,
                  dec_input_lengths, base_length, iw_length, dw_length,
                  attn_dim, hidden_dim, num_layer, swap_memory, enable_dropout=False, dropout_keep=1.,
                  no_insert_delete_attn=False):
    with tf.variable_scope(OPS_NAME, 'decoder', []):
        batch_size = tf.shape(base_sent_hiddens)[0]

        dec_inputs = tf.nn.embedding_lookup(embeddings, dec_inputs)
        helper = seq2seq.TrainingHelper(dec_inputs, dec_input_lengths, name='train_helper')

        cell = create_decoder_cell(
            agenda,
            base_sent_hiddens, insert_word_embeds, delete_word_embeds,
            base_length, iw_length, dw_length,
            attn_dim, hidden_dim, num_layer,
            enable_dropout=enable_dropout, dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn
        )

        output_layer = DecoderOutputLayer(embeddings)
        zero_states = create_trainable_zero_state(cell, batch_size)

        decoder = seq2seq.BasicDecoder(cell, helper, zero_states, output_layer)

        outputs, state, length = seq2seq.dynamic_decode(decoder, swap_memory=swap_memory)

        return outputs, state, length
Ejemplo n.º 9
0
    def _create_encoder(self, args):
        # Create LSTM portion of network
        lstms_enc = [
            rnn.LSTMCell(args.encoder_size,
                         state_is_tuple=True,
                         initializer=initializers.xavier_initializer())
            for _ in range(args.num_encoder_layers)
        ]
        self.full_lstm = rnn.MultiRNNCell(lstms_enc, state_is_tuple=True)
        self.lstm_state = self.full_lstm.zero_state(args.batch_size,
                                                    tf.float32)

        # Forward pass
        #pdb.set_trace()
        encoder_input = tf.expand_dims(
            tf.concat([self.states_encode, self.actions_encode], 1), 1)
        helper = seq2seq.TrainingHelper(encoder_input,
                                        sequence_length=[1] * args.batch_size)
        decoder = seq2seq.BasicDecoder(cell=self.full_lstm,
                                       helper=helper,
                                       initial_state=self.lstm_state)
        output, self.final_state, _ = seq2seq.dynamic_decode(
            decoder, scope='latent_encoder')
        #output = tf.squeeze(tf.gather(output[0], output[1]))
        output = output[0][:, -1, :]

        # Fully connected layer to latent variable distribution parameters
        W = tf.get_variable("latent_w", [args.encoder_size, 2 * args.z_dim],
                            initializer=initializers.xavier_initializer())
        b = tf.get_variable("latent_b", [2 * args.z_dim])
        logits = tf.nn.xw_plus_b(output, W, b)

        # Separate into mean and logstd
        self.z_mean, self.z_logstd = tf.split(logits, 2, 1)
Ejemplo n.º 10
0
 def get_DecoderHelper(embedding_lookup,
                       seq_lengths,
                       token_dim,
                       gt_tokens=None,
                       unroll_type='teacher_forcing'):
     if unroll_type == 'teacher_forcing':
         if gt_tokens is None:
             raise ValueError('teacher_forcing requires gt_tokens')
         embedding = embedding_lookup(gt_tokens)
         helper = seq2seq.TrainingHelper(embedding, seq_lengths)
     elif unroll_type == 'scheduled_sampling':
         if gt_tokens is None:
             raise ValueError('scheduled_sampling requires gt_tokens')
         embedding = embedding_lookup(gt_tokens)
         # sample_prob 1.0: always sample from ground truth
         # sample_prob 0.0: always sample from prediction
         helper = seq2seq.ScheduledEmbeddingTrainingHelper(
             embedding,
             seq_lengths,
             embedding_lookup,
             1.0 - self.sample_prob,
             seed=None,
             scheduling_seed=None)
     elif unroll_type == 'greedy':
         # during evaluation, we perform greedy unrolling.
         start_token = tf.zeros([self.batch_size],
                                dtype=tf.int32) + token_dim
         end_token = token_dim - 1
         helper = seq2seq.GreedyEmbeddingHelper(embedding_lookup,
                                                start_token, end_token)
     else:
         raise ValueError('Unknown unroll type')
     return helper
Ejemplo n.º 11
0
    def model(self):
        with tf.variable_scope("encoder"):
            encoder_cell = self._create_rnn_cell()
            source_embedding = tf.get_variable(name="source_embedding",
                                               shape=[self.source_vocab_size, self.embedding_size],
                                               initializer=tf.initializers.truncated_normal())
            encoder_embedding_inputs = tf.nn.embedding_lookup(source_embedding, self.source_input)
            encoder_outputs, encoder_states = tf.nn.dynamic_rnn(cell=encoder_cell,
                                                                inputs=encoder_embedding_inputs,
                                                                dtype=tf.float32)
        with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
            if self.mode=="test":
                encoder_states = seq2seq.tile_batch(encoder_states, self.beam_size)
            decoder_cell = self._create_rnn_cell()
            decoder_cell = rnn.DropoutWrapper(decoder_cell,output_keep_prob=0.5)

            if self.mode=="test":
                batch_size = self.batch_size*self.beam_size
            else:
                batch_size = self.batch_size
            #decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size,dtype=tf.float32)

            output_layer = tf.layers.Dense(units=self.target_vocab_size,
                                           kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1))
            target_embedding = tf.get_variable(name="target_embedding",
                                               shape=[self.target_vocab_size, self.embedding_size])
            if self.mode == "train":
                self.mask = tf.sequence_mask(self.target_length,self.max_target_length,dtype=tf.float32)
                del_end = tf.strided_slice(self.target_input,[0,0],[self.batch_size,-1],[1,1])
                decoder_input = tf.concat([tf.fill([self.batch_size, 1],2),del_end],axis=1)
                decoder_input_embedding = tf.nn.embedding_lookup(target_embedding,decoder_input)
                training_helper = seq2seq.TrainingHelper(inputs=decoder_input_embedding,
                                                         sequence_length=tf.fill([self.batch_size],self.max_target_length))
                training_decoder = seq2seq.BasicDecoder(cell=decoder_cell,
                                                        helper=training_helper,
                                                        initial_state=encoder_states,
                                                        output_layer=output_layer)
                decoder_outputs,_,_ = seq2seq.dynamic_decode(decoder=training_decoder,output_time_major=False,
                                                             impute_finished=True,
                                                             maximum_iterations=self.max_target_length)
                self.decoder_logits_train = tf.identity(decoder_outputs.rnn_output)
                self.decoder_predict_train = tf.argmax(self.decoder_logits_train,axis=-1)
                self.loss_op = tf.reduce_mean(tf.losses.softmax_cross_entropy(
                    onehot_labels=tf.one_hot(self.target_input, depth=self.target_vocab_size),
                    logits=self.decoder_logits_train, weights=self.mask))
                optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
                trainable_params = tf.trainable_variables()
                gradients = tf.gradients(self.loss_op, trainable_params)
                clip_gradients, _ = tf.clip_by_global_norm(gradients, self.max_gradient_norm)
                self.train_op = optimizer.apply_gradients(zip(clip_gradients, trainable_params))
            elif self.mode =="test":
                start_tokens = tf.fill([self.batch_size], value=2)
                end_token = 3
                inference_decoder = seq2seq.BeamSearchDecoder(cell=decoder_cell, embedding=target_embedding,
                                                              start_tokens=start_tokens, end_token=end_token,
                                                              initial_state=encoder_states,
                                                              beam_width=self.beam_size, output_layer=output_layer)
                decoder_outputs, _, _ = seq2seq.dynamic_decode(decoder=inference_decoder, maximum_iterations=self.max_target_length)
                print(decoder_outputs.predicted_ids.get_shape().as_list())
                self.decoder_predict_decode = decoder_outputs.predicted_ids[:, :, 0]
Ejemplo n.º 12
0
 def _build_decoder(self, decoder_cell, batch_size):
     embedding_fn = functools.partial(tf.one_hot, depth=self.num_classes)
     output_layer = tf.layers.Dense(
         self.num_classes,
         activation=None,
         use_bias=True,
         kernel_initializer=tf.variance_scaling_initializer(),
         bias_initializer=tf.zeros_initializer())
     if self._is_training:
         train_helper = seq2seq.TrainingHelper(
             embedding_fn(self._groundtruth_dict['decoder_inputs']),
             sequence_length=self._groundtruth_dict['decoder_lengths'],
             time_major=False)
         decoder = seq2seq.BasicDecoder(
             cell=decoder_cell,
             helper=train_helper,
             initial_state=decoder_cell.zero_state(batch_size, tf.float32),
             output_layer=output_layer)
     else:
         decoder = seq2seq.BeamSearchDecoder(
             cell=decoder_cell,
             embedding=embedding_fn,
             start_tokens=tf.fill([batch_size], self.start_label),
             end_token=self.end_label,
             initial_state=decoder_cell.zero_state(
                 batch_size * self._beam_width, tf.float32),
             beam_width=self._beam_width,
             output_layer=output_layer,
             length_penalty_weight=0.0)
     return decoder
Ejemplo n.º 13
0
def get_helper(encoder_output, input_emb, input_len, embedding, mode, params):
    batch_size = tf.shape(encoder_output.output)[0]

    if mode == tf.estimator.ModeKeys.TRAIN:
        if params['conditional']:
            # conditional train helper with encoder output state as direct input
            # Reshape encoder state as auxiliary input: batch_size * hidden -> batch_size * max_len * hidden
            decoder_length = tf.shape(input_emb)[1]
            state_shape = tf.shape(encoder_output.state)
            encoder_state = tf.tile(
                tf.reshape(encoder_output.state,
                           [state_shape[1], state_shape[0], state_shape[2]]),
                [1, decoder_length, 1])

            input_emb = tf.concat([encoder_state, input_emb], axis=-1)

        helper = seq2seq.TrainingHelper(
            inputs=input_emb,  # batch_size * max_len-1 * emb_size
            sequence_length=input_len - 1,  # exclude last token
            time_major=False,
            name='training_helper')
    else:
        helper = seq2seq.GreedyEmbeddingHelper(
            embedding=embedding_func(embedding),
            start_tokens=tf.fill([batch_size], params['start_token']),
            end_token=params['end_token'])

    return helper
Ejemplo n.º 14
0
def seq_net(name, inputs, targets, sl, n_items, n_cates, cate_list, u_emb, rank, is_training, reuse):
  with tf.variable_scope(name+'-rnn'):
    output_layer = Dense(n_items+2, n_cates+1, cate_list, activation=None, name='output_projection')
    training_helper = seq2seq.TrainingHelper(
        inputs=inputs,
        sequence_length=sl,
        time_major=False)

    cell, initial_state = build_decoder_cell(rank, u_emb, tf.shape(inputs)[0])
    training_decoder = seq2seq.BasicDecoder(
        cell=cell,
        helper=training_helper,
        initial_state=initial_state,
        output_layer=output_layer)

    max_decoder_length = tf.reduce_max(sl)
    output, _, _ = seq2seq.dynamic_decode(
        decoder=training_decoder,
        output_time_major=False,
        impute_finished=True,
        maximum_iterations=max_decoder_length)

    output = tf.identity(output.rnn_output)
    mask = tf.sequence_mask(
        lengths=sl,
        maxlen=max_decoder_length,
        dtype=tf.float32)
    loss = seq2seq.sequence_loss(
        logits=output,
        targets=targets,
        weights=mask,
        average_across_timesteps=True,
        average_across_batch=False)
  return loss, tf.shape(output), tf.shape(targets)
Ejemplo n.º 15
0
    def reconstruction_loss(self, x_input, x_target, x_length, z=None):
        """Reconstruction loss calculation.

    Args:
      x_input: Batch of decoder input sequences for teacher forcing, sized
        `[batch_size, max(x_length), output_depth]`.
      x_target: Batch of expected output sequences to compute loss against,
        sized `[batch_size, max(x_length), output_depth]`.
      x_length: Length of input/output sequences, sized `[batch_size]`.
      z: (Optional) Latent vectors. Required if model is conditional. Sized
        `[n, z_size]`.

    Returns:
      r_loss: The reconstruction loss for each sequence in the batch.
      metric_map: Map from metric name to tf.metrics return values for logging.
      truths: Ground truth labels.
      predictions: Predicted labels.
      final_state: The final states of the decoder, or None if using Cudnn.
    """
        batch_size = x_input.shape[0].value

        has_z = z is not None
        z = tf.zeros([batch_size, 0]) if z is None else z
        repeated_z = tf.tile(tf.expand_dims(z, axis=1),
                             [1, tf.shape(x_input)[1], 1])

        sampling_probability_static = tensor_util.constant_value(
            self._sampling_probability)
        if sampling_probability_static == 0.0:
            # Use teacher forcing.
            x_input = tf.concat([x_input, repeated_z], axis=2)
            helper = seq2seq.TrainingHelper(x_input, x_length)
        else:
            # Use scheduled sampling.
            helper = seq2seq.ScheduledOutputTrainingHelper(
                inputs=x_input,
                sequence_length=x_length,
                auxiliary_inputs=repeated_z if has_z else None,
                sampling_probability=self._sampling_probability,
                next_inputs_fn=self._sample)

        decoder_outputs, final_state = self._decode(z,
                                                    helper=helper,
                                                    x_input=x_input)
        flat_x_target = flatten_maybe_padded_sequences(x_target, x_length)
        flat_rnn_output = flatten_maybe_padded_sequences(
            decoder_outputs.rnn_output, x_length)
        r_loss, metric_map, truths, predictions = self._flat_reconstruction_loss(
            flat_x_target, flat_rnn_output)

        # Sum loss over sequences.
        cum_x_len = tf.concat([(0, ), tf.cumsum(x_length)], axis=0)
        r_losses = []
        for i in range(batch_size):
            b, e = cum_x_len[i], cum_x_len[i + 1]
            r_losses.append(tf.reduce_sum(r_loss[b:e]))
        r_loss = tf.stack(r_losses)

        return r_loss, metric_map, truths, predictions, final_state
Ejemplo n.º 16
0
    def _init_decoder(self):
        lstm_decoder = tf.contrib.rnn.DropoutWrapper(
            tf.contrib.rnn.LSTMCell(self.rnn_size),
            output_keep_prob=self.keep_prob)
        attention_mechanism = tf.contrib.seq2seq.LuongAttention(
            self.rnn_size, self.encoder_outputs, name='LuongAttention')
        self.decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
            lstm_decoder,
            attention_mechanism,
            attention_layer_size=self.rnn_size,
            name="AttentionWrapper")
        batch_size = tf.shape(self.encoder_inputs)[0]
        attn_zero = self.decoder_cell.zero_state(batch_size=batch_size,
                                                 dtype=tf.float32)
        init_state = attn_zero.clone(cell_state=self.encoder_final_state)
        with tf.variable_scope(
                "decoder"
        ) as scope:  # Need to understand why we aren't using the dynamic_rnn method here
            output_layer = layers_core.Dense(units=self.effective_vocab_size,
                                             activation=None)

            # Train decoding
            train_helper = seq2seq.TrainingHelper(
                inputs=self.decoder_train_inputs_embedded,
                sequence_length=self.decoder_train_length,
                time_major=False)
            train_decoder = seq2seq.BasicDecoder(cell=self.decoder_cell,
                                                 helper=train_helper,
                                                 initial_state=init_state)
            self.decoder_outputs_train, _, _ = seq2seq.dynamic_decode(
                decoder=train_decoder,
                output_time_major=False,
                impute_finished=True,
                maximum_iterations=self.max_decoder_seq_length,
                scope=scope)
            self.decoder_logits_train = output_layer.apply(
                self.decoder_outputs_train.rnn_output)
            self.decoder_prediction_train = tf.argmax(
                self.decoder_logits_train, 2)

            # Greedy decoding
            scope.reuse_variables()
            greedy_helper = seq2seq.GreedyEmbeddingHelper(
                embedding=self.embedding_matrix,
                start_tokens=self.decoder_start_tokens,
                end_token=self.eos)
            greedy_decoder = seq2seq.BasicDecoder(cell=self.decoder_cell,
                                                  helper=greedy_helper,
                                                  initial_state=init_state,
                                                  output_layer=output_layer)
            self.decoder_outputs_inference, _, _ = seq2seq.dynamic_decode(
                decoder=greedy_decoder,
                output_time_major=False,
                impute_finished=True,
                maximum_iterations=self.
                max_decoder_seq_length,  # Figure out a better way of setting this
                scope=scope)
            self.decoder_prediction_inference = tf.argmax(
                self.decoder_outputs_inference.rnn_output, 2)
Ejemplo n.º 17
0
    def _train_helper(self):
        start_ids = tf.fill([self._batch_size, 1], self._start_token_id)
        decoder_input_ids = tf.concat([start_ids, self.output_data], 1)
        decoder_inputs = tf.nn.embedding_lookup(self.embedding,
                                                decoder_input_ids)

        return seq2seq.TrainingHelper(inputs=decoder_inputs,
                                      sequence_length=self.output_lengths)
Ejemplo n.º 18
0
 def _helper(self, train_test_predict, embeded_inputs, sequences_lengths,
             start_tokens, end_token):
     if train_test_predict == 'train' or train_test_predict == 'test':
         helper = seq2seq.TrainingHelper(embeded_inputs, sequences_lengths)
     elif train_test_predict == 'predict':
         helper = seq2seq.GreedyEmbeddingHelper(self.embedding_vector,
                                                start_tokens, end_token)
     else:
         raise TypeError(
             'train_test_predict should equals train, test, or predict')
     return helper
Ejemplo n.º 19
0
def train_decoder(agenda,
                  embeddings,
                  extended_base_words,
                  oov,
                  dec_inputs,
                  dec_extended_inputs,
                  base_sent_hiddens,
                  insert_word_embeds,
                  delete_word_embeds,
                  dec_input_lengths,
                  base_length,
                  iw_length,
                  dw_length,
                  vocab_size,
                  attn_dim,
                  hidden_dim,
                  num_layer,
                  swap_memory,
                  enable_dropout=False,
                  dropout_keep=1.,
                  no_insert_delete_attn=False):
    with tf.variable_scope(OPS_NAME, 'decoder'):
        dec_input_embeds = vocab.embed_tokens(dec_inputs)
        last_ids = tf.cast(tf.expand_dims(dec_extended_inputs, 2), tf.float32)
        cell_input = tf.concat([dec_input_embeds, last_ids], axis=2)

        helper = seq2seq.TrainingHelper(cell_input,
                                        dec_input_lengths,
                                        name='train_helper')

        cell, zero_states = create_decoder_cell(
            agenda,
            extended_base_words,
            oov,
            base_sent_hiddens,
            insert_word_embeds,
            delete_word_embeds,
            base_length,
            iw_length,
            dw_length,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_layer,
            enable_dropout=enable_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

        decoder = seq2seq.BasicDecoder(cell, helper, zero_states)
        outputs, state, length = seq2seq.dynamic_decode(
            decoder, swap_memory=swap_memory)

        return outputs, state, length
Ejemplo n.º 20
0
    def train_decoding_layer(self, inputs, decoder_cell, initial_state):
        helper = seq2seq.TrainingHelper(inputs,
                                        self.out_length,
                                        time_major=False)
        decoder = seq2seq.BasicDecoder(decoder_cell, helper, initial_state)

        outputs = seq2seq.dynamic_decode(
            decoder,
            output_time_major=False,
            impute_finished=True,
            maximum_iterations=self.out_max_length)
        return outputs[0]
Ejemplo n.º 21
0
 def decoder(self, thought, labels):
     main = tf.strided_slice(labels, [0, 0], [self.batch_size, -1], [1, 1])
     shifted_labels = tf.concat([tf.fill([self.batch_size, 1], 2), main], 1)
     decoder_in = self.get_embedding(shifted_labels)
     cell = tf.nn.rnn_cell.GRUCell(self.output_size)
     max_seq_lengths = tf.fill([self.batch_size], self.maxlen)
     helper = seq2seq.TrainingHelper(
         decoder_in, max_seq_lengths, time_major = False
     )
     decoder = seq2seq.BasicDecoder(cell, helper, thought)
     decoder_out = seq2seq.dynamic_decode(decoder)[0].rnn_output
     return decoder_out
Ejemplo n.º 22
0
    def _build_decoder(self):
        with tf.variable_scope("dialog_decoder"):
            with tf.variable_scope("decoder_output_projection"):
                output_layer = layers_core.Dense(
                    self.config.vocab_size, use_bias=False, name="output_projection")

            with tf.variable_scope("decoder_rnn"):
                dec_cell, dec_init_state = self._build_decoder_cell(enc_outputs=self.encoder_outputs,
                                                                    enc_state=self.encoder_state)

                # Training or Eval
                if self.mode != ModelMode.infer:  # not infer, do decode turn by turn
                    resp_emb_inp = tf.nn.embedding_lookup(self.decoder_embeddings, self.target_input)
                    helper = tc_seq2seq.TrainingHelper(resp_emb_inp, self.target_length)
                    decoder = tc_seq2seq.BasicDecoder(
                        cell=dec_cell,
                        helper=helper,
                        initial_state=dec_init_state,
                        output_layer=output_layer
                    )

                    dec_outputs, dec_state, _ = tc_seq2seq.dynamic_decode(decoder)
                    sample_id = dec_outputs.sample_id
                    logits = dec_outputs.rnn_output

                else:
                    beam_width = self.config.beam_size
                    length_penalty_weight = self.config.length_penalty_weight
                    maximum_iterations = tf.to_int32(self.config.infer_max_len)
                    start_tokens = tf.fill([self.batch_size], self.config.sos_idx)
                    end_token = self.config.eos_idx

                    # beam size
                    decoder = tc_seq2seq.BeamSearchDecoder(
                        cell=dec_cell,
                        embedding=self.decoder_embeddings,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=dec_init_state,
                        beam_width=beam_width,
                        output_layer=output_layer,
                        length_penalty_weight=length_penalty_weight)

                    dec_outputs, dec_state, _ = tc_seq2seq.dynamic_decode(
                        decoder,
                        maximum_iterations=maximum_iterations,
                    )
                    logits = tf.no_op()
                    sample_id = dec_outputs.predicted_ids

                self.logits = logits
                self.sample_id = sample_id
Ejemplo n.º 23
0
def seq2seq_without_attention(inputs,
        hidden_size,
        scope,
        activation_fn=tf.nn.relu,
        bn=False,
        bn_decay=None,
        is_training=None):
    """ sequence model without attention.
    Args:
      inputs: 4-D tensor variable BxNxTxD
      hidden_size: int
      scope: encoder
      activation_fn: function
      bn: bool, whether to use batch norm
      bn_decay: float or float tensor variable in [0,1]
      is_training: bool Tensor variable
    Return:
      Variable Tensor BxNxD
    """
    with tf.variable_scope(scope) as sc:
        batch_size = inputs.get_shape()[0].value
        npoint = inputs.get_shape()[1].value
        nstep = inputs.get_shape()[2].value
        in_size = inputs.get_shape()[3].value
        reshaped_inputs = tf.reshape(inputs, (-1, nstep, in_size))

        with tf.variable_scope('encoder'):
            # build encoder
            encoder_cell = tf.nn.rnn_cell.LSTMCell(hidden_size)
            encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, reshaped_inputs,
                                                               sequence_length=tf.fill([batch_size * npoint], 4),
                                                               dtype=tf.float32, time_major=False)
        with tf.variable_scope('decoder'):
            decoder_cell = tf.nn.rnn_cell.LSTMCell(hidden_size)
            decoder_inputs = tf.reshape(encoder_state.h, [batch_size * npoint, 1, hidden_size])

            # Helper to feed inputs for training: read inputs from dense ground truth vectors
            train_helper = seq2seq.TrainingHelper(inputs=decoder_inputs, sequence_length=tf.fill([batch_size * npoint], 1),
                                                  time_major=False)
            decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size * npoint, dtype=tf.float32)
            train_decoder = seq2seq.BasicDecoder(cell=decoder_cell, helper=train_helper,
                                                 initial_state=decoder_initial_state, output_layer=None)
            decoder_outputs_train, decoder_last_state_train, decoder_outputs_length_train = seq2seq.dynamic_decode(
                decoder=train_decoder, output_time_major=False, impute_finished=True)
        outputs = tf.reshape(decoder_last_state_train.c, (-1, npoint, hidden_size))
        if bn:
            outputs = batch_norm_for_fc(outputs, is_training, bn_decay, 'bn')

        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return outputs
Ejemplo n.º 24
0
    def training_decoding_layer(self, dec_embed_input, summary_len, dec_cell,
                                initial_state, output_layer, max_summary_len):
        training_helper = seq2seq.TrainingHelper(inputs=dec_embed_input,
                                                 sequence_length=summary_len,
                                                 time_major=False)
        training_decoder = seq2seq.BasicDecoder(dec_cell, training_helper,
                                                initial_state, output_layer)
        training_logits, _, _ = seq2seq.dynamic_decode(
            training_decoder,
            output_time_major=False,
            impute_finished=True,
            maximum_iterations=max_summary_len)

        return training_logits
Ejemplo n.º 25
0
    def _init_decoder(self):
        self.decoder_cell = tf.contrib.rnn.BasicLSTMCell(self.rnn_size)
        with tf.variable_scope(
                "decoder"
        ) as scope:  # Need to understand why we aren't using the dynamic_rnn method here
            output_layer = layers_core.Dense(units=self.effective_vocab_size,
                                             activation=None)

            # Train decoding
            train_helper = seq2seq.TrainingHelper(
                inputs=self.decoder_train_inputs_embedded,
                sequence_length=self.decoder_train_length,
                time_major=False)
            train_decoder = seq2seq.BasicDecoder(
                cell=self.decoder_cell,
                helper=train_helper,
                initial_state=self.encoder_final_state)
            self.decoder_outputs_train, _, _ = seq2seq.dynamic_decode(
                decoder=train_decoder,
                output_time_major=False,
                impute_finished=True,
                maximum_iterations=self.max_decoder_seq_length,
                scope=scope)
            self.decoder_logits_train = output_layer.apply(
                self.decoder_outputs_train.rnn_output)
            self.decoder_prediction_train = tf.argmax(
                self.decoder_logits_train, 2)

            # Greedy decoding
            scope.reuse_variables()
            greedy_helper = seq2seq.GreedyEmbeddingHelper(
                embedding=self.embedding_matrix,
                start_tokens=self.decoder_start_tokens,
                end_token=self.eos)
            greedy_decoder = seq2seq.BasicDecoder(
                cell=self.decoder_cell,
                helper=greedy_helper,
                initial_state=self.encoder_final_state,
                output_layer=output_layer)
            self.decoder_outputs_inference, _, _ = seq2seq.dynamic_decode(
                decoder=greedy_decoder,
                output_time_major=False,
                impute_finished=True,
                maximum_iterations=self.
                max_decoder_seq_length,  # Figure out a better way of setting this
                scope=scope)
            self.decoder_prediction_inference = tf.argmax(
                self.decoder_outputs_inference.rnn_output, 2)
Ejemplo n.º 26
0
    def _model(self, embed):
        graph = tf.Graph()
        with graph.as_default():
            embedding = tf.Variable(embed, trainable=False, name='embedding')  # 词向量
            lr = tf.placeholder(tf.float32, [], name='learning_rate')
            # 输入数据
            x_input = tf.placeholder(tf.int32, [None, None], name='x_input')  # 输入数据X
            x_sequence_length = tf.placeholder(tf.int32, [None], name='x_length')  # 输入数据每一条的长度
            x_embedding = tf.nn.embedding_lookup(embedding, x_input)  # 将输入的one-hot编码转换成向量
            y_input = tf.placeholder(tf.int32, [None, None], name='y_input')  # 输入数据Y
            y_sequence_length = tf.placeholder(tf.int32, [None], name='y_length')  # 每一个Y的长度
            y_embedding = tf.nn.embedding_lookup(embedding, y_input)  # 对Y向量化
            batch_size = tf.placeholder(tf.int32, [], name='batch_size')
            keep_prob = tf.placeholder(tf.float32, [], name='keep_prob')

            encoder_output, encoder_state = self._encoder(keep_prob, x_embedding, x_sequence_length, batch_size)

            training_helper = seq2seq.TrainingHelper(inputs=y_embedding, sequence_length=y_sequence_length)
            predict_helper = seq2seq.GreedyEmbeddingHelper(embedding, tf.fill([batch_size], self.word2index['GO']),
                                                           self.word2index['EOS'])
            train_output = self._decoder(keep_prob, encoder_output, encoder_state, batch_size, 'decode',
                                         training_helper)
            predict_output = self._decoder(keep_prob, encoder_output, encoder_state, batch_size, 'decode',
                                           predict_helper, True)

            # loss function
            training_logits = tf.identity(train_output.rnn_output, name='training_logits')
            predicting_logits = tf.identity(predict_output.rnn_output, name='predicting')

            # target = tf.slice(y_input, [0, 1], [-1, -1])
            # target = tf.concat([tf.fill([batch_size, 1], self.word2index['GO']), y_input], 1)
            target = y_input

            masks = tf.sequence_mask(y_sequence_length, dtype=tf.float32, name='mask')

            loss = seq2seq.sequence_loss(training_logits, target, masks)
            optimizer = tf.train.AdamOptimizer(lr)
            gradients = optimizer.compute_gradients(loss)
            capped_gradients = [(tf.clip_by_value(grad, -5., 5.), var) for grad, var in gradients if
                                grad is not None]
            train_op = optimizer.apply_gradients(capped_gradients)
            # predicting_logits = tf.nn.softmax(predicting_logits, axis=1)
            tf.summary.scalar('loss', loss)
            tf.summary.scalar('learning rate', lr)
            # tf.summary.tensor_summary('learning rate', lr)

        return graph, loss, train_op, predicting_logits
    def _get_state(self,
                   inputs,
                   lengths=None,
                   initial_state=None):
        """Computes the state of the RNN-NADE (NADE bias parameters and RNN state).

    Args:
      inputs: A batch of sequences to compute the state from, sized
          `[batch_size, max(lengths), num_dims]` or `[batch_size, num_dims]`.
      lengths: The length of each sequence, sized `[batch_size]`.
      initial_state: An RnnNadeStateTuple, the initial state of the RNN-NADE, or
          None if the zero state should be used.

    Returns:
      final_state: An RnnNadeStateTuple, the final state of the RNN-NADE.
    """
        batch_size = inputs.shape[0].value

        if lengths is None:
            lengths = tf.tile(tf.shape(inputs)[1:2], [batch_size])
        if initial_state is None:
            initial_rnn_state = self._get_rnn_zero_state(batch_size)
        else:
            initial_rnn_state = initial_state.rnn_state

        helper = contrib_seq2seq.TrainingHelper(
            inputs=inputs, sequence_length=lengths)

        decoder = contrib_seq2seq.BasicDecoder(
            cell=self._rnn_cell,
            helper=helper,
            initial_state=initial_rnn_state,
            output_layer=self._fc_layer)

        final_outputs, final_rnn_state = contrib_seq2seq.dynamic_decode(
            decoder)[0:2]

        # Flatten time dimension.
        final_outputs_flat = magenta.common.flatten_maybe_padded_sequences(
            final_outputs.rnn_output, lengths)

        b_enc, b_dec = tf.split(
            final_outputs_flat, [self._nade.num_hidden, self._nade.num_dims],
            axis=1)

        return RnnNadeStateTuple(b_enc, b_dec, final_rnn_state)
Ejemplo n.º 28
0
    def _create_lstm_policy(self, args):
        # Create LSTM portion of network
        lstms_pol = [
            rnn.LSTMCell(args.policy_size,
                         state_is_tuple=True,
                         initializer=initializers.xavier_initializer())
            for _ in range(args.num_policy_layers)
        ]
        self.policy_lstm = rnn.MultiRNNCell(lstms_pol, state_is_tuple=True)
        self.policy_state = self.policy_lstm.zero_state(
            args.batch_size * args.sample_size, tf.float32)

        # Get samples from standard normal distribution, transform to match z-distribution
        samples = tf.random_normal(
            [args.sample_size, args.batch_size, args.z_dim], name="z_samples")
        z_samples = samples * tf.exp(self.z_logstd) + self.z_mean
        self.z_samples = tf.transpose(z_samples, [1, 0, 2])

        # Construct policy input
        policy_input = tf.reshape(tf.concat([self.states, self.z_samples], 2),
                                  [-1, 1, args.z_dim + args.state_dim])
        # Forward pass
        helper = seq2seq.TrainingHelper(policy_input,
                                        sequence_length=[1] * args.batch_size *
                                        args.sample_size)
        decoder = seq2seq.BasicDecoder(cell=self.policy_lstm,
                                       helper=helper,
                                       initial_state=self.policy_state)
        output, self.final_policy_state, _ = seq2seq.dynamic_decode(
            decoder, scope='policy_cell')
        #output = tf.squeeze(tf.gather(output[0], output[1]))
        output = output[0][:, -1, :]

        # Fully connected layer to latent variable distribution parameters
        W = tf.get_variable("lstm_w", [args.policy_size, args.action_dim],
                            initializer=initializers.xavier_initializer())
        b = tf.get_variable("lstm_b", [args.action_dim])
        a_mean = tf.nn.xw_plus_b(output, W, b)
        self.a_mean = tf.reshape(
            a_mean, [args.batch_size, args.sample_size, args.action_dim],
            name="a_mean")

        # Initialize logstd
        self.a_logstd = tf.Variable(np.zeros(args.action_dim),
                                    name="a_logstd",
                                    dtype=tf.float32)
Ejemplo n.º 29
0
def residual_decoder(agenda,
                     dec_inputs,
                     dec_input_lengths,
                     hidden_dim,
                     num_layer,
                     swap_memory,
                     enable_dropout=False,
                     dropout_keep=1.,
                     name=None):
    with tf.variable_scope(name, 'residual_decoder', []):
        batch_size = tf.shape(dec_inputs)[0]
        embeddings = vocab.get_embeddings()

        # Concatenate agenda [y_hat;base_input_embed] with decoder inputs

        # [batch x max_len x word_dim]
        dec_inputs = tf.nn.embedding_lookup(embeddings, dec_inputs)
        max_len = tf.shape(dec_inputs)[1]

        # [batch x 1 x agenda_dim]
        agenda = tf.expand_dims(agenda, axis=1)

        # [batch x max_len x agenda_dim]
        agenda = tf.tile(agenda, [1, max_len, 1])

        # [batch x max_len x word_dim+agenda_dim]
        dec_inputs = tf.concat([dec_inputs, agenda], axis=2)

        helper = seq2seq.TrainingHelper(dec_inputs,
                                        dec_input_lengths,
                                        name='train_helper')
        cell = tf_rnn.MultiRNNCell([
            create_rnn_layer(i, hidden_dim // 2, enable_dropout, dropout_keep)
            for i in range(num_layer)
        ])
        zero_states = create_trainable_initial_states(batch_size, cell)

        output_layer = DecoderOutputLayer(embeddings)
        decoder = seq2seq.BasicDecoder(cell, helper, zero_states, output_layer)

        outputs, state, length = seq2seq.dynamic_decode(
            decoder, swap_memory=swap_memory)

        return outputs, state, length
Ejemplo n.º 30
0
def decode_L(inputs, dim, embed_map, start_token,
             unroll_type='teacher_forcing', seq=None, seq_len=None,
             end_token=None, max_seq_len=None, output_layer=None,
             is_train=True, scope='decode_L', reuse=tf.AUTO_REUSE):

    with tf.variable_scope(scope, reuse=reuse) as scope:
        init_c = fc_layer(inputs, dim, use_bias=True, use_bn=False,
                          activation_fn=tf.nn.tanh, is_training=is_train,
                          scope='Linear_c', reuse=reuse)
        init_h = fc_layer(inputs, dim, use_bias=True, use_bn=False,
                          activation_fn=tf.nn.tanh, is_training=is_train,
                          scope='Linear_h', reuse=reuse)
        init_state = rnn.LSTMStateTuple(init_c, init_h)
        log.warning(scope.name)

        start_tokens = tf.zeros(
            [tf.shape(inputs)[0]], dtype=tf.int32) + start_token
        if unroll_type == 'teacher_forcing':
            if seq is None: raise ValueError('seq is None')
            if seq_len is None: raise ValueError('seq_len is None')
            seq_with_start = tf.concat([tf.expand_dims(start_tokens, axis=1),
                                        seq[:, :-1]], axis=1)
            helper = seq2seq.TrainingHelper(
                tf.nn.embedding_lookup(embed_map, seq_with_start), seq_len)
        elif unroll_type == 'greedy':
            if end_token is None: raise ValueError('end_token is None')
            helper = seq2seq.GreedyEmbeddingHelper(
                lambda e: tf.nn.embedding_lookup(embed_map, e),
                start_tokens, end_token)
        else:
            raise ValueError('Unknown unroll_type')

        cell = rnn.BasicLSTMCell(num_units=dim, state_is_tuple=True)
        decoder = seq2seq.BasicDecoder(cell, helper, init_state,
                                       output_layer=output_layer)
        outputs, _, pred_length = seq2seq.dynamic_decode(
            decoder, maximum_iterations=max_seq_len,
            scope='dynamic_decoder')

        output = outputs.rnn_output
        pred = outputs.sample_id

        return output, pred, pred_length