def build_decoder_rnn(self, first_step):
        """
        This function build a decoder
        if first_step is true, the state is initialized by mean context
        if first_step is not true, the states are placeholder, and should be assigned.
        """
        with tf.variable_scope("rnnlm"):
            flattened_ctx = tf.reshape(self.context, [self.batch_size, 196, 512])
            ctx_mean = tf.reduce_mean(flattened_ctx, 1)

            self.decoder_prev_word = tf.placeholder(tf.int32, [None])            
            if first_step:
                rnn_input = tf.nn.embedding_lookup(self.Wemb, tf.zeros([self.batch_size], tf.int32))
            else:
                rnn_input = tf.nn.embedding_lookup(self.Wemb, self.decoder_prev_word)

            tf.get_variable_scope().reuse_variables()
            if not first_step:
                initial_state = utils.get_placeholder_state(self.cell.state_size)
                self.decoder_flattened_state = utils.flatten_state(initial_state)
            else:
                initial_state = utils.get_initial_state(ctx_mean, self.cell.state_size)

            outputs, state = tf.contrib.legacy_seq2seq.attention_decoder([rnn_input], initial_state, flattened_ctx, self.cell, initial_state_attention = not first_step)
            logits = slim.fully_connected(outputs[0], self.vocab_size + 1, activation_fn = None, scope = 'logit')
            decoder_probs = tf.reshape(tf.nn.softmax(logits), [self.batch_size, self.vocab_size + 1])
            decoder_state = utils.flatten_state(state)

        # output the probability and flattened state to next time step
        return [decoder_probs, decoder_state]
    def build_generator(self):
        """
        Generator for generating captions
        Support sample max or sample from distribution
        No Beam search here; beam search is in decoder
        """
        # Variables for the sample setting
        self.sample_max = tf.Variable(True, trainable = False, name = "sample_max")
        self.sample_temperature = tf.Variable(1.0, trainable = False, name = "temperature")

        self.generator = []
        with tf.variable_scope("rnnlm"):
            flattened_ctx = tf.reshape(self.context, [self.batch_size, 196, 512])
            ctx_mean = tf.reduce_mean(flattened_ctx, 1)

            tf.get_variable_scope().reuse_variables()

            initial_state = utils.get_initial_state(ctx_mean, self.cell.state_size)

            #projected context
            # This is used in attention module; do this outside the loop to reduce redundant computations
            # with tf.variable_scope("attention"):
            if self.att_hid_size == 0:
                pctx = slim.fully_connected(flattened_ctx, 1, activation_fn = None, scope = 'ctx_att') # (batch) * 196 * 1
            else:
                pctx = slim.fully_connected(flattened_ctx, self.att_hid_size, activation_fn = None, scope = 'ctx_att') # (batch) * 196 * att_hid_size

            rnn_input = tf.nn.embedding_lookup(self.Wemb, tf.zeros([self.batch_size], tf.int32))

            prev_h = utils.last_hidden_vec(initial_state)

            self.g_alphas = []
            outputs = []
            state = initial_state
            for ind in range(MAX_STEPS):

                with tf.variable_scope("attention"):
                    alpha = self.get_alpha(prev_h, pctx)
                    self.g_alphas.append(alpha)
                    weighted_context = tf.reduce_sum(flattened_ctx * tf.expand_dims(alpha, 2), 1)

                output, state = self.cell(tf.concat(axis=1, values=[weighted_context, rnn_input]), state)
                outputs.append(output)
                prev_h = output

                # Get the input of next timestep
                prev_logit = slim.fully_connected(prev_h, self.vocab_size + 1, activation_fn = None, scope = 'logit')
                prev_symbol = tf.stop_gradient(tf.cond(self.sample_max,
                    lambda: tf.argmax(prev_logit, 1), # pick the word with largest probability as the input of next time step
                    lambda: tf.squeeze(
                        tf.multinomial(tf.nn.log_softmax(prev_logit) / self.sample_temperature, 1), 1))) # Sample from the distribution
                self.generator.append(prev_symbol)
                rnn_input = tf.nn.embedding_lookup(self.Wemb, prev_symbol)
            
            self.g_output = output = tf.reshape(tf.concat(axis=1, values=outputs), [-1, self.rnn_size]) # outputs[1:], because we don't calculate loss on time 0.
            self.g_logits = logits = slim.fully_connected(output, self.vocab_size + 1, activation_fn = None, scope = 'logit')
            self.g_probs = probs = tf.reshape(tf.nn.softmax(logits), [self.batch_size, MAX_STEPS, self.vocab_size + 1])

        self.generator = tf.transpose(tf.reshape(tf.concat(axis=0, values=self.generator), [MAX_STEPS, -1]))
    def build_model(self):
        with tf.name_scope("batch_size"):
            # Get batch_size from the first dimension of self.images
            self.batch_size = tf.shape(self.images)[0]
        with tf.variable_scope("rnnlm"):
            flattened_ctx = tf.reshape(self.context, [self.batch_size, 196, 512])
            ctx_mean = tf.reduce_mean(flattened_ctx, 1)
            
            # Initialize the first hidden state with the mean context
            initial_state = utils.get_initial_state(ctx_mean, self.cell.state_size)
            # Replicate self.seq_per_img times for each state and image embedding
            self.initial_state = initial_state = utils.expand_feat(initial_state, self.seq_per_img)
            self.flattened_ctx = flattened_ctx = tf.reshape(tf.tile(tf.expand_dims(flattened_ctx, 1), [1, self.seq_per_img, 1, 1]), 
                [self.batch_size * self.seq_per_img, 196, 512])

            rnn_inputs = tf.split(axis=1, num_or_size_splits=self.seq_length + 1, value=tf.nn.embedding_lookup(self.Wemb, self.labels[:,:self.seq_length + 1]))
            rnn_inputs = [tf.squeeze(input_, [1]) for input_ in rnn_inputs]

            outputs, last_state = tf.contrib.legacy_seq2seq.attention_decoder(rnn_inputs, initial_state, flattened_ctx, self.cell, loop_function=None)
            outputs = tf.concat(axis=0, values=outputs)

            self.logits = slim.fully_connected(outputs, self.vocab_size + 1, activation_fn = None, scope = 'logit')
            self.logits = tf.split(axis=0, num_or_size_splits=len(rnn_inputs), value=self.logits)

        with tf.variable_scope("loss"):
            loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(self.logits,
                    [tf.squeeze(label, [1]) for label in tf.split(axis=1, num_or_size_splits=self.seq_length + 1, value=self.labels[:, 1:])], # self.labels[:,1:] is the target
                    [tf.squeeze(mask, [1]) for mask in tf.split(axis=1, num_or_size_splits=self.seq_length + 1, value=self.masks[:, 1:])])
            self.cost = tf.reduce_mean(loss)

        self.final_state = last_state
        self.lr = tf.Variable(0.0, trainable=False)
        self.cnn_lr = tf.Variable(0.0, trainable=False)

        # Collect the rnn variables, and create the optimizer of rnn
        tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='rnnlm')
        grads = utils.clip_by_value(tf.gradients(self.cost, tvars), -self.opt.grad_clip, self.opt.grad_clip)
        #grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
        #        self.opt.grad_clip)
        optimizer = utils.get_optimizer(self.opt, self.lr)
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))

        # Collect the cnn variables, and create the optimizer of cnn
        cnn_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='cnn')
        cnn_grads = utils.clip_by_value(tf.gradients(self.cost, cnn_tvars), -self.opt.grad_clip, self.opt.grad_clip)
        #cnn_grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, cnn_tvars),
        #        self.opt.grad_clip)
        cnn_optimizer = utils.get_cnn_optimizer(self.opt, self.cnn_lr)
        self.cnn_train_op = cnn_optimizer.apply_gradients(zip(cnn_grads, cnn_tvars))

        tf.summary.scalar('training loss', self.cost)
        tf.summary.scalar('learning rate', self.lr)
        tf.summary.scalar('cnn learning rate', self.cnn_lr)
        self.summaries = tf.summary.merge_all()
    def build_decoder_rnn(self, first_step):
        """
        This function build a decoder
        if first_step is true, the state is initialized by mean context
        if first_step is not true, the states are placeholder, and should be assigned.
        """
        with tf.variable_scope("rnnlm"):
            flattened_ctx = tf.reshape(self.context,
                                       [self.batch_size, 196, 512])
            ctx_mean = tf.reduce_mean(flattened_ctx, 1)

            self.decoder_prev_word = tf.placeholder(tf.int32, [None])
            if first_step:
                rnn_input = tf.nn.embedding_lookup(
                    self.Wemb, tf.zeros([self.batch_size], tf.int32))
            else:
                rnn_input = tf.nn.embedding_lookup(self.Wemb,
                                                   self.decoder_prev_word)

            tf.get_variable_scope().reuse_variables()
            if not first_step:
                initial_state = utils.get_placeholder_state(
                    self.cell.state_size)
                self.decoder_flattened_state = utils.flatten_state(
                    initial_state)
            else:
                initial_state = utils.get_initial_state(
                    ctx_mean, self.cell.state_size)

            outputs, state = tf.contrib.legacy_seq2seq.attention_decoder(
                [rnn_input],
                initial_state,
                flattened_ctx,
                self.cell,
                initial_state_attention=not first_step)
            logits = slim.fully_connected(outputs[0],
                                          self.vocab_size + 1,
                                          activation_fn=None,
                                          scope='logit')
            decoder_probs = tf.reshape(tf.nn.softmax(logits),
                                       [self.batch_size, self.vocab_size + 1])
            decoder_state = utils.flatten_state(state)

        # output the probability and flattened state to next time step
        return [decoder_probs, decoder_state]
    def build_decoder_rnn(self, first_step):
        with tf.variable_scope("rnnlm"):
            flattened_ctx = tf.reshape(self.context, [self.batch_size, 196, 512])
            ctx_mean = tf.reduce_mean(flattened_ctx, 1)

            tf.get_variable_scope().reuse_variables()

            if not first_step:
                initial_state = utils.get_placeholder_state(self.cell.state_size)
                self.decoder_flattened_state = utils.flatten_state(initial_state)
            else:
                initial_state = utils.get_initial_state(ctx_mean, self.cell.state_size)

            self.decoder_prev_word = tf.placeholder(tf.int32, [None])

            if first_step:
                rnn_input = tf.nn.embedding_lookup(self.Wemb, tf.zeros([self.batch_size], tf.int32))
            else:
                rnn_input = tf.nn.embedding_lookup(self.Wemb, self.decoder_prev_word)

            #projected context
            # This is used in attention module; do this outside the loop to reduce redundant computations
            # with tf.variable_scope("attention"):
            if self.att_hid_size == 0:
                pctx = slim.fully_connected(flattened_ctx, 1, activation_fn = None, scope = 'ctx_att') # (batch * seq_per_img) * 196 * 1
            else:
                pctx = slim.fully_connected(flattened_ctx, self.att_hid_size, activation_fn = None, scope = 'ctx_att') # (batch * seq_per_img) * 196 * att_hid_size

            prev_h = utils.last_hidden_vec(initial_state)

            alphas = []
            outputs = []

            with tf.variable_scope("attention"):
                alpha = self.get_alpha(prev_h, pctx)
                alphas.append(alpha)
                weighted_context = tf.reduce_sum(flattened_ctx * tf.expand_dims(alpha, 2), 1)

            output, state = self.cell(tf.concat(axis=1, values=[weighted_context, rnn_input]), initial_state)
            logits = slim.fully_connected(output, self.vocab_size + 1, activation_fn = None, scope = 'logit')
            decoder_probs = tf.reshape(tf.nn.softmax(logits), [self.batch_size, self.vocab_size + 1])
            decoder_state = utils.flatten_state(state)
        return [decoder_probs, decoder_state]
    def build_generator(self):
        """
        Generator for generating captions
        Support sample max or sample from distribution
        No Beam search here; beam search is in decoder
        """
        # Variables for the sample setting
        self.sample_max = tf.Variable(True, trainable = False, name = "sample_max")
        self.sample_temperature = tf.Variable(1.0, trainable = False, name = "temperature")

        self.generator = []
        with tf.variable_scope("rnnlm") as rnnlm_scope:
            flattened_ctx = tf.reshape(self.context, [self.batch_size, 196, 512])
            ctx_mean = tf.reduce_mean(flattened_ctx, 1)

            tf.get_variable_scope().reuse_variables()

            initial_state = utils.get_initial_state(ctx_mean, self.cell.state_size)

            rnn_inputs = [tf.nn.embedding_lookup(self.Wemb, tf.zeros([self.batch_size], tf.int32))] + [0] * (MAX_STEPS - 1)

            # Always pick the word with largest probability as the input of next time step
            def loop(prev, i):
                with tf.variable_scope(rnnlm_scope):
                    prev = slim.fully_connected(prev, self.vocab_size + 1, activation_fn = None, scope = 'logit')                
                    prev_symbol = tf.stop_gradient(tf.cond(self.sample_max,
                        lambda: tf.argmax(prev, 1), # pick the word with largest probability as the input of next time step
                        lambda: tf.squeeze(
                            tf.multinomial(tf.nn.log_softmax(prev) / self.sample_temperature, 1), 1))) # Sample from the distribution
                    self.generator.append(prev_symbol)
                    return tf.nn.embedding_lookup(self.Wemb, prev_symbol)

            outputs, last_state = tf.contrib.legacy_seq2seq.attention_decoder(rnn_inputs, initial_state, flattened_ctx, self.cell, loop_function=loop)
            self.g_outputs = outputs = tf.reshape(tf.concat(axis=1, values=outputs), [-1, self.rnn_size]) 
            self.g_logits = logits = slim.fully_connected(outputs, self.vocab_size + 1, activation_fn = None, scope = 'logit')
            self.g_probs = probs = tf.reshape(tf.nn.softmax(logits), [self.batch_size, MAX_STEPS, self.vocab_size + 1])

        self.generator = tf.transpose(tf.reshape(tf.concat(axis=0, values=self.generator), [MAX_STEPS - 1, -1]))
    def build_model(self):
        with tf.name_scope("batch_size"):
            # Get batch_size from the first dimension of self.images
            self.batch_size = tf.shape(self.images)[0]
        with tf.variable_scope("rnnlm"):
            # Flatten the context
            flattened_ctx = tf.reshape(self.context,
                                       [self.batch_size, 196, 512])

            # Initialize the first hidden state with the mean context
            initial_state = utils.get_initial_state(self.fc7,
                                                    self.cell.state_size)
            # Replicate self.seq_per_img times for each state and image embedding
            self.initial_state = initial_state = utils.expand_feat(
                initial_state, self.seq_per_img)
            self.flattened_ctx = flattened_ctx = tf.reshape(
                tf.tile(tf.expand_dims(flattened_ctx, 1),
                        [1, self.seq_per_img, 1, 1]),
                [self.batch_size * self.seq_per_img, 196, 512])

            #projected context
            # This is used in attention module; do this outside the loop to reduce redundant computations
            # with tf.variable_scope("attention"):
            if self.att_hid_size == 0:
                pctx = slim.fully_connected(
                    self.flattened_ctx, 1, activation_fn=None,
                    scope='ctx_att')  # (batch * seq_per_img) * 196 * 1
            else:
                pctx = slim.fully_connected(
                    self.flattened_ctx,
                    self.att_hid_size,
                    activation_fn=None,
                    scope='ctx_att'
                )  # (batch * seq_per_img) * 196 * att_hid_size

            rnn_inputs = tf.split(axis=1,
                                  num_or_size_splits=self.seq_length + 1,
                                  value=tf.nn.embedding_lookup(
                                      self.Wemb,
                                      self.labels[:, :self.seq_length + 1]))
            rnn_inputs = [tf.squeeze(input_, [1]) for input_ in rnn_inputs]

            prev_h = utils.last_hidden_vec(initial_state)

            self.alphas = []
            self.logits = []
            outputs = []
            state = initial_state
            for ind in range(self.seq_length + 1):
                if ind > 0:
                    # Reuse the variables after the first timestep.
                    tf.get_variable_scope().reuse_variables()

                with tf.variable_scope("attention"):
                    alpha = self.get_alpha(prev_h, pctx)
                    self.alphas.append(alpha)
                    weighted_context = tf.reduce_sum(
                        flattened_ctx * tf.expand_dims(alpha, 2), 1)

                output, state = self.cell(
                    tf.concat(axis=1,
                              values=[weighted_context, rnn_inputs[ind]]),
                    state)
                # Save the current output for next time step attention
                prev_h = output
                # Get the score of each word in vocabulary, 0 is end token.
                self.logits.append(
                    slim.fully_connected(output,
                                         self.vocab_size + 1,
                                         activation_fn=None,
                                         scope='logit'))

        with tf.variable_scope("loss"):
            loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
                self.logits,
                [
                    tf.squeeze(label, [1])
                    for label in tf.split(axis=1,
                                          num_or_size_splits=self.seq_length +
                                          1,
                                          value=self.labels[:, 1:])
                ],  # self.labels[:,1:] is the target; ignore the first start token
                [
                    tf.squeeze(mask, [1])
                    for mask in tf.split(axis=1,
                                         num_or_size_splits=self.seq_length +
                                         1,
                                         value=self.masks[:, 1:])
                ])
            self.cost = tf.reduce_mean(loss)

        self.final_state = state
        self.lr = tf.Variable(0.0, trainable=False)
        self.cnn_lr = tf.Variable(0.0, trainable=False)

        # Collect the rnn variables, and create the optimizer of rnn
        tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope='rnnlm')
        grads = utils.clip_by_value(tf.gradients(self.cost, tvars),
                                    -self.opt.grad_clip, self.opt.grad_clip)
        optimizer = utils.get_optimizer(self.opt, self.lr)
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))

        # Collect the cnn variables, and create the optimizer of cnn
        cnn_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                      scope='cnn')
        cnn_grads = utils.clip_by_value(tf.gradients(self.cost, cnn_tvars),
                                        -self.opt.grad_clip,
                                        self.opt.grad_clip)
        cnn_optimizer = utils.get_cnn_optimizer(self.opt, self.cnn_lr)
        self.cnn_train_op = cnn_optimizer.apply_gradients(
            zip(cnn_grads, cnn_tvars))

        tf.summary.scalar('training loss', self.cost)
        tf.summary.scalar('learning rate', self.lr)
        tf.summary.scalar('cnn learning rate', self.cnn_lr)
        self.summaries = tf.summary.merge_all()
    def build_generator(self):
        """
        Generator for generating captions
        Support sample max or sample from distribution
        No Beam search here; beam search is in decoder
        """
        # Variables for the sample setting
        self.sample_max = tf.Variable(True, trainable=False, name="sample_max")
        self.sample_temperature = tf.Variable(1.0,
                                              trainable=False,
                                              name="temperature")

        self.generator = []
        with tf.variable_scope("rnnlm") as rnnlm_scope:
            flattened_ctx = tf.reshape(self.context,
                                       [self.batch_size, 196, 512])
            ctx_mean = tf.reduce_mean(flattened_ctx, 1)

            tf.get_variable_scope().reuse_variables()

            initial_state = utils.get_initial_state(ctx_mean,
                                                    self.cell.state_size)

            rnn_inputs = [
                tf.nn.embedding_lookup(self.Wemb,
                                       tf.zeros([self.batch_size], tf.int32))
            ] + [0] * (MAX_STEPS - 1)

            # Always pick the word with largest probability as the input of next time step
            def loop(prev, i):
                with tf.variable_scope(rnnlm_scope):
                    prev = slim.fully_connected(prev,
                                                self.vocab_size + 1,
                                                activation_fn=None,
                                                scope='logit')
                    prev_symbol = tf.stop_gradient(
                        tf.cond(
                            self.sample_max,
                            lambda: tf.argmax(
                                prev, 1
                            ),  # pick the word with largest probability as the input of next time step
                            lambda: tf.squeeze(
                                tf.multinomial(
                                    tf.nn.log_softmax(prev) / self.
                                    sample_temperature, 1), 1))
                    )  # Sample from the distribution
                    self.generator.append(prev_symbol)
                    return tf.nn.embedding_lookup(self.Wemb, prev_symbol)

            outputs, last_state = tf.contrib.legacy_seq2seq.attention_decoder(
                rnn_inputs,
                initial_state,
                flattened_ctx,
                self.cell,
                loop_function=loop)
            self.g_outputs = outputs = tf.reshape(
                tf.concat(axis=1, values=outputs), [-1, self.rnn_size])
            self.g_logits = logits = slim.fully_connected(outputs,
                                                          self.vocab_size + 1,
                                                          activation_fn=None,
                                                          scope='logit')
            self.g_probs = probs = tf.reshape(
                tf.nn.softmax(logits),
                [self.batch_size, MAX_STEPS, self.vocab_size + 1])

        self.generator = tf.transpose(
            tf.reshape(tf.concat(axis=0, values=self.generator),
                       [MAX_STEPS - 1, -1]))
    def build_model(self):
        with tf.name_scope("batch_size"):
            # Get batch_size from the first dimension of self.images
            self.batch_size = tf.shape(self.images)[0]
        with tf.variable_scope("rnnlm"):
            flattened_ctx = tf.reshape(self.context,
                                       [self.batch_size, 196, 512])
            ctx_mean = tf.reduce_mean(flattened_ctx, 1)

            # Initialize the first hidden state with the mean context
            initial_state = utils.get_initial_state(ctx_mean,
                                                    self.cell.state_size)
            # Replicate self.seq_per_img times for each state and image embedding
            self.initial_state = initial_state = utils.expand_feat(
                initial_state, self.seq_per_img)
            self.flattened_ctx = flattened_ctx = tf.reshape(
                tf.tile(tf.expand_dims(flattened_ctx, 1),
                        [1, self.seq_per_img, 1, 1]),
                [self.batch_size * self.seq_per_img, 196, 512])

            rnn_inputs = tf.split(axis=1,
                                  num_or_size_splits=self.seq_length + 1,
                                  value=tf.nn.embedding_lookup(
                                      self.Wemb,
                                      self.labels[:, :self.seq_length + 1]))
            rnn_inputs = [tf.squeeze(input_, [1]) for input_ in rnn_inputs]

            outputs, last_state = tf.contrib.legacy_seq2seq.attention_decoder(
                rnn_inputs,
                initial_state,
                flattened_ctx,
                self.cell,
                loop_function=None)
            outputs = tf.concat(axis=0, values=outputs)

            self.logits = slim.fully_connected(outputs,
                                               self.vocab_size + 1,
                                               activation_fn=None,
                                               scope='logit')
            self.logits = tf.split(axis=0,
                                   num_or_size_splits=len(rnn_inputs),
                                   value=self.logits)

        with tf.variable_scope("loss"):
            loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
                self.logits,
                [
                    tf.squeeze(label, [1])
                    for label in tf.split(axis=1,
                                          num_or_size_splits=self.seq_length +
                                          1,
                                          value=self.labels[:, 1:])
                ],  # self.labels[:,1:] is the target
                [
                    tf.squeeze(mask, [1])
                    for mask in tf.split(axis=1,
                                         num_or_size_splits=self.seq_length +
                                         1,
                                         value=self.masks[:, 1:])
                ])
            self.cost = tf.reduce_mean(loss)

        self.final_state = last_state
        self.lr = tf.Variable(0.0, trainable=False)
        self.cnn_lr = tf.Variable(0.0, trainable=False)

        # Collect the rnn variables, and create the optimizer of rnn
        tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope='rnnlm')
        grads = utils.clip_by_value(tf.gradients(self.cost, tvars),
                                    -self.opt.grad_clip, self.opt.grad_clip)
        #grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
        #        self.opt.grad_clip)
        optimizer = utils.get_optimizer(self.opt, self.lr)
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))

        # Collect the cnn variables, and create the optimizer of cnn
        cnn_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                      scope='cnn')
        cnn_grads = utils.clip_by_value(tf.gradients(self.cost, cnn_tvars),
                                        -self.opt.grad_clip,
                                        self.opt.grad_clip)
        #cnn_grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, cnn_tvars),
        #        self.opt.grad_clip)
        cnn_optimizer = utils.get_cnn_optimizer(self.opt, self.cnn_lr)
        self.cnn_train_op = cnn_optimizer.apply_gradients(
            zip(cnn_grads, cnn_tvars))

        tf.summary.scalar('training loss', self.cost)
        tf.summary.scalar('learning rate', self.lr)
        tf.summary.scalar('cnn learning rate', self.cnn_lr)
        self.summaries = tf.summary.merge_all()
    def build_model(self):
        with tf.name_scope("batch_size"):
            # Get batch_size from the first dimension of self.images
            self.batch_size = tf.shape(self.images)[0]
        with tf.variable_scope("rnnlm"):
            # Flatten the context
            flattened_ctx = tf.reshape(self.context, [self.batch_size, 196, 512])
            ctx_mean = tf.reduce_mean(flattened_ctx, 1)
            
            # Initialize the first hidden state with the mean context
            initial_state = utils.get_initial_state(ctx_mean, self.cell.state_size)
            # Replicate self.seq_per_img times for each state and image embedding
            self.initial_state = initial_state = utils.expand_feat(initial_state, self.seq_per_img)
            self.flattened_ctx = flattened_ctx = tf.reshape(tf.tile(tf.expand_dims(flattened_ctx, 1), [1, self.seq_per_img, 1, 1]), 
                [self.batch_size * self.seq_per_img, 196, 512])

            #projected context
            # This is used in attention module; do this outside the loop to reduce redundant computations
            # with tf.variable_scope("attention"):
            if self.att_hid_size == 0:
                pctx = slim.fully_connected(self.flattened_ctx, 1, activation_fn = None, scope = 'ctx_att') # (batch * seq_per_img) * 196 * 1
            else:
                pctx = slim.fully_connected(self.flattened_ctx, self.att_hid_size, activation_fn = None, scope = 'ctx_att') # (batch * seq_per_img) * 196 * att_hid_size

            rnn_inputs = tf.split(axis=1, num_or_size_splits=self.seq_length + 1, value=tf.nn.embedding_lookup(self.Wemb, self.labels[:,:self.seq_length + 1]))
            rnn_inputs = [tf.squeeze(input_, [1]) for input_ in rnn_inputs]

            prev_h = utils.last_hidden_vec(initial_state)

            self.alphas = []
            self.logits = []
            outputs = []
            state = initial_state
            for ind in range(self.seq_length + 1):
                if ind > 0:
                    # Reuse the variables after the first timestep.
                    tf.get_variable_scope().reuse_variables()

                with tf.variable_scope("attention"):
                    alpha = self.get_alpha(prev_h, pctx)
                    self.alphas.append(alpha)
                    weighted_context = tf.reduce_sum(flattened_ctx * tf.expand_dims(alpha, 2), 1)
                    
                output, state = self.cell(tf.concat(axis=1, values=[weighted_context, rnn_inputs[ind]]), state)
                # Save the current output for next time step attention
                prev_h = output
                # Get the score of each word in vocabulary, 0 is end token.
                self.logits.append(slim.fully_connected(output, self.vocab_size + 1, activation_fn = None, scope = 'logit'))
                
        with tf.variable_scope("loss"):
            loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
                    self.logits,
                    [tf.squeeze(label, [1]) for label in tf.split(axis=1, num_or_size_splits=self.seq_length + 1, value=self.labels[:, 1:])], # self.labels[:,1:] is the target; ignore the first start token
                    [tf.squeeze(mask, [1]) for mask in tf.split(axis=1, num_or_size_splits=self.seq_length + 1, value=self.masks[:, 1:])])
            self.cost = tf.reduce_mean(loss)

        self.final_state = state
        self.lr = tf.Variable(0.0, trainable=False)
        self.cnn_lr = tf.Variable(0.0, trainable=False)

        # Collect the rnn variables, and create the optimizer of rnn
        tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='rnnlm')
        grads = utils.clip_by_value(tf.gradients(self.cost, tvars), -self.opt.grad_clip, self.opt.grad_clip)
        optimizer = utils.get_optimizer(self.opt, self.lr)
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))

        # Collect the cnn variables, and create the optimizer of cnn
        cnn_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='cnn')
        cnn_grads = utils.clip_by_value(tf.gradients(self.cost, cnn_tvars), -self.opt.grad_clip, self.opt.grad_clip)
        cnn_optimizer = utils.get_cnn_optimizer(self.opt, self.cnn_lr) 
        self.cnn_train_op = cnn_optimizer.apply_gradients(zip(cnn_grads, cnn_tvars))

        tf.summary.scalar('training loss', self.cost)
        tf.summary.scalar('learning rate', self.lr)
        tf.summary.scalar('cnn learning rate', self.cnn_lr)
        self.summaries = tf.summary.merge_all()