def _build_beamsearch_inference_decoder(glb_ctx, im, ans, inputs, vocab_size, num_cells, pad_token): vocab_size = max(vocab_size, pad_token + 1) # =================== create cells ======================== lstm = BasicLSTMCell(num_cells) attention_cell = MultiModalAttentionCell(512, im, ans, keep_prob=1.0) multi_cell = MultiRNNCell([lstm, attention_cell], state_is_tuple=True) lstm_state_sizes, att_state_size = multi_cell.state_size lstm_state_size = sum(lstm_state_sizes) state_size = lstm_state_size + att_state_size # ============= create state placeholders ================ state_feed = tf.placeholder(dtype=tf.float32, shape=[None, state_size], name='state_feed') lstm_state_feed = tf.slice(state_feed, begin=[0, 0], size=[-1, lstm_state_size]) att_state_feed = tf.slice(state_feed, begin=[0, lstm_state_size], size=[-1, att_state_size]) feed_c, feed_h = split_op(lstm_state_feed, num_splits=2, axis=1) state_tuple = LSTMStateTuple(feed_c, feed_h) multi_cell_state_feed = (state_tuple, att_state_feed) # ==================== create init state ================== # lstm init state init_h = slim.fully_connected(glb_ctx, num_cells, activation_fn=tf.nn.tanh, scope='init_h') init_c = tf.zeros_like(init_h) batch_size = tf.shape(init_h)[0] init_att = tf.zeros([batch_size, num_cells], dtype=tf.float32) init_state = (init_c, init_h, init_att) concat_op(init_state, axis=1, name="initial_state") # need to fetch # ==================== forward pass ======================== with tf.variable_scope('word_embedding'): word_map = tf.get_variable( name="word_map", shape=[vocab_size, num_cells], initializer=tf.random_uniform_initializer(-0.08, 0.08, dtype=tf.float32)) word_embed = tf.nn.embedding_lookup(word_map, inputs) with tf.variable_scope('RNN'): outputs, multi_state = multi_cell(tf.squeeze(word_embed, squeeze_dims=[1]), state=multi_cell_state_feed) # ==================== concat states ========================= lstm_state, att_state = multi_state state_c, state_h = lstm_state concat_op((state_c, state_h, att_state), axis=1, name="state") # need to fetch # predict next word outputs = tf.reshape(outputs, [-1, num_cells]) logits = slim.fully_connected(outputs, vocab_size, activation_fn=None, scope='logits') prob = tf.nn.softmax(logits, name="softmax") # need to fetch return prob
def _build_beamsearch_inference_decoder(in_embed, inputs, vocab_size, num_cells, pad_token): # avoid out of range error vocab_size = max(vocab_size, pad_token + 1) # init state / image embedding init_h = slim.fully_connected(in_embed, num_cells, activation_fn=tf.nn.tanh, scope='init_h') init_c = tf.zeros_like(init_h) init_state = LSTMStateTuple(init_c, init_h) # build LSTM cell and RNN lstm_cell = BasicLSTMCell(num_cells) concat_op(init_state, axis=1, name="initial_state") # word embedding with tf.variable_scope('word_embedding'): word_map = tf.get_variable(name="word_map", shape=[vocab_size, num_cells], initializer=tf.random_uniform_initializer( -0.08, 0.08, dtype=tf.float32)) word_embed = tf.nn.embedding_lookup(word_map, inputs) # Placeholder for feeding a batch of concatenated states. state_feed = tf.placeholder(dtype=tf.float32, shape=[None, sum(lstm_cell.state_size)], name="state_feed") feed_c, feed_h = split_op(state_feed, num_splits=2, axis=1) state_tuple = LSTMStateTuple(feed_c, feed_h) # Run a single LSTM step. with tf.variable_scope('RNN'): outputs, state_tuple = lstm_cell(inputs=tf.squeeze(word_embed, squeeze_dims=[1]), state=state_tuple) # Concatentate the resulting state. state = concat_op(state_tuple, 1, name="state") # Stack batches vertically. outputs = tf.reshape(outputs, [-1, lstm_cell.output_size]) logits = slim.fully_connected(outputs, vocab_size, activation_fn=None, scope='logits') prob = tf.nn.softmax(logits, name="softmax") return state
def build_decoder(fused, noise, ans, ans_len, vocab_size, keep_prob, pad_token, num_dec_cells, phase='train'): # quest_embed, noise = ctx # average pooling over image in_embed = concat_op([fused, noise], axis=1) with tf.variable_scope('var_vqa'): if phase == 'train' or phase == 'condition': add_loss = phase == 'train' inputs, targets, length = build_caption_inputs_and_targets( ans, ans_len) return _build_training_decoder(in_embed, inputs, length, targets, vocab_size, num_dec_cells, keep_prob, pad_token, add_loss) elif phase == 'greedy': return _build_greedy_inference_decoder(in_embed, vocab_size, num_dec_cells, _START_TOKEN_ID) elif phase == 'beam' or phase == 'sampling': return _build_tf_beam_inference_decoder(in_embed, vocab_size, num_dec_cells, _START_TOKEN_ID) else: return _build_beamsearch_inference_decoder(in_embed, quest, vocab_size, num_dec_cells, pad_token)
def build_decoding_attention_vaq_model(im, attr, ans_embed, quest, quest_len, vocab_size, keep_prob, pad_token, num_dec_cells, phase='train', cell_option=1): if attr is None: in_embed = ans_embed else: in_embed = concat_op(values=[attr, ans_embed], axis=1) with tf.variable_scope('vaq'): if phase == 'train' or phase == 'condition' or phase == 'evaluate': inputs, targets, length = _build_caption_inputs_and_targets( quest, quest_len) return build_attention_decoder(in_embed, im, ans_embed, inputs, length, targets, vocab_size, num_dec_cells, keep_prob, pad_token, cell_option) elif phase == 'greedy': return build_attention_greedy_decoder(in_embed, im, ans_embed, vocab_size, num_dec_cells, pad_token, cell_option) else: return build_vaq_dec_attention_predictor(in_embed, im, ans_embed, quest, vocab_size, num_dec_cells, pad_token, cell_option)
def build_decoder(im, ans_embed, quest, quest_len, vocab_size, keep_prob, pad_token, num_dec_cells, phase='train'): # average pooling over image in_embed = concat_op(values=[im, ans_embed], axis=1) with tf.variable_scope('vaq'): if phase == 'train' or phase == 'condition': inputs, targets, length = build_caption_inputs_and_targets( quest, quest_len) return _build_training_decoder(in_embed, inputs, length, targets, vocab_size, num_dec_cells, keep_prob, pad_token) elif phase == 'greedy': return _build_greedy_inference_decoder(in_embed, vocab_size, num_dec_cells, _START_TOKEN_ID) elif phase == 'sampling': return _build_random_inference_decoder(in_embed, vocab_size, num_dec_cells, _START_TOKEN_ID) elif phase == 'beam': return _build_tf_beam_inference_decoder(in_embed, vocab_size, num_dec_cells, _START_TOKEN_ID) else: return _build_beamsearch_inference_decoder(in_embed, quest, vocab_size, num_dec_cells, pad_token)
def build_decoder(im, attr, ans_embed, quest, quest_len, vocab_size, keep_prob, pad_token, num_dec_cells, phase='train'): if attr is None: in_embed = ans_embed else: in_embed = concat_op(values=[attr, ans_embed], axis=1) with tf.variable_scope('inverse_vqa'): if phase == 'train' or phase == 'condition' or phase == 'evaluate': inputs, targets, length = build_caption_inputs_and_targets(quest, quest_len) return _build_training_decoder(in_embed, im, ans_embed, inputs, length, targets, vocab_size, num_dec_cells, keep_prob, pad_token) elif phase == 'greedy': return _build_greedy_inference_decoder(in_embed, im, ans_embed, vocab_size, num_dec_cells, VOCAB_CONFIG.start_token_id) elif phase == 'beam': return _build_tf_beam_inference_decoder(in_embed, im, ans_embed, vocab_size, num_dec_cells, VOCAB_CONFIG.start_token_id, pad_token) else: return _build_beamsearch_inference_decoder(in_embed, im, ans_embed, quest, vocab_size, num_dec_cells, pad_token)
def unmerge_batch_beam(self, tensor): remaining_shape = tf.shape(tensor)[1:] res = tf.reshape(tensor, concat_op([[-1, self.beam_size], remaining_shape], 0)) res.set_shape( tf.TensorShape( (None, self.beam_size)).concatenate(tensor.get_shape()[1:])) return res
def batch_gather(params, indices, validate_indices=None, batch_size=None, options_size=None): """ Gather slices from `params` according to `indices`, separately for each example in a batch. output[b, i, ..., j, :, ..., :] = params[b, indices[b, i, ..., j], :, ..., :] The arguments `batch_size` and `options_size`, if provided, are used instead of looking up the shape from the inputs. This may help avoid redundant computation (TODO: figure out if tensorflow's optimizer can do this automatically) Args: params: A `Tensor`, [batch_size, options_size, ...] indices: A `Tensor`, [batch_size, ...] validate_indices: An optional `bool`. Defaults to `True` batch_size: (optional) an integer or scalar tensor representing the batch size options_size: (optional) an integer or scalar Tensor representing the number of options to choose from """ if batch_size is None: batch_size = params.get_shape()[0].merge_with( indices.get_shape()[0]).value if batch_size is None: batch_size = tf.shape(indices)[0] if options_size is None: options_size = params.get_shape()[1].value if options_size is None: options_size = tf.shape(params)[1] batch_size_times_options_size = batch_size * options_size # TODO(nikita): consider using gather_nd. However as of 1/9/2017 gather_nd # has no gradients implemented. flat_params = tf.reshape( params, concat_op([[batch_size_times_options_size], tf.shape(params)[2:]], axis=0)) indices_offsets = tf.reshape( tf.range(batch_size) * options_size, [-1] + [1] * (len(indices.get_shape()) - 1)) indices_into_flat = indices + tf.cast(indices_offsets, indices.dtype) return tf.gather(flat_params, indices_into_flat, validate_indices=validate_indices)
def show_attend_tell_attention_helper(im, part_q, keep_prob, scope=""): scope = scope or "ShowAttendTellCell" _, h, w, c = im.get_shape().as_list() with tf.variable_scope(scope): # concat im and part q part_q = tf.expand_dims(tf.expand_dims(part_q, 1), 2) part_q_tile = tf.tile(part_q, [1, h, w, 1]) im_pq = concat_op([im, part_q_tile], axis=3) im_ctx = slim.conv2d(im_pq, 512, 1, activation_fn=tf.nn.tanh, scope='vq_fusion') im_ctx = slim.dropout(im_ctx, keep_prob=keep_prob) v, _ = _soft_attention_pool_with_map(im, im_ctx) return v
def build_decoder(im, attr, ans_embed, quest, quest_len, rewards, advantage, vocab_size, keep_prob, pad_token, num_dec_cells, phase='train', reuse=False, xe_mask=None, T=4): if attr is None: in_embed = ans_embed else: in_embed = concat_op(values=[attr, ans_embed], axis=1) with tf.variable_scope('inverse_vqa', reuse=reuse): if phase == 'train' or phase == 'condition' or phase == 'evaluate': inputs, targets, length = build_caption_inputs_and_targets( quest, quest_len) return _build_training_decoder(in_embed, im, ans_embed, inputs, length, targets, vocab_size, num_dec_cells, keep_prob, pad_token, rewards, advantage, xe_mask) elif phase == 'random': return _build_random_inference_decoder(in_embed, im, ans_embed, vocab_size, num_dec_cells, VOCAB_CONFIG.start_token_id, pad_token) elif phase == 'mixer': return _build_mixer_inference_decoder(in_embed, im, ans_embed, vocab_size, num_dec_cells, quest, quest_len, pad_token, T) elif phase == 'beam': return _build_tf_beam_inference_decoder( in_embed, im, ans_embed, vocab_size, num_dec_cells, VOCAB_CONFIG.start_token_id, pad_token) else: raise Exception('unknown option')
def __call__(self, inputs, state, scope=None): """Attention cell with answer context.""" with tf.variable_scope(scope or type(self).__name__): if self._state_is_tuple: _, h = state else: _, h = split_op(values=state, num_splits=2, axis=1) with tf.variable_scope('Attention'): v = show_attend_tell_attention_helper(self._context, h, self._keep_prob) lstm_input = concat_op([v, inputs], axis=1) lstm_h, next_state = self._lstm_cell(lstm_input, state=state) # concat outputs h_ctx_reduct = concat_fusion(v, lstm_h, self._num_units, act_fn=None) output = h_ctx_reduct + inputs return output, next_state
def build_answer_basis(self): ans_vocab_size = VOCAB_CONFIG.answer_vocab_size enc_lstm = create_dropout_lstm_cells(256, self._keep_prob, self._keep_prob) # create word embedding with tf.variable_scope('answer_embed'): ans_embed_map = tf.get_variable( name='word_map', shape=[ans_vocab_size, self._word_embed_dim], initializer=get_default_initializer()) ans_word_embed = tf.nn.embedding_lookup(ans_embed_map, self._answers) _, states = tf.nn.dynamic_rnn(enc_lstm, ans_word_embed, self._ans_len, dtype=tf.float32, scope='AnswerEncoder') # self.debug_ops.append(ans_word_embed) self._answer_embed = concat_op(values=states, axis=1) # concat tuples and concat
def __call__(self, inputs, state, scope=None): """Run this multi-layer cell on inputs, starting from state.""" with tf.variable_scope(scope or "multi_rnn_cell"): cur_state_pos = 0 cur_inp = inputs new_states = [] for i, cell in enumerate(self._cells): with tf.variable_scope("cell_%d" % i): if self._state_is_tuple: if not nest.is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) cur_state = state[i] else: cur_state = tf.slice( state, [0, cur_state_pos], [-1, cell.state_size]) cur_state_pos += cell.state_size cur_inp, new_state = cell(cur_inp, cur_state) new_states.append(new_state) new_states = (tuple(new_states) if self._state_is_tuple else concat_op(new_states, 1)) return cur_inp, new_states
def build_decoder(im, conditions, quest, quest_len, vocab_size, keep_prob, pad_token, num_dec_cells, phase='train'): # average pooling over image answer_embed, noise = conditions answer_reduct = slim.fully_connected(answer_embed, num_dec_cells, activation_fn=tf.nn.tanh, scope='answer_mask') z_embed = answer_reduct * noise in_embed = concat_op(values=[im, answer_embed], axis=1) with tf.variable_scope('vaq'): if phase == 'train' or phase == 'condition': inputs, targets, length = build_caption_inputs_and_targets( quest, quest_len) return _build_training_decoder(in_embed, z_embed, inputs, length, targets, vocab_size, num_dec_cells, keep_prob, pad_token) elif phase == 'greedy': return _build_greedy_inference_decoder(in_embed, vocab_size, num_dec_cells, _START_TOKEN_ID) elif phase == 'beam' or phase == 'sampling': return _build_tf_beam_inference_decoder(in_embed, z_embed, vocab_size, num_dec_cells, _START_TOKEN_ID) else: return _build_beamsearch_inference_decoder(in_embed, quest, vocab_size, num_dec_cells, pad_token)
def _build_tf_beam_inference_decoder(glb_ctx, im, ans_embed, vocab_size, num_cells, start_token_id, pad_token): beam_size = 3 # avoid out of range error vocab_size = max(vocab_size, pad_token + 1) # init state / image embedding init_h = slim.fully_connected(glb_ctx, num_cells, activation_fn=tf.nn.tanh, scope='init_h') init_c = slim.fully_connected(glb_ctx, num_cells, activation_fn=tf.nn.tanh, scope='init_c') init_state = concat_op([init_c, init_h], axis=1) # replicate context of the attention module im_shape = im.get_shape().as_list()[1:] im = tf.expand_dims(im, 1) # add a time dim im = tf.reshape(tf.tile(im, [1, beam_size, 1, 1, 1]), [-1] + im_shape) multi_cell = ShowAttendTellCell(num_cells, im, state_is_tuple=False) # word embedding with tf.variable_scope('word_embedding'): word_map = tf.get_variable(name="word_map", shape=[vocab_size, num_cells], initializer=tf.random_uniform_initializer( -0.08, 0.08, dtype=tf.float32)) # apply weights for outputs with tf.variable_scope('logits'): weights = tf.get_variable('weights', shape=[num_cells, vocab_size], dtype=tf.float32) biases = tf.get_variable('biases', shape=[vocab_size]) # define helper functions def _tokens_to_inputs_fn(inputs): inputs = tf.nn.embedding_lookup(word_map, inputs) # inputs = tf.squeeze(inputs, [1]) return inputs def _output_to_score_fn(hidden): batch_size = tf.shape(hidden)[0] beam_size = tf.shape(hidden)[1] hidden = tf.reshape(hidden, [batch_size * beam_size, -1]) logits = tf.nn.xw_plus_b(hidden, weights, biases) logprob = tf.nn.log_softmax(logits) return tf.reshape(logprob, [batch_size, beam_size, -1]) stop_token_id = VOCAB_CONFIG.end_token_id batch_size = tf.shape(glb_ctx)[0] start_tokens = tf.ones(shape=[batch_size], dtype=tf.int32) * start_token_id init_inputs = _tokens_to_inputs_fn(start_tokens) pathes, scores = beam_decoder(multi_cell, beam_size=beam_size, stop_token=stop_token_id, initial_state=init_state, initial_input=init_inputs, tokens_to_inputs_fn=_tokens_to_inputs_fn, outputs_to_score_fn=_output_to_score_fn, max_len=20, cell_transform='flatten', output_dense=True, scope='RNN') return scores, pathes
def _convert_to_tensor(inputs): from ops import concat_op return concat_op([tf.expand_dims(i, 1) for i in inputs], axis=1)
def build_critic(glb_ctx, im, ans, quest, quest_len, vocab_size, num_cells, pad_token, rewards, xe_mask): # process inputs inputs, targets, length = build_caption_inputs_and_targets( quest, quest_len) # avoid out of range error vocab_size = max(vocab_size, pad_token + 1) # init state / image embedding ctx = concat_op([glb_ctx, im, ans], axis=1) with tf.variable_scope('critic'): init_h = slim.fully_connected(ctx, num_cells, activation_fn=tf.nn.tanh, scope='init_h') init_c = tf.zeros_like(init_h) init_state = LSTMStateTuple(init_c, init_h) # word embedding with tf.variable_scope('inverse_vqa', reuse=True): with tf.variable_scope('word_embedding'): word_map = tf.get_variable( name="word_map", shape=[vocab_size, num_cells], initializer=tf.random_uniform_initializer(-0.08, 0.08, dtype=tf.float32)) inputs = tf.nn.embedding_lookup(word_map, inputs) inputs = tf.stop_gradient(inputs) # build LSTM cell and RNN lstm = BasicLSTMCell(num_cells) with tf.variable_scope('critic'): outputs, _ = tf.nn.dynamic_rnn(lstm, inputs, length, initial_state=init_state, dtype=tf.float32, scope='rnn') # compute critic values = slim.fully_connected(outputs, 1, activation_fn=None, scope='value') values = tf.reshape(values, [-1]) # compute loss targets = tf.reshape(targets, [-1]) valid_mask = tf.not_equal(targets, pad_token) rl_mask = tf.logical_not(tf.reshape(xe_mask, [-1])) critic_mask = tf.logical_and(rl_mask, valid_mask) critic_mask = tf.cast(critic_mask, tf.float32) rewards = tf.reshape(rewards, [-1]) # compute loss critic_loss = tf.div(tf.reduce_sum( tf.square(values - rewards) * critic_mask * 0.5), tf.reduce_sum(critic_mask), name='critic_loss') slim.losses.add_loss(critic_loss) return values
def merge_batch_beam(self, tensor): remaining_shape = tf.shape(tensor)[2:] res = tf.reshape(tensor, concat_op([[-1], remaining_shape], axis=0)) res.set_shape( tf.TensorShape((None, )).concatenate(tensor.get_shape()[2:])) return res
def beam_loop(self, time, cell_output, cell_state, loop_state): ( past_cand_symbols, # [batch_size, time-1] past_cand_logprobs, # [batch_size] past_beam_symbols, # [batch_size*beam_size, time-1], right-aligned past_beam_logprobs, # [batch_size*beam_size] ) = loop_state # We don't actually use this, but emit_output is required to match the # cell output size specfication. Otherwise we would leave this as None. emit_output = cell_output # 1. Get scores for all candidate sequences logprobs = self.outputs_to_score_fn(cell_output) try: num_classes = int(logprobs.get_shape()[-1]) except: # Shape inference failed num_classes = tf.shape(logprobs)[-1] logprobs_batched = tf.reshape( logprobs + tf.expand_dims( tf.reshape(past_beam_logprobs, [self.batch_size, self.beam_size]), 2), [self.batch_size, self.beam_size * num_classes]) # 2. Determine which states to pass to next iteration # TODO(nikita): consider using slice+fill+concat instead of adding a mask nondone_mask = tf.reshape( tf.cast(tf.equal(tf.range(num_classes), self.stop_token), tf.float32) * self.INVALID_SCORE, [1, 1, num_classes]) # disable the stop token slice nondone_mask = tf.reshape( tf.tile(nondone_mask, [1, self.beam_size, 1]), [-1, self.beam_size * num_classes]) beam_logprobs, indices = tf.nn.top_k(logprobs_batched + nondone_mask, self.beam_size) min_beam_logprobs = tf.reduce_min(beam_logprobs, 1) beam_logprobs = tf.reshape(beam_logprobs, [-1]) # For continuing to the next symbols symbols = indices % num_classes # [batch_size, self.beam_size] parent_refs = indices // num_classes # [batch_size, self.beam_size] symbols_history = flat_batch_gather(past_beam_symbols, parent_refs, batch_size=self.batch_size, options_size=self.beam_size) beam_symbols = concat_op( [symbols_history, tf.reshape(symbols, [-1, 1])], 1) # Handle the output and the cell state shuffling next_cell_state = nest_map( lambda element: batch_gather(element, parent_refs, batch_size=self.batch_size, options_size=self.beam_size), cell_state) next_input = self.tokens_to_inputs_fn( tf.reshape(symbols, [-1, self.beam_size])) # 3. Update the candidate pool to include entries that just ended with a stop token logprobs_done = tf.reshape( logprobs_batched, [-1, self.beam_size, num_classes])[:, :, self.stop_token] done_parent_refs = tf.argmax(logprobs_done, 1) done_symbols = flat_batch_gather(past_beam_symbols, done_parent_refs, batch_size=self.batch_size, options_size=self.beam_size) logprobs_done_max = tf.reduce_max(logprobs_done, 1) # Make sure the end token scores higher than all the top K partial captions update_cond = tf.logical_and( tf.greater(logprobs_done_max, past_cand_logprobs), # larger than previous tf.greater(logprobs_done_max, min_beam_logprobs)) # in top K cand_symbols_unpadded = select_op( update_cond, done_symbols, # current estimate past_cand_symbols) cand_logprobs = select_op(update_cond, logprobs_done_max, past_cand_logprobs) # cand_symbols_unpadded = tf.select(logprobs_done_max>past_cand_logprobs, # done_symbols, # current estimate # past_cand_symbols) # cand_logprobs = tf.maximum(logprobs_done_max, past_cand_logprobs) cand_symbols = concat_op([ cand_symbols_unpadded, tf.fill([self.batch_size, 1], self.stop_token) ], 1) # 4. Check the stopping criteria if self.max_len is not None: elements_finished_clip = (time >= self.max_len) if self.score_upper_bound is not None: elements_finished_bound = tf.reduce_max( tf.reshape(beam_logprobs, [-1, self.beam_size]), 1) < (cand_logprobs - self.score_upper_bound) if self.max_len is not None and self.score_upper_bound is not None: elements_finished = elements_finished_clip | elements_finished_bound elif self.score_upper_bound is not None: elements_finished = elements_finished_bound elif self.max_len is not None: # this broadcasts elements_finished_clip to the correct shape elements_finished = tf.zeros( [self.batch_size], dtype=tf.bool) | elements_finished_clip else: assert False, "Lack of stopping criterion should have been caught in constructor" # 5. Prepare return values # While loops require strict shape invariants, so we manually set shapes # in case the automatic shape inference can't calculate these. Even when # this is redundant is has the benefit of helping catch shape bugs. for tensor in list(nest.flatten(next_input)) + list( nest.flatten(next_cell_state)): tensor.set_shape( tf.TensorShape( (self.inferred_batch_size, self.beam_size)).concatenate(tensor.get_shape()[2:])) for tensor in [cand_symbols, cand_logprobs, elements_finished]: tensor.set_shape( tf.TensorShape((self.inferred_batch_size, )).concatenate( tensor.get_shape()[1:])) for tensor in [beam_symbols, beam_logprobs]: tensor.set_shape( tf.TensorShape( (self.inferred_batch_size_times_beam_size, )).concatenate( tensor.get_shape()[1:])) next_loop_state = ( cand_symbols, cand_logprobs, beam_symbols, beam_logprobs, ) return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)
def decode_sparse(self, include_stop_tokens=True): dense_symbols, logprobs = self.decode_dense() mask = tf.not_equal(dense_symbols, self.stop_token) if include_stop_tokens: mask = concat_op([tf.ones_like(mask[:, :1]), mask[:, :-1]], 1) return sparse_boolean_mask(dense_symbols, mask), logprobs