# labels_ph = tf.placeholder(tf.int64, [None])
input_x = tf.placeholder(tf.int32, [None, data_reader.sequence_length],
                         name="input_x")
labels_ph = tf.placeholder(tf.float32, [None, data_reader.num_classes],
                           name="input_y")

with tf.variable_scope('embedding') as scope:
    lookup_embedding = tf.Variable(tf.random_uniform(
        [data_reader.vocab_size, data_reader.embedding_size], -1.0, 1.0),
                                   name="W")
    embedded_chars = tf.nn.embedding_lookup(lookup_embedding, input_x)
    images_ph = tf.expand_dims(embedded_chars, -1)

# Build the aux nets.
with tf.variable_scope('glimpse_net'):
    gl = GlimpseNet(config, images_ph=images_ph)
with tf.variable_scope('loc_net'):
    loc_net = LocNet(config)

# number of examples
N = tf.shape(input_x)[0]
init_loc = tf.random_uniform((N, 2), minval=-1, maxval=1)
init_glimpse = gl(init_loc)
# Core network.
lstm_cell = rnn_cell.LSTMCell(config.cell_size, state_is_tuple=True)
init_state = lstm_cell.zero_state(N, tf.float32)
inputs = [init_glimpse]
inputs.extend([0] * (config.num_glimpses))
outputs, _ = seq2seq.rnn_decoder(inputs,
                                 init_state,
                                 lstm_cell,

# placeholders
images_ph = tf.placeholder(
    tf.float32,
    [None, config.original_size * config.original_size * config.num_channels])
labels_ph = tf.placeholder(tf.int64, [None])

# Monte Carlo sampling, duplicate M times, see Eqn (2)
images_expanded = tf.tile(images_ph, [config.M, 1])
labels_expanded = tf.tile(labels_ph, [config.M])

# Build the aux nets.
with tf.variable_scope('glimpse_net'):
    # gl = GlimpseNet(config, images_ph)
    gl = GlimpseNet(config, images_expanded)
with tf.variable_scope('loc_net'):
    loc_net = LocNet(config)

# number of examples
# N = tf.shape(images_ph)[0]
N = tf.shape(images_expanded)[0]
init_loc = tf.random_uniform((N, 2), minval=-1, maxval=1)
init_glimpse = gl(init_loc)
# Core network.
lstm_cell = rnn_cell.LSTMCell(config.cell_size, state_is_tuple=True)
init_state = lstm_cell.zero_state(N, tf.float32)
inputs = [init_glimpse]
inputs.extend([0] * (config.num_glimpses))
outputs, _ = seq2seq.rnn_decoder(inputs,
                                 init_state,