def __call__(self, inputs, state): l_outputs, l_next_state = self.language_lstm(inputs, state) s_inputs = tf.concat([l_outputs, inputs], 1) image_height = tf.shape(self.spatial_image_features)[1] image_width = tf.shape(self.spatial_image_features)[2] image_features = collapse_dims(self.spatial_image_features, [1, 2]) attn_inputs = tf.concat([ image_features, tile_with_new_axis(l_outputs, [ image_height * image_width], [1]) ], 2) attended_features = tf.reduce_sum(image_features * self.attn_layer(attn_inputs), [1]) return tf.concat([attended_features, l_outputs], 1), l_next_state
def __call__(self, inputs, state): l_outputs, l_next_state = self.language_lstm(inputs, state) sentinel_embeddings = self.sentinel_embeddings_layer( tf.nn.tanh(l_next_state.c) * self.sentinel_gate_layer(tf.concat([state.h, inputs], 1))) image_height = tf.shape(self.spatial_image_features)[1] image_width = tf.shape(self.spatial_image_features)[2] image_features = collapse_dims(self.spatial_image_features, [1, 2]) sentinel_image_features = tf.concat( [image_features, tf.expand_dims(sentinel_embeddings, 1)], 1) attn_inputs = tf.nn.tanh( tf.concat([ sentinel_image_features, tile_with_new_axis(l_outputs, [image_height * image_width + 1], [1]) ], 2)) attended_sif = tf.reduce_sum( sentinel_image_features * self.attn_layer(attn_inputs), [1]) return tf.concat([attended_sif, l_outputs], 1), l_next_state
def __call__(self, top_k_attributes, mean_image_features=None, mean_object_features=None, spatial_image_features=None, spatial_object_features=None, seq_inputs=None, lengths=None ): assert(mean_image_features is not None or mean_object_features is not None or spatial_image_features is not None or spatial_object_features is not None) attribute_features = tf.nn.embedding_lookup(self.attribute_embeddings_map, top_k_attributes) mean_attribute_features = tf.reduce_mean(attribute_features, [1]) use_beam_search = (seq_inputs is None or lengths is None) if mean_image_features is not None: batch_size = tf.shape(mean_image_features)[0] mean_image_features = tf.concat([mean_image_features, mean_attribute_features], 1) elif mean_object_features is not None: batch_size = tf.shape(mean_object_features)[0] mean_object_features = tf.concat([mean_object_features, attribute_features], 1) elif spatial_image_features is not None: batch_size = tf.shape(spatial_image_features)[0] spatial_image_features = collapse_dims(spatial_image_features, [1, 2]) mean_image_features = tf.concat([tf.reduce_mean(spatial_image_features, [1]), mean_attribute_features], 1) spatial_image_features = tf.concat([spatial_image_features, attribute_features], 1) elif spatial_object_features is not None: batch_size = tf.shape(spatial_object_features)[0] spatial_object_features = collapse_dims(spatial_object_features, [2, 3]) mean_object_features = tf.concat([tf.reduce_mean(spatial_object_features, [2]), attribute_features], 1) spatial_object_features = tf.concat([spatial_object_features, tf.expand_dims(attribute_features, 2)], 2) initial_state = self.image_caption_cell.zero_state(batch_size, tf.float32) if use_beam_search: if mean_image_features is not None: mean_image_features = seq2seq.tile_batch(mean_image_features, multiplier=self.beam_size) self.image_caption_cell.mean_image_features = mean_image_features if mean_object_features is not None: mean_object_features = seq2seq.tile_batch(mean_object_features, multiplier=self.beam_size) self.image_caption_cell.mean_object_features = mean_object_features if spatial_image_features is not None: spatial_image_features = seq2seq.tile_batch(spatial_image_features, multiplier=self.beam_size) self.image_caption_cell.spatial_image_features = spatial_image_features if spatial_object_features is not None: spatial_object_features = seq2seq.tile_batch(spatial_object_features, multiplier=self.beam_size) self.image_caption_cell.spatial_object_features = spatial_object_features initial_state = seq2seq.tile_batch(initial_state, multiplier=self.beam_size) decoder = seq2seq.BeamSearchDecoder(self.image_caption_cell, self.word_embeddings_map, tf.fill([batch_size], self.word_vocabulary.start_id), self.word_vocabulary.end_id, initial_state, self.beam_size, output_layer=self.word_logits_layer) outputs, state, lengths = seq2seq.dynamic_decode(decoder, maximum_iterations=self.maximum_iterations) ids = tf.transpose(outputs.predicted_ids, [0, 2, 1]) sequence_length = tf.shape(ids)[2] flat_ids = tf.reshape(ids, [batch_size * self.beam_size, sequence_length]) seq_inputs = tf.concat([ tf.fill([batch_size * self.beam_size, 1], self.word_vocabulary.start_id), flat_ids], 1) if mean_image_features is not None: self.image_caption_cell.mean_image_features = mean_image_features if mean_object_features is not None: self.image_caption_cell.mean_object_features = mean_object_features if spatial_image_features is not None: self.image_caption_cell.spatial_image_features = spatial_image_features if spatial_object_features is not None: self.image_caption_cell.spatial_object_features = spatial_object_features activations, _state = tf.nn.dynamic_rnn(self.image_caption_cell, tf.nn.embedding_lookup(self.word_embeddings_map, seq_inputs), sequence_length=tf.reshape(lengths, [-1]), initial_state=initial_state) logits = self.word_logits_layer(activations) if use_beam_search: length = tf.shape(logits)[1] logits = tf.reshape(logits, [batch_size, self.beam_size, length, self.vocab_size]) return logits, tf.argmax(logits, axis=-1, output_type=tf.int32)