def _build_greedy_inference_decoder(in_embed, vocab_size, num_cells,
                                    start_token_id, keep_prob, epsion):
    vocab_size = max(vocab_size, vocab_size + 1)
    # init state / image embedding
    init_h = slim.fully_connected(in_embed,
                                  num_cells,
                                  activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = tf.zeros_like(init_h)
    init_state = LSTMStateTuple(init_c, init_h)

    # build LSTM cell and RNN
    lstm_cell = create_drop_lstm_cell(num_cells,
                                      input_keep_prob=keep_prob,
                                      output_keep_prob=keep_prob,
                                      cell_fn=BasicLSTMCell)

    # word embedding
    with tf.variable_scope('word_embedding'):
        word_map = tf.get_variable(name="word_map",
                                   shape=[vocab_size, num_cells],
                                   initializer=tf.random_uniform_initializer(
                                       -0.08, 0.08, dtype=tf.float32))

    # apply weights for outputs
    with tf.variable_scope('logits'):
        weights = tf.get_variable('weights',
                                  shape=[num_cells, vocab_size],
                                  dtype=tf.float32)
        biases = tf.get_variable('biases', shape=[vocab_size])
        softmax_params = [weights, biases]

    return create_greedy_decoder(init_state, lstm_cell, word_map,
                                 softmax_params, start_token_id, epsion)
Beispiel #2
0
def _build_training_decoder(glb_ctx, im, ans, inputs, length, targets, vocab_size,
                            num_cells, keep_prob, pad_token):
    # avoid out of range error
    vocab_size = max(vocab_size, pad_token + 1)
    # init state / image embedding
    init_h = slim.fully_connected(glb_ctx, num_cells, activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = tf.zeros_like(init_h)
    init_state_lstm = LSTMStateTuple(init_c, init_h)
    batch_size = tf.shape(init_h)[0]
    att_zero_state = tf.zeros([batch_size, num_cells], dtype=tf.float32)
    init_state = (init_state_lstm, att_zero_state)

    # word embedding
    with tf.variable_scope('word_embedding'):
        word_map = tf.get_variable(
            name="word_map",
            shape=[vocab_size, num_cells],
            initializer=tf.random_uniform_initializer(-0.08, 0.08,
                                                      dtype=tf.float32))
    inputs = tf.nn.embedding_lookup(word_map, inputs)

    # build LSTM cell and RNN
    lstm = create_drop_lstm_cell(num_cells, input_keep_prob=keep_prob,
                                 output_keep_prob=keep_prob,
                                 cell_fn=BasicLSTMCell)

    attention_cell = RerankAttentionCell(512, im, ans, keep_prob=keep_prob)
    attention_cell = DropoutWrapper(attention_cell, input_keep_prob=1.0,
                                    output_keep_prob=keep_prob)

    multi_cell = MultiRNNCell([lstm, attention_cell], state_is_tuple=True)
    outputs, final_states = tf.nn.dynamic_rnn(multi_cell, inputs, length,
                                              initial_state=init_state,
                                              dtype=tf.float32, scope='RNN')
    final_att_state = final_states[1]
    rerank_logits = slim.fully_connected(final_att_state, 1, activation_fn=None,
                                         scope='Softmax')
    # predict next word
    outputs = tf.reshape(outputs, [-1, num_cells])
    logits = slim.fully_connected(outputs, vocab_size, activation_fn=None,
                                  scope='logits')
    # compute loss
    batch_size = tf.shape(targets)[0]
    targets = tf.reshape(targets, [-1])
    mask = tf.cast(tf.not_equal(targets, pad_token), tf.float32)
    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                            labels=targets)

    mask = tf.reshape(mask, [batch_size, -1])
    losses = tf.reshape(losses, [batch_size, -1]) * mask

    # loss = tf.div(tf.reduce_sum(losses * mask), tf.reduce_sum(mask),
    #               name='dec_loss')
    # slim.losses.add_loss(loss)
    # return rerank_logits
    return rerank_logits, losses, mask
def _build_training_decoder(in_embed, inputs, length, targets, vocab_size,
                            num_cells, keep_prob, pad_token, add_loss):
    # avoid out of range error
    vocab_size = max(vocab_size, pad_token + 1)
    # init state / image embedding
    init_h = slim.fully_connected(in_embed,
                                  num_cells,
                                  activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = tf.zeros_like(init_h)
    init_state = LSTMStateTuple(init_c, init_h)

    # word embedding
    with tf.variable_scope('word_embedding', reuse=True):
        word_map = tf.get_variable(name="word_map",
                                   shape=[vocab_size, num_cells],
                                   initializer=tf.random_uniform_initializer(
                                       -0.08, 0.08, dtype=tf.float32))
    print('WordEmbed')
    print(word_map)
    inputs = tf.nn.embedding_lookup(word_map, inputs)

    # build LSTM cell and RNN
    lstm = create_drop_lstm_cell(num_cells,
                                 input_keep_prob=keep_prob,
                                 output_keep_prob=keep_prob,
                                 cell_fn=BasicLSTMCell)
    outputs, states = tf.nn.dynamic_rnn(lstm,
                                        inputs,
                                        length,
                                        initial_state=init_state,
                                        dtype=tf.float32,
                                        scope='RNN')

    # predict next word
    outputs = tf.reshape(outputs, [-1, num_cells])
    logits = slim.fully_connected(outputs,
                                  vocab_size,
                                  activation_fn=None,
                                  scope='logits')
    # compute loss
    batch_size = tf.shape(targets)[0]
    targets = tf.reshape(targets, [-1])
    mask = tf.cast(tf.not_equal(targets, pad_token), tf.float32)
    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                            labels=targets)
    if add_loss:
        loss = tf.div(tf.reduce_sum(losses * mask),
                      tf.reduce_sum(mask),
                      name='dec_loss')
        tf.losses.add_loss(loss)
        tf.summary.scalar('nll', loss)
    logits_norm = 1.0 / tf.reshape(tf.cast(length, tf.float32),
                                   [-1, 1])  # normalise by length
    # return tf.reshape(losses * mask, [batch_size, -1])
    return tf.reshape(losses * mask, [batch_size, -1]) * logits_norm
def _build_random_inference_decoder(glb_ctx, im, ans_embed, vocab_size,
                                    num_cells, start_token_id, pad_token):
    keep_prob = 0.7
    # avoid out of range error
    vocab_size = max(vocab_size, pad_token + 1)
    # init state / image embedding
    init_h = slim.fully_connected(glb_ctx,
                                  num_cells,
                                  activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = tf.zeros_like(init_h)
    init_lstm_state = LSTMStateTuple(init_c, init_h)
    batch_size = tf.shape(init_h)[0]
    att_zero_state = tf.zeros([batch_size, num_cells], dtype=tf.float32)
    init_state = (init_lstm_state, att_zero_state)

    # build LSTM cell and RNN
    lstm_cell = create_drop_lstm_cell(num_cells,
                                      input_keep_prob=keep_prob,
                                      output_keep_prob=keep_prob,
                                      cell_fn=BasicLSTMCell)
    attention_cell = MultiModalAttentionCell(512, im, ans_embed, keep_prob=1.0)
    attention_cell = DropoutWrapper(attention_cell,
                                    input_keep_prob=1.0,
                                    output_keep_prob=keep_prob)
    multi_cell = MultiRNNCell([lstm_cell, attention_cell], state_is_tuple=True)

    # word embedding
    with tf.variable_scope('word_embedding'):
        word_map = tf.get_variable(name="word_map",
                                   shape=[vocab_size, num_cells],
                                   initializer=tf.random_uniform_initializer(
                                       -0.08, 0.08, dtype=tf.float32))

    # apply weights for outputs
    with tf.variable_scope('logits'):
        weights = tf.get_variable('weights',
                                  shape=[num_cells, vocab_size],
                                  dtype=tf.float32)
        biases = tf.get_variable('biases', shape=[vocab_size])
        softmax_params = [weights, biases]

    return create_greedy_decoder(init_state, multi_cell, word_map,
                                 softmax_params, start_token_id)
def _build_actor_network(in_embed, inputs, length, targets, advantages,
                         vocab_size, num_cells, keep_prob, pad_token):
    # avoid out of range error
    vocab_size = max(vocab_size, pad_token + 1)
    # init state / image embedding
    init_h = slim.fully_connected(in_embed,
                                  num_cells,
                                  activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = tf.zeros_like(init_h)
    init_state = LSTMStateTuple(init_c, init_h)

    # word embedding
    with tf.variable_scope('word_embedding', reuse=True):
        word_map = tf.get_variable(name="word_map",
                                   shape=[vocab_size, num_cells],
                                   initializer=tf.random_uniform_initializer(
                                       -0.08, 0.08, dtype=tf.float32))
    inputs = tf.nn.embedding_lookup(word_map, inputs)

    # build LSTM cell and RNN
    lstm = create_drop_lstm_cell(num_cells,
                                 input_keep_prob=keep_prob,
                                 output_keep_prob=keep_prob,
                                 cell_fn=BasicLSTMCell)
    outputs, states = tf.nn.dynamic_rnn(lstm,
                                        inputs,
                                        length,
                                        initial_state=init_state,
                                        dtype=tf.float32,
                                        scope='RNN')

    # predict next word
    outputs = tf.reshape(outputs, [-1, num_cells])
    logits = slim.fully_connected(outputs,
                                  vocab_size,
                                  activation_fn=None,
                                  scope='logits')
    # compute loss
    targets = tf.reshape(targets, [-1])
    advantages = tf.reshape(advantages, [-1])
    mask = tf.cast(tf.not_equal(targets, pad_token), tf.float32)
    valid = tf.div(tf.reduce_sum(mask), tf.cast(tf.shape(mask)[0], tf.float32))
    tf.summary.scalar('mask_cnt', valid)
    tf.summary.scalar('advantage', tf.reduce_mean(advantages))
    #
    logprobs = -tf.nn.log_softmax(logits)
    one_hot_targets = tf.one_hot(targets, vocab_size, dtype=tf.float32)
    action_log_probs = tf.reduce_sum(logprobs * one_hot_targets, 1)
    loss = tf.div(tf.reduce_sum(action_log_probs * advantages * mask),
                  tf.reduce_sum(mask),
                  name='pg_loss')
    tf.losses.add_loss(loss)

    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                            labels=targets)
    loss2 = tf.div(tf.reduce_sum(losses * mask * advantages),
                   tf.reduce_sum(mask),
                   name='pg_loss_obv')

    tf.summary.scalar('pg_loss', loss)
    tf.summary.scalar('pg_loss_obv', loss2)
    return loss
def _build_training_decoder(glb_ctx, im, ans, inputs, length, targets,
                            vocab_size, num_cells, keep_prob, pad_token,
                            rewards, advantage, xe_mask):
    # avoid out of range error
    vocab_size = max(vocab_size, pad_token + 1)
    # init state / image embedding
    init_h = slim.fully_connected(glb_ctx,
                                  num_cells,
                                  activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = tf.zeros_like(init_h)
    init_state_lstm = LSTMStateTuple(init_c, init_h)
    batch_size = tf.shape(init_h)[0]
    att_zero_state = tf.zeros([batch_size, num_cells], dtype=tf.float32)
    init_state = (init_state_lstm, att_zero_state)

    # word embedding
    with tf.variable_scope('word_embedding'):
        word_map = tf.get_variable(name="word_map",
                                   shape=[vocab_size, num_cells],
                                   initializer=tf.random_uniform_initializer(
                                       -0.08, 0.08, dtype=tf.float32))
    inputs = tf.nn.embedding_lookup(word_map, inputs)

    # build LSTM cell and RNN
    lstm = create_drop_lstm_cell(num_cells,
                                 input_keep_prob=keep_prob,
                                 output_keep_prob=keep_prob,
                                 cell_fn=BasicLSTMCell)

    attention_cell = MultiModalAttentionCell(512, im, ans, keep_prob=keep_prob)
    attention_cell = DropoutWrapper(attention_cell,
                                    input_keep_prob=1.0,
                                    output_keep_prob=keep_prob)

    multi_cell = MultiRNNCell([lstm, attention_cell], state_is_tuple=True)
    outputs, _ = tf.nn.dynamic_rnn(multi_cell,
                                   inputs,
                                   length,
                                   initial_state=init_state,
                                   dtype=tf.float32,
                                   scope='RNN')

    # predict next word
    outputs = tf.reshape(outputs, [-1, num_cells])
    logits = slim.fully_connected(outputs,
                                  vocab_size,
                                  activation_fn=None,
                                  scope='logits')

    # compute loss
    batch_size = tf.shape(targets)[0]
    targets = tf.reshape(targets, [-1])
    valid_mask = tf.not_equal(targets, pad_token)
    valid_xe_mask = tf.cast(tf.logical_and(xe_mask, valid_mask), tf.float32)
    valid_rl_mask = tf.cast(
        tf.logical_and(tf.logical_not(xe_mask), valid_mask), tf.float32)

    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                            labels=targets)

    # compute cross entropy loss
    xe_loss = tf.div(tf.reduce_sum(losses * valid_xe_mask),
                     tf.reduce_sum(valid_xe_mask),
                     name='xe_loss')
    slim.losses.add_loss(xe_loss)

    # compute reinforce loss
    advantage = tf.reshape(advantage, [-1])
    actor_loss = tf.div(tf.reduce_sum(losses * valid_rl_mask * advantage),
                        tf.reduce_sum(valid_rl_mask),
                        name='actor_loss')
    slim.losses.add_loss(actor_loss)
    return tf.reshape(losses * tf.cast(valid_mask, tf.float32),
                      [batch_size, -1])
def _build_mixer_inference_decoder(glb_ctx, im, ans_embed, vocab_size,
                                   num_cells, capt, capt_len, pad_token,
                                   n_free_run_steps):
    keep_prob = 0.7
    # avoid out of range error
    vocab_size = max(vocab_size, pad_token + 1)
    # init state / image embedding
    init_h = slim.fully_connected(glb_ctx,
                                  num_cells,
                                  activation_fn=tf.nn.tanh,
                                  scope='init_h')
    init_c = tf.zeros_like(init_h)
    init_lstm_state = LSTMStateTuple(init_c, init_h)

    # zero attention cell state
    batch_size = tf.shape(init_h)[0]
    att_zero_state = tf.zeros([batch_size, num_cells], dtype=tf.float32)

    # build LSTM cell and RNN
    lstm_cell = create_drop_lstm_cell(num_cells,
                                      input_keep_prob=keep_prob,
                                      output_keep_prob=keep_prob,
                                      cell_fn=BasicLSTMCell)
    attention_cell = MultiModalAttentionCell(512, im, ans_embed, keep_prob=1.0)
    attention_cell = DropoutWrapper(attention_cell,
                                    input_keep_prob=1.0,
                                    output_keep_prob=keep_prob)
    multi_cell = MultiRNNCell([lstm_cell, attention_cell], state_is_tuple=True)

    # word embedding
    with tf.variable_scope('word_embedding'):
        word_map = tf.get_variable(name="word_map",
                                   shape=[vocab_size, num_cells],
                                   initializer=tf.random_uniform_initializer(
                                       -0.08, 0.08, dtype=tf.float32))

    # run teacher forcing steps
    inputs = tf.nn.embedding_lookup(word_map, capt)
    if type(n_free_run_steps) != tf.Tensor:
        n_free_run_steps = tf.constant(n_free_run_steps, dtype=tf.int32)

    n_teacher_forcing_steps = tf.maximum(capt_len - n_free_run_steps,
                                         tf.constant(0, dtype=tf.int32))

    def gather_by_col(arr, cols):
        batch_size = tf.shape(arr)[0]
        num_cols = tf.shape(arr)[1]
        index = tf.range(batch_size) * num_cols + cols
        arr = tf.reshape(arr, [-1])
        return tf.reshape(tf.gather(arr, index), [batch_size])

    free_run_start_tokens = gather_by_col(capt, n_teacher_forcing_steps)

    with tf.variable_scope('RNN'):
        with tf.variable_scope('multi_rnn_cell'):
            with tf.variable_scope('cell_0') as sc:
                _, lstm_state = tf.nn.dynamic_rnn(
                    lstm_cell,
                    inputs,
                    n_teacher_forcing_steps,
                    initial_state=init_lstm_state,
                    dtype=tf.float32,
                    scope=sc)
                dummy_hideen = lstm_state[0]
            with tf.variable_scope('cell_1'):
                attention_cell(dummy_hideen, att_zero_state)
    init_state = (lstm_state, att_zero_state)

    # fetch start tokens

    # apply weights for outputs
    with tf.variable_scope('logits'):
        weights = tf.get_variable('weights',
                                  shape=[num_cells, vocab_size],
                                  dtype=tf.float32)
        biases = tf.get_variable('biases', shape=[vocab_size])
        softmax_params = [weights, biases]

    path = create_greedy_decoder(init_state, multi_cell, word_map,
                                 softmax_params, free_run_start_tokens)
    return path, free_run_start_tokens