コード例 #1
0
ファイル: model.py プロジェクト: Jeksonic/tf-crnn
def get_crnn_output(input_images, parameters: Params = None) -> tf.Tensor:
    """
    Creates the CRNN network and returns it's output.
    Passes the `input_images` through the network and returns its output

    :param input_images: images to process (B, H, W, C)
    :param parameters: parameters of the model (``Params``)
    :return: the output of the CRNN model
    """

    # params of the architecture
    cnn_features_list = parameters.cnn_features_list
    cnn_kernel_size = parameters.cnn_kernel_size
    cnn_pool_size = parameters.cnn_pool_size
    cnn_stride_size = parameters.cnn_stride_size
    cnn_batch_norm = parameters.cnn_batch_norm
    rnn_units = parameters.rnn_units

    # CNN layers
    cnn_params = zip(cnn_features_list, cnn_kernel_size, cnn_stride_size,
                     cnn_pool_size, cnn_batch_norm)
    conv_layers = [
        ConvBlock(ft, ks, ss, 'same', psz, bn)
        for ft, ks, ss, psz, bn in cnn_params
    ]

    x = conv_layers[0](input_images)
    for conv in conv_layers[1:]:
        x = conv(x)

    # Permutation and reshape
    x = Permute((2, 1, 3))(x)
    shape = x.get_shape().as_list()
    x = Reshape((shape[1], shape[2] * shape[3]))(x)  # [B, W, H*C]

    # RNN layers
    rnn_layers = [
        Bidirectional(
            LSTM(ru, dropout=0.5, return_sequences=True, time_major=False))
        for ru in rnn_units
    ]
    for rnn in rnn_layers:
        x = rnn(x)

    # Dense and softmax
    x = Dense(parameters.alphabet.n_classes)(x)
    net_output = Softmax()(x)

    return net_output
コード例 #2
0
ファイル: utils.py プロジェクト: mtn/keras-i-revnet
    def forward(self, inp):
        output = Permute((2, 3, 1))(inp)
        batch_size, s_height, s_width, s_depth = output.get_shape().as_list()
        d_depth = s_depth * self.block_size_sq
        d_height = int(s_height / self.block_size)

        t_1 = tf.split(
            output, compute_block_size_shapes(output.shape[2], self.block_size), axis=2
        )
        stack = [tf.reshape(t_t, (batch_size, d_height, d_depth)) for t_t in t_1]
        output = tf.stack(stack, axis=1)
        output = Permute((2, 1, 3))(output)
        output = Permute((3, 1, 2))(output)

        return output