def _build_greedy_inference_decoder(in_embed, vocab_size, num_cells, start_token_id): vocab_size += 1 # init state / image embedding init_h = slim.fully_connected(in_embed, num_cells, activation_fn=tf.nn.tanh, scope='init_h') init_c = tf.zeros_like(init_h) init_state = LSTMStateTuple(init_c, init_h) # build LSTM cell and RNN lstm_cell = BasicLSTMCell(num_cells) # word embedding with tf.variable_scope('word_embedding'): word_map = tf.get_variable(name="word_map", shape=[vocab_size, num_cells], initializer=tf.random_uniform_initializer( -0.08, 0.08, dtype=tf.float32)) # apply weights for outputs with tf.variable_scope('logits'): weights = tf.get_variable('weights', shape=[num_cells, vocab_size], dtype=tf.float32) biases = tf.get_variable('biases', shape=[vocab_size]) softmax_params = [weights, biases] return create_greedy_decoder(init_state, lstm_cell, word_map, softmax_params, start_token_id)
def _build_greedy_inference_decoder(glb_ctx, im, ans_embed, vocab_size, num_cells, start_token_id): vocab_size += 1 # init state / image embedding init_h = slim.fully_connected(glb_ctx, num_cells, activation_fn=tf.nn.tanh, scope='init_h') init_c = tf.zeros_like(init_h) init_lstm_state = LSTMStateTuple(init_c, init_h) batch_size = tf.shape(init_h)[0] att_zero_state = tf.zeros([batch_size, num_cells], dtype=tf.float32) init_state = (init_lstm_state, att_zero_state) # build LSTM cell and RNN lstm_cell = BasicLSTMCell(num_cells) attention_cell = MultiModalAttentionCell(512, im, ans_embed, keep_prob=1.0) multi_cell = MultiRNNCell([lstm_cell, attention_cell], state_is_tuple=True) # word embedding with tf.variable_scope('word_embedding'): word_map = tf.get_variable( name="word_map", shape=[vocab_size, num_cells], initializer=tf.random_uniform_initializer(-0.08, 0.08, dtype=tf.float32)) # apply weights for outputs with tf.variable_scope('logits'): weights = tf.get_variable('weights', shape=[num_cells, vocab_size], dtype=tf.float32) biases = tf.get_variable('biases', shape=[vocab_size]) softmax_params = [weights, biases] return create_greedy_decoder(init_state, multi_cell, word_map, softmax_params, start_token_id)