def bluche(input_size, d_model, learning_rate): """ Gated Convolucional Recurrent Neural Network by Bluche et al. Reference: Bluche, T., Messina, R.: Gated convolutional recurrent neural networks for multilingual handwriting recognition. In: Document Analysis and Recognition (ICDAR), 2017 14th IAPR International Conference on, vol. 1, pp. 646–651, 2017. URL: https://ieeexplore.ieee.org/document/8270042 """ input_data = Input(name="input", shape=input_size) cnn = Reshape((input_size[0] // 2, input_size[1] // 2, input_size[2] * 4))(input_data) cnn = Conv2D(filters=8, kernel_size=(3,3), strides=(1,1), padding="same", activation="tanh")(cnn) cnn = Conv2D(filters=16, kernel_size=(2,4), strides=(2,4), padding="same", activation="tanh")(cnn) cnn = GatedConv2D(filters=16, kernel_size=(3,3), strides=(1,1), padding="same")(cnn) cnn = Conv2D(filters=32, kernel_size=(3,3), strides=(1,1), padding="same", activation="tanh")(cnn) cnn = GatedConv2D(filters=32, kernel_size=(3,3), strides=(1,1), padding="same")(cnn) cnn = Conv2D(filters=64, kernel_size=(2,4), strides=(2,4), padding="same", activation="tanh")(cnn) cnn = GatedConv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding="same")(cnn) cnn = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding="same", activation="tanh")(cnn) cnn = MaxPooling2D(pool_size=(1,4), strides=(1,4), padding="valid")(cnn) shape = cnn.get_shape() blstm = Reshape((shape[1], shape[2] * shape[3]))(cnn) blstm = Bidirectional(LSTM(units=128, return_sequences=True))(blstm) blstm = Dense(units=128, activation="tanh")(blstm) blstm = Bidirectional(LSTM(units=128, return_sequences=True))(blstm) output_data = Dense(units=d_model, activation="softmax")(blstm) if learning_rate is None: learning_rate = 4e-4 optimizer = RMSprop(learning_rate=learning_rate) return (input_data, output_data, optimizer)
def bluche(input_size, output_size, learning_rate=4e-4): """ Gated Convolucional Recurrent Neural Network by Bluche et al. Reference: Bluche, T., Messina, R.: Gated convolutional recurrent neural networks for multilingual handwriting recognition. In: Document Analysis and Recognition (ICDAR), 2017 14th IAPR International Conference on, vol. 1, pp. 646–651, 2017. URL: https://ieeexplore.ieee.org/document/8270042 Moysset, B. and Messina, R.: Are 2D-LSTM really dead for offline text recognition? In: International Journal on Document Analysis and Recognition (IJDAR) Springer Science and Business Media LLC URL: http://dx.doi.org/10.1007/s10032-019-00325-0 """ input_data = Input(name="input", shape=input_size) cnn = Reshape((input_size[0] // 2, input_size[1] // 2, input_size[2] * 4))(input_data) cnn = Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding="same")(cnn) cnn = Activation(activation="tanh")(cnn) cnn = Conv2D(filters=16, kernel_size=(2, 4), strides=(2, 4), padding="same")(cnn) cnn = Activation(activation="tanh")(cnn) cnn = GatedConv2D(filters=16, kernel_size=(3, 3), strides=(1, 1), padding="same")(cnn) cnn = Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), padding="same")(cnn) cnn = Activation(activation="tanh")(cnn) cnn = GatedConv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), padding="same")(cnn) cnn = Conv2D(filters=64, kernel_size=(2, 4), strides=(2, 4), padding="same")(cnn) cnn = Activation(activation="tanh")(cnn) cnn = GatedConv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding="same")(cnn) cnn = Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1), padding="same")(cnn) cnn = Activation(activation="tanh")(cnn) cnn = MaxPooling2D(pool_size=(1, 4), strides=(1, 4), padding="valid")(cnn) shape = cnn.get_shape() blstm = Reshape((shape[1], shape[2] * shape[3]))(cnn) blstm = Bidirectional(LSTM(units=128, return_sequences=True, dropout=0.5))(blstm) blstm = Dense(units=128)(blstm) blstm = Activation(activation="tanh")(blstm) blstm = Bidirectional(LSTM(units=128, return_sequences=True, dropout=0.5))(blstm) blstm = Dense(units=output_size)(blstm) output_data = Activation(activation="softmax")(blstm) optimizer = RMSprop(learning_rate=learning_rate) return (input_data, output_data, optimizer)