def _build_model(self): if self.network == CNNNetwork.CNN5: x = CNN5(inputs=self.inputs, utils=self.utils).build() elif self.network == CNNNetwork.ResNet: x = ResNet50(inputs=self.inputs, utils=self.utils).build() else: print('This cnn neural network is not supported at this time.') sys.exit(-1) shape_list = x.get_shape().as_list() self.seq_len = tf.fill([tf.shape(x)[0]], shape_list[1], name="seq_len") if self.recurrent == RecurrentNetwork.LSTM: recurrent_network_builder = LSTM(self.utils, x, self.seq_len) elif self.recurrent == RecurrentNetwork.BLSTM: recurrent_network_builder = BLSTM(self.utils, x, self.seq_len) elif self.recurrent == RecurrentNetwork.GRU: recurrent_network_builder = GRU(x, self.seq_len) elif self.recurrent == RecurrentNetwork.SRU: recurrent_network_builder = SRU(x, self.seq_len) elif self.recurrent == RecurrentNetwork.BSRU: recurrent_network_builder = BSRU(self.utils, x, self.seq_len) else: print( 'This recurrent neural network is not supported at this time.') sys.exit(-1) outputs = recurrent_network_builder.build() # Reshaping to apply the same weights over the time_steps outputs = tf.reshape(outputs, [-1, NUM_HIDDEN * 2]) with tf.variable_scope('output'): # tf.Variable weight_out = tf.get_variable( name='weight', shape=[ outputs.get_shape()[1] if self.network == CNNNetwork.ResNet else NUM_HIDDEN * 2, NUM_CLASSES ], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.1), # initializer=tf.glorot_uniform_initializer(), # initializer=tf.contrib.layers.xavier_initializer(), # initializer=tf.truncated_normal([NUM_HIDDEN, NUM_CLASSES], stddev=0.1), ) biases_out = tf.get_variable(name='biases', shape=[NUM_CLASSES], dtype=tf.float32, initializer=tf.constant_initializer( value=0, dtype=tf.float32)) # [batch_size * max_timesteps, num_classes] logits = tf.matmul(outputs, weight_out) + biases_out # Reshaping back to the original shape logits = tf.reshape(logits, [tf.shape(x)[0], -1, NUM_CLASSES]) # Time major predict = tf.transpose(logits, (1, 0, 2), "predict") self.predict = predict
def _build_model(self): """选择采用哪种卷积网络""" if self.network == CNNNetwork.CNN5: x = CNN5(model_conf=self.model_conf, inputs=self.inputs, utils=self.utils).build() elif self.network == CNNNetwork.CNNX: x = CNNX(model_conf=self.model_conf, inputs=self.inputs, utils=self.utils).build() elif self.network == CNNNetwork.ResNetTiny: x = ResNetTiny(model_conf=self.model_conf, inputs=self.inputs, utils=self.utils).build() elif self.network == CNNNetwork.ResNet50: x = ResNet50(model_conf=self.model_conf, inputs=self.inputs, utils=self.utils).build() elif self.network == CNNNetwork.DenseNet: x = DenseNet(model_conf=self.model_conf, inputs=self.inputs, utils=self.utils).build() else: raise ValueError('This cnn neural network is not supported at this time.') """选择采用哪种循环网络""" # time_major = True: [max_time_step, batch_size, num_classes] tf.compat.v1.logging.info("CNN Output: {}".format(x.get_shape())) self.seq_len = tf.fill([tf.shape(x)[0]], tf.shape(x)[1], name="seq_len") # self.labels_len = tf.fill([BATCH_SIZE], 12, name="labels_len") if self.recurrent == RecurrentNetwork.NoRecurrent: self.recurrent_network_builder = None elif self.recurrent == RecurrentNetwork.LSTM: self.recurrent_network_builder = LSTM(model_conf=self.model_conf, inputs=x, utils=self.utils) elif self.recurrent == RecurrentNetwork.BiLSTM: self.recurrent_network_builder = BiLSTM(model_conf=self.model_conf, inputs=x, utils=self.utils) elif self.recurrent == RecurrentNetwork.GRU: self.recurrent_network_builder = GRU(model_conf=self.model_conf, inputs=x, utils=self.utils) elif self.recurrent == RecurrentNetwork.BiGRU: self.recurrent_network_builder = BiGRU(model_conf=self.model_conf, inputs=x, utils=self.utils) elif self.recurrent == RecurrentNetwork.LSTMcuDNN: self.recurrent_network_builder = LSTMcuDNN(model_conf=self.model_conf, inputs=x, utils=self.utils) elif self.recurrent == RecurrentNetwork.BiLSTMcuDNN: self.recurrent_network_builder = BiLSTMcuDNN(model_conf=self.model_conf, inputs=x, utils=self.utils) elif self.recurrent == RecurrentNetwork.GRUcuDNN: self.recurrent_network_builder = GRUcuDNN(model_conf=self.model_conf, inputs=x, utils=self.utils) else: raise ValueError('This recurrent neural network is not supported at this time.') logits = self.recurrent_network_builder.build() if self.recurrent_network_builder else x if self.recurrent_network_builder and self.model_conf.loss_func != LossFunction.CTC: raise ValueError('CTC loss must use recurrent neural network.') """输出层,根据Loss函数区分""" with tf.keras.backend.name_scope('output'): if self.model_conf.loss_func == LossFunction.CTC: self.outputs = FullConnectedRNN(model_conf=self.model_conf, mode=self.mode, outputs=logits).build() elif self.model_conf.loss_func == LossFunction.CrossEntropy: self.outputs = FullConnectedCNN(model_conf=self.model_conf, mode=self.mode, outputs=logits).build() return self.outputs