def get_logits(self, image, is_train, **kwargs): """ """ image = tf.reshape(image, [-1, 100, 100, 1]) # BCNN features = self._bcnn(image, is_train) assert features.get_shape()[1:] == (26, 26, 256) # AON features, clue = self._aon(features, is_train) assert features.get_shape()[1:] == (4, 23, 512) assert clue.get_shape()[1:] == (4, 23, 1) # FG features = tf.reduce_sum(features * clue, axis=1) features = tf.nn.tanh(features) assert features.get_shape()[1:] == (23, 512) # LSTM features = tf.transpose(features, [1, 0, 2], name='time_major') features = rnn_layer(features, None, self.rnn_size, 'lstm') logits, weights = attention_decoder(features, kwargs['label'], len(self.out_charset), self.rnn_size, is_train, self.FLAGS.label_maxlen) sequence_length = None return logits, sequence_length
def get_logits(self, image, is_train, **kwargs): """ """ widths = tf.ones(tf.shape(image)[0], dtype=tf.int32) * tf.shape(image)[2] features, sequence_length = self._convnet_layers( image, widths, is_train) features = tf.transpose(features, perm=[1, 0, 2], name='time_major') attention_states = rnn_layer(features, sequence_length, self.rnn_size, scope="rnn") logits, weights = attention_decoder(attention_states, kwargs['label'], len(self.out_charset), self.rnn_size, is_train, self.FLAGS.label_maxlen) return logits, sequence_length
def get_logits(self, image, is_train, **kwargs): """ """ widths = tf.ones(tf.shape(image)[0], dtype=tf.int32) * tf.shape(image)[2] features, sequence_length = self._convnet_layers( image, widths, is_train) attention_states = rnn_layers(features, sequence_length, self.rnn_size, use_projection=True) attention_states = dense_layer(attention_states, self.rnn_size, name='att_state_dense') logits, weights = attention_decoder(attention_states, kwargs['label'], len(self.out_charset), self.rnn_size, is_train, self.FLAGS.label_maxlen, cell_type='gru') return logits, sequence_length