コード例 #1
0
    def get_logits(self, image, is_train, **kwargs):
        """
        """
        image = tf.reshape(image, [-1, 100, 100, 1])

        # BCNN
        features = self._bcnn(image, is_train)
        assert features.get_shape()[1:] == (26, 26, 256)

        # AON
        features, clue = self._aon(features, is_train)
        assert features.get_shape()[1:] == (4, 23, 512)
        assert clue.get_shape()[1:] == (4, 23, 1)

        # FG
        features = tf.reduce_sum(features * clue, axis=1)
        features = tf.nn.tanh(features)
        assert features.get_shape()[1:] == (23, 512)

        # LSTM
        features = tf.transpose(features, [1, 0, 2], name='time_major')
        features = rnn_layer(features, None, self.rnn_size, 'lstm')
        logits, weights = attention_decoder(features, kwargs['label'],
                                            len(self.out_charset),
                                            self.rnn_size, is_train,
                                            self.FLAGS.label_maxlen)

        sequence_length = None

        return logits, sequence_length
コード例 #2
0
ファイル: FAN.py プロジェクト: EuphoriaYan/SATRN
    def get_logits(self, image, is_train, **kwargs):
        """
        """
        widths = tf.ones(tf.shape(image)[0],
                         dtype=tf.int32) * tf.shape(image)[2]
        features, sequence_length = self._convnet_layers(
            image, widths, is_train)
        features = tf.transpose(features, perm=[1, 0, 2], name='time_major')
        attention_states = rnn_layer(features,
                                     sequence_length,
                                     self.rnn_size,
                                     scope="rnn")
        logits, weights = attention_decoder(attention_states, kwargs['label'],
                                            len(self.out_charset),
                                            self.rnn_size, is_train,
                                            self.FLAGS.label_maxlen)

        return logits, sequence_length
コード例 #3
0
    def get_logits(self, image, is_train, **kwargs):
        """
        """
        widths = tf.ones(tf.shape(image)[0],
                         dtype=tf.int32) * tf.shape(image)[2]
        features, sequence_length = self._convnet_layers(
            image, widths, is_train)
        attention_states = rnn_layers(features,
                                      sequence_length,
                                      self.rnn_size,
                                      use_projection=True)
        attention_states = dense_layer(attention_states,
                                       self.rnn_size,
                                       name='att_state_dense')
        logits, weights = attention_decoder(attention_states,
                                            kwargs['label'],
                                            len(self.out_charset),
                                            self.rnn_size,
                                            is_train,
                                            self.FLAGS.label_maxlen,
                                            cell_type='gru')

        return logits, sequence_length