def __call__(self, inputs, state):
     l_outputs, l_next_state = self.language_lstm(inputs, state)
     s_inputs = tf.concat([l_outputs, inputs], 1)
     image_height = tf.shape(self.spatial_image_features)[1]
     image_width = tf.shape(self.spatial_image_features)[2]
     image_features = collapse_dims(self.spatial_image_features, [1, 2])
     attn_inputs = tf.concat([ image_features, tile_with_new_axis(l_outputs, [
         image_height * image_width], [1]) ], 2)
     attended_features = tf.reduce_sum(image_features * self.attn_layer(attn_inputs), [1])
     return tf.concat([attended_features, l_outputs], 1), l_next_state
 def __call__(self, inputs, state):
     l_outputs, l_next_state = self.language_lstm(inputs, state)
     sentinel_embeddings = self.sentinel_embeddings_layer(
         tf.nn.tanh(l_next_state.c) *
         self.sentinel_gate_layer(tf.concat([state.h, inputs], 1)))
     image_height = tf.shape(self.spatial_image_features)[1]
     image_width = tf.shape(self.spatial_image_features)[2]
     image_features = collapse_dims(self.spatial_image_features, [1, 2])
     sentinel_image_features = tf.concat(
         [image_features,
          tf.expand_dims(sentinel_embeddings, 1)], 1)
     attn_inputs = tf.nn.tanh(
         tf.concat([
             sentinel_image_features,
             tile_with_new_axis(l_outputs, [image_height * image_width + 1],
                                [1])
         ], 2))
     attended_sif = tf.reduce_sum(
         sentinel_image_features * self.attn_layer(attn_inputs), [1])
     return tf.concat([attended_sif, l_outputs], 1), l_next_state
Example #3
0
 def __call__(self, 
         top_k_attributes,
         mean_image_features=None, 
         mean_object_features=None, 
         spatial_image_features=None, 
         spatial_object_features=None, 
         seq_inputs=None, lengths=None ):
     assert(mean_image_features is not None or mean_object_features is not None or
         spatial_image_features is not None or spatial_object_features is not None)
     attribute_features = tf.nn.embedding_lookup(self.attribute_embeddings_map, top_k_attributes)
     mean_attribute_features = tf.reduce_mean(attribute_features, [1])
     use_beam_search = (seq_inputs is None or lengths is None)
     if mean_image_features is not None:
         batch_size = tf.shape(mean_image_features)[0]
         mean_image_features = tf.concat([mean_image_features, mean_attribute_features], 1)
     elif mean_object_features is not None:
         batch_size = tf.shape(mean_object_features)[0]
         mean_object_features = tf.concat([mean_object_features, attribute_features], 1)
     elif spatial_image_features is not None:
         batch_size = tf.shape(spatial_image_features)[0]
         spatial_image_features = collapse_dims(spatial_image_features, [1, 2])
         mean_image_features = tf.concat([tf.reduce_mean(spatial_image_features, [1]), 
             mean_attribute_features], 1)
         spatial_image_features = tf.concat([spatial_image_features, attribute_features], 1)
     elif spatial_object_features is not None:
         batch_size = tf.shape(spatial_object_features)[0] 
         spatial_object_features = collapse_dims(spatial_object_features, [2, 3])
         mean_object_features = tf.concat([tf.reduce_mean(spatial_object_features, [2]), 
             attribute_features], 1)
         spatial_object_features = tf.concat([spatial_object_features, 
             tf.expand_dims(attribute_features, 2)], 2)
     initial_state = self.image_caption_cell.zero_state(batch_size, tf.float32)
     if use_beam_search:
         if mean_image_features is not None:
             mean_image_features = seq2seq.tile_batch(mean_image_features, 
                 multiplier=self.beam_size)
             self.image_caption_cell.mean_image_features = mean_image_features
         if mean_object_features is not None:
             mean_object_features = seq2seq.tile_batch(mean_object_features, 
                 multiplier=self.beam_size)
             self.image_caption_cell.mean_object_features = mean_object_features
         if spatial_image_features is not None:
             spatial_image_features = seq2seq.tile_batch(spatial_image_features, 
                 multiplier=self.beam_size)
             self.image_caption_cell.spatial_image_features = spatial_image_features
         if spatial_object_features is not None:
             spatial_object_features = seq2seq.tile_batch(spatial_object_features, 
                 multiplier=self.beam_size)
             self.image_caption_cell.spatial_object_features = spatial_object_features
         initial_state = seq2seq.tile_batch(initial_state, multiplier=self.beam_size)
         decoder = seq2seq.BeamSearchDecoder(self.image_caption_cell, self.word_embeddings_map, 
             tf.fill([batch_size], self.word_vocabulary.start_id), self.word_vocabulary.end_id, 
             initial_state, self.beam_size, output_layer=self.word_logits_layer)
         outputs, state, lengths = seq2seq.dynamic_decode(decoder, 
             maximum_iterations=self.maximum_iterations)
         ids = tf.transpose(outputs.predicted_ids, [0, 2, 1])
         sequence_length = tf.shape(ids)[2]
         flat_ids = tf.reshape(ids, [batch_size * self.beam_size, sequence_length])
         seq_inputs = tf.concat([
             tf.fill([batch_size * self.beam_size, 1], self.word_vocabulary.start_id), flat_ids], 1)
     if mean_image_features is not None:
         self.image_caption_cell.mean_image_features = mean_image_features
     if mean_object_features is not None:
         self.image_caption_cell.mean_object_features = mean_object_features
     if spatial_image_features is not None:
         self.image_caption_cell.spatial_image_features = spatial_image_features
     if spatial_object_features is not None:
         self.image_caption_cell.spatial_object_features = spatial_object_features   
     activations, _state = tf.nn.dynamic_rnn(self.image_caption_cell, 
         tf.nn.embedding_lookup(self.word_embeddings_map, seq_inputs),
         sequence_length=tf.reshape(lengths, [-1]), initial_state=initial_state)
     logits = self.word_logits_layer(activations)
     if use_beam_search:
         length = tf.shape(logits)[1]
         logits = tf.reshape(logits, [batch_size, self.beam_size, length, self.vocab_size])
     return logits, tf.argmax(logits, axis=-1, output_type=tf.int32)