Beispiel #1
0
def _build_beamsearch_inference_decoder(glb_ctx, im, ans, inputs, vocab_size,
                                        num_cells, pad_token):
    vocab_size = max(vocab_size, pad_token + 1)

    # =================== create cells ========================
    lstm = BasicLSTMCell(num_cells)
    attention_cell = MultiModalAttentionCell(512, im, ans, keep_prob=1.0)

    multi_cell = MultiRNNCell([lstm, attention_cell],
                              state_is_tuple=True)
    lstm_state_sizes, att_state_size = multi_cell.state_size
    lstm_state_size = sum(lstm_state_sizes)
    state_size = lstm_state_size + att_state_size

    # =============  create state placeholders ================
    state_feed = tf.placeholder(dtype=tf.float32, shape=[None, state_size],
                                name='state_feed')
    lstm_state_feed = tf.slice(state_feed, begin=[0, 0],
                               size=[-1, lstm_state_size])
    att_state_feed = tf.slice(state_feed, begin=[0, lstm_state_size],
                              size=[-1, att_state_size])
    feed_c, feed_h = split_op(lstm_state_feed, num_splits=2, axis=1)
    state_tuple = LSTMStateTuple(feed_c, feed_h)
    multi_cell_state_feed = (state_tuple, att_state_feed)

    # ==================== create init state ==================
    # lstm init state
    init_h = slim.fully_connected(glb_ctx, num_cells, activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = tf.zeros_like(init_h)
    batch_size = tf.shape(init_h)[0]

    init_att = tf.zeros([batch_size, num_cells], dtype=tf.float32)
    init_state = (init_c, init_h, init_att)
    concat_op(init_state, axis=1, name="initial_state")  # need to fetch

    # ==================== forward pass ========================
    with tf.variable_scope('word_embedding'):
        word_map = tf.get_variable(
            name="word_map",
            shape=[vocab_size, num_cells],
            initializer=tf.random_uniform_initializer(-0.08, 0.08,
                                                      dtype=tf.float32))
    word_embed = tf.nn.embedding_lookup(word_map, inputs)

    with tf.variable_scope('RNN'):
        outputs, multi_state = multi_cell(tf.squeeze(word_embed, squeeze_dims=[1]),
                                          state=multi_cell_state_feed)

    # ==================== concat states =========================
    lstm_state, att_state = multi_state
    state_c, state_h = lstm_state
    concat_op((state_c, state_h, att_state), axis=1, name="state")  # need to fetch

    # predict next word
    outputs = tf.reshape(outputs, [-1, num_cells])
    logits = slim.fully_connected(outputs, vocab_size, activation_fn=None,
                                  scope='logits')
    prob = tf.nn.softmax(logits, name="softmax")  # need to fetch
    return prob
def _build_beamsearch_inference_decoder(in_embed, inputs, vocab_size,
                                        num_cells, pad_token):
    # avoid out of range error
    vocab_size = max(vocab_size, pad_token + 1)

    # init state / image embedding
    init_h = slim.fully_connected(in_embed,
                                  num_cells,
                                  activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = tf.zeros_like(init_h)
    init_state = LSTMStateTuple(init_c, init_h)

    # build LSTM cell and RNN
    lstm_cell = BasicLSTMCell(num_cells)
    concat_op(init_state, axis=1, name="initial_state")

    # word embedding
    with tf.variable_scope('word_embedding'):
        word_map = tf.get_variable(name="word_map",
                                   shape=[vocab_size, num_cells],
                                   initializer=tf.random_uniform_initializer(
                                       -0.08, 0.08, dtype=tf.float32))
    word_embed = tf.nn.embedding_lookup(word_map, inputs)

    # Placeholder for feeding a batch of concatenated states.
    state_feed = tf.placeholder(dtype=tf.float32,
                                shape=[None, sum(lstm_cell.state_size)],
                                name="state_feed")
    feed_c, feed_h = split_op(state_feed, num_splits=2, axis=1)
    state_tuple = LSTMStateTuple(feed_c, feed_h)

    # Run a single LSTM step.
    with tf.variable_scope('RNN'):
        outputs, state_tuple = lstm_cell(inputs=tf.squeeze(word_embed,
                                                           squeeze_dims=[1]),
                                         state=state_tuple)

    # Concatentate the resulting state.
    state = concat_op(state_tuple, 1, name="state")

    # Stack batches vertically.
    outputs = tf.reshape(outputs, [-1, lstm_cell.output_size])
    logits = slim.fully_connected(outputs,
                                  vocab_size,
                                  activation_fn=None,
                                  scope='logits')
    prob = tf.nn.softmax(logits, name="softmax")
    return state
def build_decoder(fused,
                  noise,
                  ans,
                  ans_len,
                  vocab_size,
                  keep_prob,
                  pad_token,
                  num_dec_cells,
                  phase='train'):
    # quest_embed, noise = ctx
    # average pooling over image
    in_embed = concat_op([fused, noise], axis=1)
    with tf.variable_scope('var_vqa'):
        if phase == 'train' or phase == 'condition':
            add_loss = phase == 'train'
            inputs, targets, length = build_caption_inputs_and_targets(
                ans, ans_len)
            return _build_training_decoder(in_embed, inputs, length, targets,
                                           vocab_size, num_dec_cells,
                                           keep_prob, pad_token, add_loss)
        elif phase == 'greedy':
            return _build_greedy_inference_decoder(in_embed, vocab_size,
                                                   num_dec_cells,
                                                   _START_TOKEN_ID)
        elif phase == 'beam' or phase == 'sampling':
            return _build_tf_beam_inference_decoder(in_embed, vocab_size,
                                                    num_dec_cells,
                                                    _START_TOKEN_ID)
        else:
            return _build_beamsearch_inference_decoder(in_embed, quest,
                                                       vocab_size,
                                                       num_dec_cells,
                                                       pad_token)
Beispiel #4
0
def build_decoding_attention_vaq_model(im,
                                       attr,
                                       ans_embed,
                                       quest,
                                       quest_len,
                                       vocab_size,
                                       keep_prob,
                                       pad_token,
                                       num_dec_cells,
                                       phase='train',
                                       cell_option=1):
    if attr is None:
        in_embed = ans_embed
    else:
        in_embed = concat_op(values=[attr, ans_embed], axis=1)

    with tf.variable_scope('vaq'):
        if phase == 'train' or phase == 'condition' or phase == 'evaluate':
            inputs, targets, length = _build_caption_inputs_and_targets(
                quest, quest_len)
            return build_attention_decoder(in_embed, im, ans_embed, inputs,
                                           length, targets, vocab_size,
                                           num_dec_cells, keep_prob, pad_token,
                                           cell_option)
        elif phase == 'greedy':
            return build_attention_greedy_decoder(in_embed, im, ans_embed,
                                                  vocab_size, num_dec_cells,
                                                  pad_token, cell_option)
        else:
            return build_vaq_dec_attention_predictor(in_embed, im, ans_embed,
                                                     quest, vocab_size,
                                                     num_dec_cells, pad_token,
                                                     cell_option)
def build_decoder(im,
                  ans_embed,
                  quest,
                  quest_len,
                  vocab_size,
                  keep_prob,
                  pad_token,
                  num_dec_cells,
                  phase='train'):
    # average pooling over image
    in_embed = concat_op(values=[im, ans_embed], axis=1)
    with tf.variable_scope('vaq'):
        if phase == 'train' or phase == 'condition':
            inputs, targets, length = build_caption_inputs_and_targets(
                quest, quest_len)
            return _build_training_decoder(in_embed, inputs, length, targets,
                                           vocab_size, num_dec_cells,
                                           keep_prob, pad_token)
        elif phase == 'greedy':
            return _build_greedy_inference_decoder(in_embed, vocab_size,
                                                   num_dec_cells,
                                                   _START_TOKEN_ID)
        elif phase == 'sampling':
            return _build_random_inference_decoder(in_embed, vocab_size,
                                                   num_dec_cells,
                                                   _START_TOKEN_ID)
        elif phase == 'beam':
            return _build_tf_beam_inference_decoder(in_embed, vocab_size,
                                                    num_dec_cells,
                                                    _START_TOKEN_ID)
        else:
            return _build_beamsearch_inference_decoder(in_embed, quest,
                                                       vocab_size,
                                                       num_dec_cells,
                                                       pad_token)
Beispiel #6
0
def build_decoder(im, attr, ans_embed, quest, quest_len, vocab_size,
                  keep_prob, pad_token, num_dec_cells, phase='train'):
    if attr is None:
        in_embed = ans_embed
    else:
        in_embed = concat_op(values=[attr, ans_embed], axis=1)

    with tf.variable_scope('inverse_vqa'):
        if phase == 'train' or phase == 'condition' or phase == 'evaluate':
            inputs, targets, length = build_caption_inputs_and_targets(quest,
                                                                       quest_len)
            return _build_training_decoder(in_embed, im, ans_embed, inputs, length, targets,
                                           vocab_size, num_dec_cells, keep_prob,
                                           pad_token)
        elif phase == 'greedy':
            return _build_greedy_inference_decoder(in_embed, im, ans_embed,
                                                   vocab_size, num_dec_cells,
                                                   VOCAB_CONFIG.start_token_id)
        elif phase == 'beam':
            return _build_tf_beam_inference_decoder(in_embed, im, ans_embed,
                                                    vocab_size, num_dec_cells,
                                                    VOCAB_CONFIG.start_token_id,
                                                    pad_token)
        else:
            return _build_beamsearch_inference_decoder(in_embed, im, ans_embed,
                                                       quest, vocab_size, num_dec_cells,
                                                       pad_token)
 def unmerge_batch_beam(self, tensor):
     remaining_shape = tf.shape(tensor)[1:]
     res = tf.reshape(tensor,
                      concat_op([[-1, self.beam_size], remaining_shape], 0))
     res.set_shape(
         tf.TensorShape(
             (None, self.beam_size)).concatenate(tensor.get_shape()[1:]))
     return res
def batch_gather(params,
                 indices,
                 validate_indices=None,
                 batch_size=None,
                 options_size=None):
    """
    Gather slices from `params` according to `indices`, separately for each
    example in a batch.

    output[b, i, ..., j, :, ..., :] = params[b, indices[b, i, ..., j], :, ..., :]

    The arguments `batch_size` and `options_size`, if provided, are used instead
    of looking up the shape from the inputs. This may help avoid redundant
    computation (TODO: figure out if tensorflow's optimizer can do this automatically)

    Args:
      params: A `Tensor`, [batch_size, options_size, ...]
      indices: A `Tensor`, [batch_size, ...]
      validate_indices: An optional `bool`. Defaults to `True`
      batch_size: (optional) an integer or scalar tensor representing the batch size
      options_size: (optional) an integer or scalar Tensor representing the number of options to choose from
    """
    if batch_size is None:
        batch_size = params.get_shape()[0].merge_with(
            indices.get_shape()[0]).value
        if batch_size is None:
            batch_size = tf.shape(indices)[0]

    if options_size is None:
        options_size = params.get_shape()[1].value
        if options_size is None:
            options_size = tf.shape(params)[1]

    batch_size_times_options_size = batch_size * options_size

    # TODO(nikita): consider using gather_nd. However as of 1/9/2017 gather_nd
    # has no gradients implemented.
    flat_params = tf.reshape(
        params,
        concat_op([[batch_size_times_options_size],
                   tf.shape(params)[2:]],
                  axis=0))

    indices_offsets = tf.reshape(
        tf.range(batch_size) * options_size,
        [-1] + [1] * (len(indices.get_shape()) - 1))
    indices_into_flat = indices + tf.cast(indices_offsets, indices.dtype)

    return tf.gather(flat_params,
                     indices_into_flat,
                     validate_indices=validate_indices)
def show_attend_tell_attention_helper(im, part_q, keep_prob, scope=""):
    scope = scope or "ShowAttendTellCell"
    _, h, w, c = im.get_shape().as_list()
    with tf.variable_scope(scope):
        # concat im and part q
        part_q = tf.expand_dims(tf.expand_dims(part_q, 1), 2)
        part_q_tile = tf.tile(part_q, [1, h, w, 1])
        im_pq = concat_op([im, part_q_tile], axis=3)
        im_ctx = slim.conv2d(im_pq,
                             512,
                             1,
                             activation_fn=tf.nn.tanh,
                             scope='vq_fusion')
        im_ctx = slim.dropout(im_ctx, keep_prob=keep_prob)
        v, _ = _soft_attention_pool_with_map(im, im_ctx)
    return v
def build_decoder(im,
                  attr,
                  ans_embed,
                  quest,
                  quest_len,
                  rewards,
                  advantage,
                  vocab_size,
                  keep_prob,
                  pad_token,
                  num_dec_cells,
                  phase='train',
                  reuse=False,
                  xe_mask=None,
                  T=4):
    if attr is None:
        in_embed = ans_embed
    else:
        in_embed = concat_op(values=[attr, ans_embed], axis=1)

    with tf.variable_scope('inverse_vqa', reuse=reuse):
        if phase == 'train' or phase == 'condition' or phase == 'evaluate':
            inputs, targets, length = build_caption_inputs_and_targets(
                quest, quest_len)
            return _build_training_decoder(in_embed, im, ans_embed, inputs,
                                           length, targets, vocab_size,
                                           num_dec_cells, keep_prob, pad_token,
                                           rewards, advantage, xe_mask)
        elif phase == 'random':
            return _build_random_inference_decoder(in_embed, im, ans_embed,
                                                   vocab_size, num_dec_cells,
                                                   VOCAB_CONFIG.start_token_id,
                                                   pad_token)
        elif phase == 'mixer':
            return _build_mixer_inference_decoder(in_embed, im, ans_embed,
                                                  vocab_size, num_dec_cells,
                                                  quest, quest_len, pad_token,
                                                  T)
        elif phase == 'beam':
            return _build_tf_beam_inference_decoder(
                in_embed, im, ans_embed, vocab_size, num_dec_cells,
                VOCAB_CONFIG.start_token_id, pad_token)
        else:
            raise Exception('unknown option')
    def __call__(self, inputs, state, scope=None):
        """Attention cell with answer context."""
        with tf.variable_scope(scope or type(self).__name__):
            if self._state_is_tuple:
                _, h = state
            else:
                _, h = split_op(values=state, num_splits=2, axis=1)

            with tf.variable_scope('Attention'):
                v = show_attend_tell_attention_helper(self._context, h,
                                                      self._keep_prob)
            lstm_input = concat_op([v, inputs], axis=1)
            lstm_h, next_state = self._lstm_cell(lstm_input, state=state)
            # concat outputs
            h_ctx_reduct = concat_fusion(v,
                                         lstm_h,
                                         self._num_units,
                                         act_fn=None)
            output = h_ctx_reduct + inputs
        return output, next_state
    def build_answer_basis(self):
        ans_vocab_size = VOCAB_CONFIG.answer_vocab_size
        enc_lstm = create_dropout_lstm_cells(256, self._keep_prob,
                                             self._keep_prob)
        # create word embedding
        with tf.variable_scope('answer_embed'):
            ans_embed_map = tf.get_variable(
                name='word_map',
                shape=[ans_vocab_size, self._word_embed_dim],
                initializer=get_default_initializer())
            ans_word_embed = tf.nn.embedding_lookup(ans_embed_map,
                                                    self._answers)

        _, states = tf.nn.dynamic_rnn(enc_lstm,
                                      ans_word_embed,
                                      self._ans_len,
                                      dtype=tf.float32,
                                      scope='AnswerEncoder')
        # self.debug_ops.append(ans_word_embed)
        self._answer_embed = concat_op(values=states,
                                       axis=1)  # concat tuples and concat
 def __call__(self, inputs, state, scope=None):
     """Run this multi-layer cell on inputs, starting from state."""
     with tf.variable_scope(scope or "multi_rnn_cell"):
         cur_state_pos = 0
         cur_inp = inputs
         new_states = []
         for i, cell in enumerate(self._cells):
             with tf.variable_scope("cell_%d" % i):
                 if self._state_is_tuple:
                     if not nest.is_sequence(state):
                         raise ValueError(
                             "Expected state to be a tuple of length %d, but received: %s"
                             % (len(self.state_size), state))
                     cur_state = state[i]
                 else:
                     cur_state = tf.slice(
                         state, [0, cur_state_pos], [-1, cell.state_size])
                     cur_state_pos += cell.state_size
                 cur_inp, new_state = cell(cur_inp, cur_state)
                 new_states.append(new_state)
     new_states = (tuple(new_states) if self._state_is_tuple else
                   concat_op(new_states, 1))
     return cur_inp, new_states
Beispiel #14
0
def build_decoder(im,
                  conditions,
                  quest,
                  quest_len,
                  vocab_size,
                  keep_prob,
                  pad_token,
                  num_dec_cells,
                  phase='train'):
    # average pooling over image
    answer_embed, noise = conditions
    answer_reduct = slim.fully_connected(answer_embed,
                                         num_dec_cells,
                                         activation_fn=tf.nn.tanh,
                                         scope='answer_mask')
    z_embed = answer_reduct * noise
    in_embed = concat_op(values=[im, answer_embed], axis=1)
    with tf.variable_scope('vaq'):
        if phase == 'train' or phase == 'condition':
            inputs, targets, length = build_caption_inputs_and_targets(
                quest, quest_len)
            return _build_training_decoder(in_embed, z_embed, inputs, length,
                                           targets, vocab_size, num_dec_cells,
                                           keep_prob, pad_token)
        elif phase == 'greedy':
            return _build_greedy_inference_decoder(in_embed, vocab_size,
                                                   num_dec_cells,
                                                   _START_TOKEN_ID)
        elif phase == 'beam' or phase == 'sampling':
            return _build_tf_beam_inference_decoder(in_embed, z_embed,
                                                    vocab_size, num_dec_cells,
                                                    _START_TOKEN_ID)
        else:
            return _build_beamsearch_inference_decoder(in_embed, quest,
                                                       vocab_size,
                                                       num_dec_cells,
                                                       pad_token)
Beispiel #15
0
def _build_tf_beam_inference_decoder(glb_ctx, im, ans_embed, vocab_size,
                                     num_cells, start_token_id, pad_token):
    beam_size = 3
    # avoid out of range error
    vocab_size = max(vocab_size, pad_token + 1)
    # init state / image embedding
    init_h = slim.fully_connected(glb_ctx,
                                  num_cells,
                                  activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = slim.fully_connected(glb_ctx,
                                  num_cells,
                                  activation_fn=tf.nn.tanh,
                                  scope='init_c')
    init_state = concat_op([init_c, init_h], axis=1)

    # replicate context of the attention module
    im_shape = im.get_shape().as_list()[1:]
    im = tf.expand_dims(im, 1)  # add a time dim
    im = tf.reshape(tf.tile(im, [1, beam_size, 1, 1, 1]), [-1] + im_shape)

    multi_cell = ShowAttendTellCell(num_cells, im, state_is_tuple=False)

    # word embedding
    with tf.variable_scope('word_embedding'):
        word_map = tf.get_variable(name="word_map",
                                   shape=[vocab_size, num_cells],
                                   initializer=tf.random_uniform_initializer(
                                       -0.08, 0.08, dtype=tf.float32))

    # apply weights for outputs
    with tf.variable_scope('logits'):
        weights = tf.get_variable('weights',
                                  shape=[num_cells, vocab_size],
                                  dtype=tf.float32)
        biases = tf.get_variable('biases', shape=[vocab_size])

    # define helper functions
    def _tokens_to_inputs_fn(inputs):
        inputs = tf.nn.embedding_lookup(word_map, inputs)
        # inputs = tf.squeeze(inputs, [1])
        return inputs

    def _output_to_score_fn(hidden):
        batch_size = tf.shape(hidden)[0]
        beam_size = tf.shape(hidden)[1]
        hidden = tf.reshape(hidden, [batch_size * beam_size, -1])
        logits = tf.nn.xw_plus_b(hidden, weights, biases)
        logprob = tf.nn.log_softmax(logits)
        return tf.reshape(logprob, [batch_size, beam_size, -1])

    stop_token_id = VOCAB_CONFIG.end_token_id
    batch_size = tf.shape(glb_ctx)[0]
    start_tokens = tf.ones(shape=[batch_size], dtype=tf.int32) * start_token_id
    init_inputs = _tokens_to_inputs_fn(start_tokens)
    pathes, scores = beam_decoder(multi_cell,
                                  beam_size=beam_size,
                                  stop_token=stop_token_id,
                                  initial_state=init_state,
                                  initial_input=init_inputs,
                                  tokens_to_inputs_fn=_tokens_to_inputs_fn,
                                  outputs_to_score_fn=_output_to_score_fn,
                                  max_len=20,
                                  cell_transform='flatten',
                                  output_dense=True,
                                  scope='RNN')
    return scores, pathes
def _convert_to_tensor(inputs):
    from ops import concat_op
    return concat_op([tf.expand_dims(i, 1) for i in inputs], axis=1)
def build_critic(glb_ctx, im, ans, quest, quest_len, vocab_size, num_cells,
                 pad_token, rewards, xe_mask):
    # process inputs
    inputs, targets, length = build_caption_inputs_and_targets(
        quest, quest_len)

    # avoid out of range error
    vocab_size = max(vocab_size, pad_token + 1)

    # init state / image embedding
    ctx = concat_op([glb_ctx, im, ans], axis=1)
    with tf.variable_scope('critic'):
        init_h = slim.fully_connected(ctx,
                                      num_cells,
                                      activation_fn=tf.nn.tanh,
                                      scope='init_h')
    init_c = tf.zeros_like(init_h)
    init_state = LSTMStateTuple(init_c, init_h)

    # word embedding
    with tf.variable_scope('inverse_vqa', reuse=True):
        with tf.variable_scope('word_embedding'):
            word_map = tf.get_variable(
                name="word_map",
                shape=[vocab_size, num_cells],
                initializer=tf.random_uniform_initializer(-0.08,
                                                          0.08,
                                                          dtype=tf.float32))
    inputs = tf.nn.embedding_lookup(word_map, inputs)
    inputs = tf.stop_gradient(inputs)

    # build LSTM cell and RNN
    lstm = BasicLSTMCell(num_cells)

    with tf.variable_scope('critic'):
        outputs, _ = tf.nn.dynamic_rnn(lstm,
                                       inputs,
                                       length,
                                       initial_state=init_state,
                                       dtype=tf.float32,
                                       scope='rnn')

        # compute critic
        values = slim.fully_connected(outputs,
                                      1,
                                      activation_fn=None,
                                      scope='value')
        values = tf.reshape(values, [-1])

    # compute loss
    targets = tf.reshape(targets, [-1])
    valid_mask = tf.not_equal(targets, pad_token)
    rl_mask = tf.logical_not(tf.reshape(xe_mask, [-1]))
    critic_mask = tf.logical_and(rl_mask, valid_mask)
    critic_mask = tf.cast(critic_mask, tf.float32)
    rewards = tf.reshape(rewards, [-1])

    # compute loss
    critic_loss = tf.div(tf.reduce_sum(
        tf.square(values - rewards) * critic_mask * 0.5),
                         tf.reduce_sum(critic_mask),
                         name='critic_loss')
    slim.losses.add_loss(critic_loss)
    return values
 def merge_batch_beam(self, tensor):
     remaining_shape = tf.shape(tensor)[2:]
     res = tf.reshape(tensor, concat_op([[-1], remaining_shape], axis=0))
     res.set_shape(
         tf.TensorShape((None, )).concatenate(tensor.get_shape()[2:]))
     return res
    def beam_loop(self, time, cell_output, cell_state, loop_state):
        (
            past_cand_symbols,  # [batch_size, time-1]
            past_cand_logprobs,  # [batch_size]
            past_beam_symbols,  # [batch_size*beam_size, time-1], right-aligned
            past_beam_logprobs,  # [batch_size*beam_size]
        ) = loop_state

        # We don't actually use this, but emit_output is required to match the
        # cell output size specfication. Otherwise we would leave this as None.
        emit_output = cell_output

        # 1. Get scores for all candidate sequences

        logprobs = self.outputs_to_score_fn(cell_output)

        try:
            num_classes = int(logprobs.get_shape()[-1])
        except:
            # Shape inference failed
            num_classes = tf.shape(logprobs)[-1]

        logprobs_batched = tf.reshape(
            logprobs + tf.expand_dims(
                tf.reshape(past_beam_logprobs,
                           [self.batch_size, self.beam_size]), 2),
            [self.batch_size, self.beam_size * num_classes])

        # 2. Determine which states to pass to next iteration

        # TODO(nikita): consider using slice+fill+concat instead of adding a mask
        nondone_mask = tf.reshape(
            tf.cast(tf.equal(tf.range(num_classes), self.stop_token),
                    tf.float32) * self.INVALID_SCORE,
            [1, 1, num_classes])  # disable the stop token slice

        nondone_mask = tf.reshape(
            tf.tile(nondone_mask, [1, self.beam_size, 1]),
            [-1, self.beam_size * num_classes])

        beam_logprobs, indices = tf.nn.top_k(logprobs_batched + nondone_mask,
                                             self.beam_size)
        min_beam_logprobs = tf.reduce_min(beam_logprobs, 1)
        beam_logprobs = tf.reshape(beam_logprobs, [-1])

        # For continuing to the next symbols
        symbols = indices % num_classes  # [batch_size, self.beam_size]
        parent_refs = indices // num_classes  # [batch_size, self.beam_size]

        symbols_history = flat_batch_gather(past_beam_symbols,
                                            parent_refs,
                                            batch_size=self.batch_size,
                                            options_size=self.beam_size)
        beam_symbols = concat_op(
            [symbols_history, tf.reshape(symbols, [-1, 1])], 1)

        # Handle the output and the cell state shuffling
        next_cell_state = nest_map(
            lambda element: batch_gather(element,
                                         parent_refs,
                                         batch_size=self.batch_size,
                                         options_size=self.beam_size),
            cell_state)

        next_input = self.tokens_to_inputs_fn(
            tf.reshape(symbols, [-1, self.beam_size]))

        # 3. Update the candidate pool to include entries that just ended with a stop token
        logprobs_done = tf.reshape(
            logprobs_batched,
            [-1, self.beam_size, num_classes])[:, :, self.stop_token]
        done_parent_refs = tf.argmax(logprobs_done, 1)
        done_symbols = flat_batch_gather(past_beam_symbols,
                                         done_parent_refs,
                                         batch_size=self.batch_size,
                                         options_size=self.beam_size)

        logprobs_done_max = tf.reduce_max(logprobs_done, 1)

        # Make sure the end token scores higher than all the top K partial captions
        update_cond = tf.logical_and(
            tf.greater(logprobs_done_max,
                       past_cand_logprobs),  # larger than previous
            tf.greater(logprobs_done_max, min_beam_logprobs))  # in top K
        cand_symbols_unpadded = select_op(
            update_cond,
            done_symbols,  # current estimate
            past_cand_symbols)
        cand_logprobs = select_op(update_cond, logprobs_done_max,
                                  past_cand_logprobs)

        # cand_symbols_unpadded = tf.select(logprobs_done_max>past_cand_logprobs,
        #                                   done_symbols,  # current estimate
        #                                   past_cand_symbols)
        # cand_logprobs = tf.maximum(logprobs_done_max, past_cand_logprobs)

        cand_symbols = concat_op([
            cand_symbols_unpadded,
            tf.fill([self.batch_size, 1], self.stop_token)
        ], 1)

        # 4. Check the stopping criteria

        if self.max_len is not None:
            elements_finished_clip = (time >= self.max_len)

        if self.score_upper_bound is not None:
            elements_finished_bound = tf.reduce_max(
                tf.reshape(beam_logprobs, [-1, self.beam_size]),
                1) < (cand_logprobs - self.score_upper_bound)

        if self.max_len is not None and self.score_upper_bound is not None:
            elements_finished = elements_finished_clip | elements_finished_bound
        elif self.score_upper_bound is not None:
            elements_finished = elements_finished_bound
        elif self.max_len is not None:
            # this broadcasts elements_finished_clip to the correct shape
            elements_finished = tf.zeros(
                [self.batch_size], dtype=tf.bool) | elements_finished_clip
        else:
            assert False, "Lack of stopping criterion should have been caught in constructor"

        # 5. Prepare return values
        # While loops require strict shape invariants, so we manually set shapes
        # in case the automatic shape inference can't calculate these. Even when
        # this is redundant is has the benefit of helping catch shape bugs.

        for tensor in list(nest.flatten(next_input)) + list(
                nest.flatten(next_cell_state)):
            tensor.set_shape(
                tf.TensorShape(
                    (self.inferred_batch_size,
                     self.beam_size)).concatenate(tensor.get_shape()[2:]))

        for tensor in [cand_symbols, cand_logprobs, elements_finished]:
            tensor.set_shape(
                tf.TensorShape((self.inferred_batch_size, )).concatenate(
                    tensor.get_shape()[1:]))

        for tensor in [beam_symbols, beam_logprobs]:
            tensor.set_shape(
                tf.TensorShape(
                    (self.inferred_batch_size_times_beam_size, )).concatenate(
                        tensor.get_shape()[1:]))

        next_loop_state = (
            cand_symbols,
            cand_logprobs,
            beam_symbols,
            beam_logprobs,
        )

        return (elements_finished, next_input, next_cell_state, emit_output,
                next_loop_state)
 def decode_sparse(self, include_stop_tokens=True):
     dense_symbols, logprobs = self.decode_dense()
     mask = tf.not_equal(dense_symbols, self.stop_token)
     if include_stop_tokens:
         mask = concat_op([tf.ones_like(mask[:, :1]), mask[:, :-1]], 1)
     return sparse_boolean_mask(dense_symbols, mask), logprobs