Ejemplo n.º 1
0
def _build_greedy_inference_decoder(glb_ctx, im, ans_embed, vocab_size,
                                    num_cells, start_token_id):
    vocab_size += 1
    # init state / image embedding
    init_h = slim.fully_connected(glb_ctx, num_cells, activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = tf.zeros_like(init_h)
    init_lstm_state = LSTMStateTuple(init_c, init_h)
    batch_size = tf.shape(init_h)[0]
    att_zero_state = tf.zeros([batch_size, num_cells], dtype=tf.float32)
    init_state = (init_lstm_state, att_zero_state)

    # build LSTM cell and RNN
    lstm_cell = BasicLSTMCell(num_cells)
    attention_cell = MultiModalAttentionCell(512, im, ans_embed, keep_prob=1.0)
    multi_cell = MultiRNNCell([lstm_cell, attention_cell], state_is_tuple=True)

    # word embedding
    with tf.variable_scope('word_embedding'):
        word_map = tf.get_variable(
            name="word_map",
            shape=[vocab_size, num_cells],
            initializer=tf.random_uniform_initializer(-0.08, 0.08,
                                                      dtype=tf.float32))

    # apply weights for outputs
    with tf.variable_scope('logits'):
        weights = tf.get_variable('weights', shape=[num_cells, vocab_size], dtype=tf.float32)
        biases = tf.get_variable('biases', shape=[vocab_size])
        softmax_params = [weights, biases]

    return create_greedy_decoder(init_state, multi_cell, word_map,
                                 softmax_params, start_token_id)
Ejemplo n.º 2
0
def _build_beamsearch_inference_decoder(glb_ctx, im, ans, inputs, vocab_size,
                                        num_cells, pad_token):
    vocab_size = max(vocab_size, pad_token + 1)

    # =================== create cells ========================
    lstm = BasicLSTMCell(num_cells)
    attention_cell = MultiModalAttentionCell(512, im, ans, keep_prob=1.0)

    multi_cell = MultiRNNCell([lstm, attention_cell],
                              state_is_tuple=True)
    lstm_state_sizes, att_state_size = multi_cell.state_size
    lstm_state_size = sum(lstm_state_sizes)
    state_size = lstm_state_size + att_state_size

    # =============  create state placeholders ================
    state_feed = tf.placeholder(dtype=tf.float32, shape=[None, state_size],
                                name='state_feed')
    lstm_state_feed = tf.slice(state_feed, begin=[0, 0],
                               size=[-1, lstm_state_size])
    att_state_feed = tf.slice(state_feed, begin=[0, lstm_state_size],
                              size=[-1, att_state_size])
    feed_c, feed_h = split_op(lstm_state_feed, num_splits=2, axis=1)
    state_tuple = LSTMStateTuple(feed_c, feed_h)
    multi_cell_state_feed = (state_tuple, att_state_feed)

    # ==================== create init state ==================
    # lstm init state
    init_h = slim.fully_connected(glb_ctx, num_cells, activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = tf.zeros_like(init_h)
    batch_size = tf.shape(init_h)[0]

    init_att = tf.zeros([batch_size, num_cells], dtype=tf.float32)
    init_state = (init_c, init_h, init_att)
    concat_op(init_state, axis=1, name="initial_state")  # need to fetch

    # ==================== forward pass ========================
    with tf.variable_scope('word_embedding'):
        word_map = tf.get_variable(
            name="word_map",
            shape=[vocab_size, num_cells],
            initializer=tf.random_uniform_initializer(-0.08, 0.08,
                                                      dtype=tf.float32))
    word_embed = tf.nn.embedding_lookup(word_map, inputs)

    with tf.variable_scope('RNN'):
        outputs, multi_state = multi_cell(tf.squeeze(word_embed, squeeze_dims=[1]),
                                          state=multi_cell_state_feed)

    # ==================== concat states =========================
    lstm_state, att_state = multi_state
    state_c, state_h = lstm_state
    concat_op((state_c, state_h, att_state), axis=1, name="state")  # need to fetch

    # predict next word
    outputs = tf.reshape(outputs, [-1, num_cells])
    logits = slim.fully_connected(outputs, vocab_size, activation_fn=None,
                                  scope='logits')
    prob = tf.nn.softmax(logits, name="softmax")  # need to fetch
    return prob
Ejemplo n.º 3
0
def _build_training_decoder(glb_ctx,
                            im,
                            ans,
                            inputs,
                            length,
                            targets,
                            vocab_size,
                            num_cells,
                            keep_prob,
                            pad_token,
                            add_loss=True):
    # avoid out of range error
    vocab_size = max(vocab_size, pad_token + 1)
    # init state / image embedding
    init_h = slim.fully_connected(glb_ctx,
                                  num_cells,
                                  activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = tf.zeros_like(init_h)
    init_state_lstm = LSTMStateTuple(init_c, init_h)
    batch_size = tf.shape(init_h)[0]
    att_zero_state = tf.zeros([batch_size, num_cells], dtype=tf.float32)
    init_state = (init_state_lstm, att_zero_state)

    # word embedding
    with tf.variable_scope('word_embedding'):
        word_map = tf.get_variable(name="word_map",
                                   shape=[vocab_size, num_cells],
                                   initializer=tf.random_uniform_initializer(
                                       -0.08, 0.08, dtype=tf.float32))
    inputs = tf.nn.embedding_lookup(word_map, inputs)

    # build LSTM cell and RNN
    lstm = create_drop_lstm_cell(num_cells,
                                 input_keep_prob=keep_prob,
                                 output_keep_prob=keep_prob,
                                 cell_fn=BasicLSTMCell)

    attention_cell = MultiModalAttentionCell(512, im, ans, keep_prob=keep_prob)
    attention_cell = DropoutWrapper(attention_cell,
                                    input_keep_prob=1.0,
                                    output_keep_prob=keep_prob)

    multi_cell = MultiRNNCell([lstm, attention_cell], state_is_tuple=True)
    outputs, _ = tf.nn.dynamic_rnn(multi_cell,
                                   inputs,
                                   length,
                                   initial_state=init_state,
                                   dtype=tf.float32,
                                   scope='RNN')

    # predict next word
    outputs = tf.reshape(outputs, [-1, num_cells])
    logits = slim.fully_connected(outputs,
                                  vocab_size,
                                  activation_fn=None,
                                  scope='logits')
    # compute loss
    batch_size = tf.shape(targets)[0]
    targets = tf.reshape(targets, [-1])
    mask = tf.cast(tf.not_equal(targets, pad_token), tf.float32)
    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                            labels=targets)

    if add_loss:
        loss = tf.div(tf.reduce_sum(losses * mask),
                      tf.reduce_sum(mask),
                      name='dec_loss')
        slim.losses.add_loss(loss)
        return tf.reshape(losses * mask, [batch_size, -1])
    else:
        mask = tf.reshape(mask, [batch_size, -1])
        losses = tf.reshape(losses, [batch_size, -1]) * mask
        return losses, mask
Ejemplo n.º 4
0
def _build_tf_beam_inference_decoder(glb_ctx, im, ans_embed, vocab_size,
                                     num_cells, start_token_id, pad_token):
    beam_size = 3
    # avoid out of range error
    vocab_size = max(vocab_size, pad_token + 1)
    # init state / image embedding
    init_h = slim.fully_connected(glb_ctx,
                                  num_cells,
                                  activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = tf.zeros_like(init_h)
    init_lstm_state = concat_op([init_c, init_h], axis=1)
    batch_size = tf.shape(init_h)[0]
    att_zero_state = tf.zeros([batch_size, num_cells], dtype=tf.float32)
    init_state = concat_op([init_lstm_state, att_zero_state], axis=1)

    # build LSTM cell and RNN
    lstm_cell = BasicLSTMCell(num_cells, state_is_tuple=False)

    # replicate context of the attention module
    im_shape = im.get_shape().as_list()[1:]
    im = tf.expand_dims(im, 1)  # add a time dim
    im = tf.reshape(tf.tile(im, [1, beam_size, 1, 1, 1]), [-1] + im_shape)
    ans_embed = tf.expand_dims(ans_embed, 1)
    ans_embed_dim = ans_embed.get_shape().as_list()[-1]
    ans_embed = tf.reshape(tf.tile(ans_embed, [1, beam_size, 1]),
                           [-1, ans_embed_dim])

    attention_cell = MultiModalAttentionCell(512, im, ans_embed, keep_prob=1.0)
    multi_cell = MultiRNNCell([lstm_cell, attention_cell],
                              state_is_tuple=False)

    # word embedding
    with tf.variable_scope('word_embedding'):
        word_map = tf.get_variable(name="word_map",
                                   shape=[vocab_size, num_cells],
                                   initializer=tf.random_uniform_initializer(
                                       -0.08, 0.08, dtype=tf.float32))

    # apply weights for outputs
    with tf.variable_scope('logits'):
        weights = tf.get_variable('weights',
                                  shape=[num_cells, vocab_size],
                                  dtype=tf.float32)
        biases = tf.get_variable('biases', shape=[vocab_size])

    # define helper functions
    def _tokens_to_inputs_fn(inputs):
        inputs = tf.nn.embedding_lookup(word_map, inputs)
        # inputs = tf.squeeze(inputs, [1])
        return inputs

    def _output_to_score_fn(hidden):
        batch_size = tf.shape(hidden)[0]
        beam_size = tf.shape(hidden)[1]
        hidden = tf.reshape(hidden, [batch_size * beam_size, -1])
        logits = tf.nn.xw_plus_b(hidden, weights, biases)
        logprob = tf.nn.log_softmax(logits)
        return tf.reshape(logprob, [batch_size, beam_size, -1])

    stop_token_id = VOCAB_CONFIG.end_token_id
    batch_size = tf.shape(glb_ctx)[0]
    start_tokens = tf.ones(shape=[batch_size], dtype=tf.int32) * start_token_id
    init_inputs = _tokens_to_inputs_fn(start_tokens)
    pathes, scores = beam_decoder(multi_cell,
                                  beam_size=beam_size,
                                  stop_token=stop_token_id,
                                  initial_state=init_state,
                                  initial_input=init_inputs,
                                  tokens_to_inputs_fn=_tokens_to_inputs_fn,
                                  outputs_to_score_fn=_output_to_score_fn,
                                  max_len=20,
                                  output_dense=True,
                                  scope='RNN')
    return scores, pathes
def _build_mixer_inference_decoder(glb_ctx, im, ans_embed, vocab_size,
                                   num_cells, capt, capt_len, pad_token,
                                   n_free_run_steps):
    keep_prob = 0.7
    # avoid out of range error
    vocab_size = max(vocab_size, pad_token + 1)
    # init state / image embedding
    init_h = slim.fully_connected(glb_ctx,
                                  num_cells,
                                  activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = tf.zeros_like(init_h)
    init_lstm_state = LSTMStateTuple(init_c, init_h)

    # zero attention cell state
    batch_size = tf.shape(init_h)[0]
    att_zero_state = tf.zeros([batch_size, num_cells], dtype=tf.float32)

    # build LSTM cell and RNN
    lstm_cell = create_drop_lstm_cell(num_cells,
                                      input_keep_prob=keep_prob,
                                      output_keep_prob=keep_prob,
                                      cell_fn=BasicLSTMCell)
    attention_cell = MultiModalAttentionCell(512, im, ans_embed, keep_prob=1.0)
    attention_cell = DropoutWrapper(attention_cell,
                                    input_keep_prob=1.0,
                                    output_keep_prob=keep_prob)
    multi_cell = MultiRNNCell([lstm_cell, attention_cell], state_is_tuple=True)

    # word embedding
    with tf.variable_scope('word_embedding'):
        word_map = tf.get_variable(name="word_map",
                                   shape=[vocab_size, num_cells],
                                   initializer=tf.random_uniform_initializer(
                                       -0.08, 0.08, dtype=tf.float32))

    # run teacher forcing steps
    inputs = tf.nn.embedding_lookup(word_map, capt)
    if type(n_free_run_steps) != tf.Tensor:
        n_free_run_steps = tf.constant(n_free_run_steps, dtype=tf.int32)

    n_teacher_forcing_steps = tf.maximum(capt_len - n_free_run_steps,
                                         tf.constant(0, dtype=tf.int32))

    def gather_by_col(arr, cols):
        batch_size = tf.shape(arr)[0]
        num_cols = tf.shape(arr)[1]
        index = tf.range(batch_size) * num_cols + cols
        arr = tf.reshape(arr, [-1])
        return tf.reshape(tf.gather(arr, index), [batch_size])

    free_run_start_tokens = gather_by_col(capt, n_teacher_forcing_steps)

    with tf.variable_scope('RNN'):
        with tf.variable_scope('multi_rnn_cell'):
            with tf.variable_scope('cell_0') as sc:
                _, lstm_state = tf.nn.dynamic_rnn(
                    lstm_cell,
                    inputs,
                    n_teacher_forcing_steps,
                    initial_state=init_lstm_state,
                    dtype=tf.float32,
                    scope=sc)
                dummy_hideen = lstm_state[0]
            with tf.variable_scope('cell_1'):
                attention_cell(dummy_hideen, att_zero_state)
    init_state = (lstm_state, att_zero_state)

    # fetch start tokens

    # apply weights for outputs
    with tf.variable_scope('logits'):
        weights = tf.get_variable('weights',
                                  shape=[num_cells, vocab_size],
                                  dtype=tf.float32)
        biases = tf.get_variable('biases', shape=[vocab_size])
        softmax_params = [weights, biases]

    path = create_greedy_decoder(init_state, multi_cell, word_map,
                                 softmax_params, free_run_start_tokens)
    return path, free_run_start_tokens