def vanilla_lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, is_batchnorm=False, gamma=None, beta=None, name=None): """LSTM Cell symbol""" i2h = mx.sym.FullyConnected(data=indata, weight=param.i2h_weight, bias=param.i2h_bias, num_hidden=num_hidden * 4, name="t%d_l%d_i2h" % (seqidx, layeridx)) if is_batchnorm: if name is not None: i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name) else: i2h = batchnorm(net=i2h, gamma=gamma, beta=beta) h2h = mx.sym.FullyConnected(data=prev_state.h, weight=param.h2h_weight, bias=param.h2h_bias, num_hidden=num_hidden * 4, name="t%d_l%d_h2h" % (seqidx, layeridx)) gates = i2h + h2h slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, name="t%d_l%d_slice" % (seqidx, layeridx)) in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") return LSTMState(c=next_c, h=next_h)
def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., num_hidden_proj=0, is_batchnorm=False, gamma=None, beta=None, name=None): """LSTM Cell symbol""" # dropout input if dropout > 0.: indata = mx.sym.Dropout(data=indata, p=dropout) i2h = mx.sym.FullyConnected(data=indata, weight=param.i2h_weight, bias=param.i2h_bias, num_hidden=num_hidden * 4, name="t%d_l%d_i2h" % (seqidx, layeridx)) if is_batchnorm: if name is not None: i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name) else: i2h = batchnorm(net=i2h, gamma=gamma, beta=beta) h2h = mx.sym.FullyConnected(data=prev_state.h, weight=param.h2h_weight, # bias=param.h2h_bias, no_bias=True, num_hidden=num_hidden * 4, name="t%d_l%d_h2h" % (seqidx, layeridx)) gates = i2h + h2h slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, name="t%d_l%d_slice" % (seqidx, layeridx)) Wcidc = mx.sym.broadcast_mul(param.c2i_bias, prev_state.c) + slice_gates[0] in_gate = mx.sym.Activation(Wcidc, act_type="sigmoid") in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") Wcfdc = mx.sym.broadcast_mul(param.c2f_bias, prev_state.c) + slice_gates[2] forget_gate = mx.sym.Activation(Wcfdc, act_type="sigmoid") next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) Wcoct = mx.sym.broadcast_mul(param.c2o_bias, next_c) + slice_gates[3] out_gate = mx.sym.Activation(Wcoct, act_type="sigmoid") next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") if num_hidden_proj > 0: proj_next_h = mx.sym.FullyConnected(data=next_h, weight=param.ph2h_weight, no_bias=True, num_hidden=num_hidden_proj, name="t%d_l%d_ph2h" % (seqidx, layeridx)) return LSTMState(c=next_c, h=proj_next_h) else: return LSTMState(c=next_c, h=next_h)
def gru(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., is_batchnorm=False, gamma=None, beta=None, name=None): """ GRU Cell symbol Reference: * Chung, Junyoung, et al. "Empirical evaluation of gated recurrent neural networks on sequence modeling." arXiv preprint arXiv:1412.3555 (2014). """ if dropout > 0.: indata = mx.sym.Dropout(data=indata, p=dropout) i2h = mx.sym.FullyConnected(data=indata, weight=param.gates_i2h_weight, bias=param.gates_i2h_bias, num_hidden=num_hidden * 2, name="t%d_l%d_gates_i2h" % (seqidx, layeridx)) if is_batchnorm: if name is not None: i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name) else: i2h = batchnorm(net=i2h, gamma=gamma, beta=beta) h2h = mx.sym.FullyConnected(data=prev_state.h, weight=param.gates_h2h_weight, bias=param.gates_h2h_bias, num_hidden=num_hidden * 2, name="t%d_l%d_gates_h2h" % (seqidx, layeridx)) gates = i2h + h2h slice_gates = mx.sym.SliceChannel(gates, num_outputs=2, name="t%d_l%d_slice" % (seqidx, layeridx)) update_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") reset_gate = mx.sym.Activation(slice_gates[1], act_type="sigmoid") # The transform part of GRU is a little magic htrans_i2h = mx.sym.FullyConnected(data=indata, weight=param.trans_i2h_weight, bias=param.trans_i2h_bias, num_hidden=num_hidden, name="t%d_l%d_trans_i2h" % (seqidx, layeridx)) h_after_reset = prev_state.h * reset_gate htrans_h2h = mx.sym.FullyConnected(data=h_after_reset, weight=param.trans_h2h_weight, bias=param.trans_h2h_bias, num_hidden=num_hidden, name="t%d_l%d_trans_h2h" % (seqidx, layeridx)) h_trans = htrans_i2h + htrans_h2h h_trans_active = mx.sym.Activation(h_trans, act_type="tanh") next_h = prev_state.h + update_gate * (h_trans_active - prev_state.h) return GRUState(h=next_h)
def gru(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., is_batchnorm=False, gamma=None, beta=None): """ GRU Cell symbol Reference: * Chung, Junyoung, et al. "Empirical evaluation of gated recurrent neural networks on sequence modeling." arXiv preprint arXiv:1412.3555 (2014). """ if dropout > 0.: indata = mx.sym.Dropout(data=indata, p=dropout) i2h = mx.sym.FullyConnected(data=indata, weight=param.gates_i2h_weight, bias=param.gates_i2h_bias, num_hidden=num_hidden * 2, name="t%d_l%d_gates_i2h" % (seqidx, layeridx)) if is_batchnorm: i2h = batchnorm(net=i2h, gamma=gamma, beta=beta) h2h = mx.sym.FullyConnected(data=prev_state.h, weight=param.gates_h2h_weight, bias=param.gates_h2h_bias, num_hidden=num_hidden * 2, name="t%d_l%d_gates_h2h" % (seqidx, layeridx)) gates = i2h + h2h slice_gates = mx.sym.SliceChannel(gates, num_outputs=2, name="t%d_l%d_slice" % (seqidx, layeridx)) update_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") reset_gate = mx.sym.Activation(slice_gates[1], act_type="sigmoid") # The transform part of GRU is a little magic htrans_i2h = mx.sym.FullyConnected(data=indata, weight=param.trans_i2h_weight, bias=param.trans_i2h_bias, num_hidden=num_hidden, name="t%d_l%d_trans_i2h" % (seqidx, layeridx)) h_after_reset = prev_state.h * reset_gate htrans_h2h = mx.sym.FullyConnected(data=h_after_reset, weight=param.trans_h2h_weight, bias=param.trans_h2h_bias, num_hidden=num_hidden, name="t%d_l%d_trans_i2h" % (seqidx, layeridx)) h_trans = htrans_i2h + htrans_h2h h_trans_active = mx.sym.Activation(h_trans, act_type="tanh") next_h = prev_state.h + update_gate * (h_trans_active - prev_state.h) return GRUState(h=next_h)
def arch(args, seq_len=None): """ define deep speech 2 network """ if isinstance(args, argparse.Namespace): mode = args.config.get("common", "mode") if mode == "train": channel_num = args.config.getint("arch", "channel_num") conv_layer1_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple( json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple( json.loads(args.config.get("arch", "conv_layer2_stride"))) rnn_type = args.config.get("arch", "rnn_type") num_rnn_layer = args.config.getint("arch", "num_rnn_layer") num_hidden_rnn_list = json.loads( args.config.get("arch", "num_hidden_rnn_list")) is_batchnorm = args.config.getboolean("arch", "is_batchnorm") is_bucketing = args.config.getboolean("arch", "is_bucketing") if seq_len is None: seq_len = args.config.getint('arch', 'max_t_count') num_label = args.config.getint('arch', 'max_label_length') num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers") num_hidden_rear_fc_list = json.loads( args.config.get("arch", "num_hidden_rear_fc_list")) act_type_rear_fc_list = json.loads( args.config.get("arch", "act_type_rear_fc_list")) # model symbol generation # input preparation data = mx.sym.Variable('data') label = mx.sym.Variable('label') net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0)) net = conv(net=net, channels=channel_num, filter_dimension=conv_layer1_filter_dim, stride=conv_layer1_stride, no_bias=is_batchnorm, name='conv1') if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net, name="conv1_batchnorm") net = conv(net=net, channels=channel_num, filter_dimension=conv_layer2_filter_dim, stride=conv_layer2_stride, no_bias=is_batchnorm, name='conv2') if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net, name="conv2_batchnorm") net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3)) net = mx.sym.Reshape(data=net, shape=(0, 0, -3)) seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor( (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 net = slice_symbol_to_seq_symobls( net=net, seq_len=seq_len_after_conv_layer2, axis=1) if rnn_type == "bilstm": net = bi_lstm_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_lstm_list=num_hidden_rnn_list, num_lstm_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) elif rnn_type == "gru": net = gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) elif rnn_type == "bigru": net = bi_gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) else: raise Exception( 'rnn_type should be one of the followings, bilstm,gru,bigru' ) # rear fc layers net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, num_layer=num_rear_fc_layers, prefix="rear", num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list, is_batchnorm=is_batchnorm) # warpctc layer net = warpctc_layer( net=net, seq_len=seq_len_after_conv_layer2, label=label, num_label=num_label, character_classes_count=( args.config.getint('arch', 'n_classes') + 1)) args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) return net elif mode == 'load' or mode == 'predict': conv_layer1_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple( json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple( json.loads(args.config.get("arch", "conv_layer2_stride"))) if seq_len is None: seq_len = args.config.getint('arch', 'max_t_count') seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor( (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) else: raise Exception( 'mode must be the one of the followings - train,predict,load') else: raise Exception( 'type of args should be one of the argparse.' + 'Namespace for fixed length model or integer for variable length model' )
def arch(args): mode = args.config.get("common", "mode") if mode == "train": channel_num = args.config.getint("arch", "channel_num") conv_layer1_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride"))) rnn_type = args.config.get("arch", "rnn_type") num_rnn_layer = args.config.getint("arch", "num_rnn_layer") num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list")) is_batchnorm = args.config.getboolean("arch", "is_batchnorm") seq_len = args.config.getint('arch', 'max_t_count') num_label = args.config.getint('arch', 'max_label_length') num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers") num_hidden_rear_fc_list = json.loads(args.config.get("arch", "num_hidden_rear_fc_list")) act_type_rear_fc_list = json.loads(args.config.get("arch", "act_type_rear_fc_list")) # model symbol generation # input preparation data = mx.sym.Variable('data') label = mx.sym.Variable('label') net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0)) net = conv(net=net, channels=channel_num, filter_dimension=conv_layer1_filter_dim, stride=conv_layer1_stride, no_bias=is_batchnorm ) if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net) net = conv(net=net, channels=channel_num, filter_dimension=conv_layer2_filter_dim, stride=conv_layer2_stride, no_bias=is_batchnorm ) if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net) net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3)) net = mx.sym.Reshape(data=net, shape=(0, 0, -3)) seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1) if rnn_type == "bilstm": net = bi_lstm_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_lstm_list=num_hidden_rnn_list, num_lstm_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm) elif rnn_type == "gru": net = gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm) elif rnn_type == "bigru": net = bi_gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm) else: raise Exception('rnn_type should be one of the followings, bilstm,gru,bigru') # rear fc layers net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, num_layer=num_rear_fc_layers, prefix="rear", num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list, is_batchnorm=is_batchnorm) if is_batchnorm: hidden_all = [] # batch norm normalizes axis 1 for seq_index in range(seq_len_after_conv_layer2): hidden = net[seq_index] hidden = batchnorm(hidden) hidden_all.append(hidden) net = hidden_all # warpctc layer net = warpctc_layer(net=net, seq_len=seq_len_after_conv_layer2, label=label, num_label=num_label, character_classes_count=(args.config.getint('arch', 'n_classes') + 1) ) args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) return net else: conv_layer1_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride"))) seq_len = args.config.getint('arch', 'max_t_count') seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
def sequence_fc( net, seq_len, num_layer, prefix, num_hidden_list=[], act_type_list=[], is_batchnorm=False, dropout_rate=0, ): if num_layer == len(num_hidden_list) == len(act_type_list): if num_layer > 0: weight_list = [] bias_list = [] for layer_index in range(num_layer): weight_list.append( mx.sym.Variable(name='%s_sequence_fc%d_weight' % (prefix, layer_index))) # if you use batchnorm bias do not have any effect if not is_batchnorm: bias_list.append( mx.sym.Variable(name='%s_sequence_fc%d_bias' % (prefix, layer_index))) # batch normalization parameters gamma_list = [] beta_list = [] if is_batchnorm: for layer_index in range(num_layer): gamma_list.append( mx.sym.Variable(name='%s_sequence_fc%d_gamma' % (prefix, layer_index))) beta_list.append( mx.sym.Variable(name='%s_sequence_fc%d_beta' % (prefix, layer_index))) # batch normalization parameters ends if type(net) is mx.symbol.Symbol: net = mx.sym.SliceChannel(data=net, num_outputs=seq_len, axis=1, squeeze_axis=1) elif type(net) is list: for net_index, one_net in enumerate(net): if type(one_net) is not mx.symbol.Symbol: raise Exception( '%d th elements of the net should be mx.symbol.Symbol' % net_index) else: raise Exception( 'type of net should be whether mx.symbol.Symbol or list of mx.symbol.Symbol' ) hidden_all = [] for seq_index in range(seq_len): hidden = net[seq_index] for layer_index in range(num_layer): if dropout_rate > 0: hidden = mx.sym.Dropout(data=hidden, p=dropout_rate) if is_batchnorm: hidden = fc(net=hidden, num_hidden=num_hidden_list[layer_index], act_type=None, weight=weight_list[layer_index], no_bias=is_batchnorm, name="%s_t%d_l%d_fc" % (prefix, seq_index, layer_index)) # last layer doesn't have batchnorm hidden = batchnorm(net=hidden, gamma=gamma_list[layer_index], beta=beta_list[layer_index], name="%s_t%d_l%d_batchnorm" % (prefix, seq_index, layer_index)) hidden = mx.sym.Activation( data=hidden, act_type=act_type_list[layer_index], name="%s_t%d_l%d_activation" % (prefix, seq_index, layer_index)) else: hidden = fc(net=hidden, num_hidden=num_hidden_list[layer_index], act_type=act_type_list[layer_index], weight=weight_list[layer_index], bias=bias_list[layer_index]) hidden_all.append(hidden) net = hidden_all return net else: raise Exception("length doesn't met - num_layer:", num_layer, ",len(num_hidden_list):", len(num_hidden_list), ",len(act_type_list):", len(act_type_list))
def sequence_fc(net, seq_len, num_layer, prefix, num_hidden_list=[], act_type_list=[], is_batchnorm=False, dropout_rate=0, ): if num_layer == len(num_hidden_list) == len(act_type_list): if num_layer > 0: weight_list = [] bias_list = [] for layer_index in range(num_layer): weight_list.append(mx.sym.Variable(name='%s_sequence_fc%d_weight' % (prefix, layer_index))) # if you use batchnorm bias do not have any effect if not is_batchnorm: bias_list.append(mx.sym.Variable(name='%s_sequence_fc%d_bias' % (prefix, layer_index))) # batch normalization parameters gamma_list = [] beta_list = [] if is_batchnorm: for layer_index in range(num_layer): gamma_list.append(mx.sym.Variable(name='%s_sequence_fc%d_gamma' % (prefix, layer_index))) beta_list.append(mx.sym.Variable(name='%s_sequence_fc%d_beta' % (prefix, layer_index))) # batch normalization parameters ends if type(net) is mx.symbol.Symbol: net = mx.sym.SliceChannel(data=net, num_outputs=seq_len, axis=1, squeeze_axis=1) elif type(net) is list: for net_index, one_net in enumerate(net): if type(one_net) is not mx.symbol.Symbol: raise Exception('%d th elements of the net should be mx.symbol.Symbol' % net_index) else: raise Exception('type of net should be whether mx.symbol.Symbol or list of mx.symbol.Symbol') hidden_all = [] for seq_index in range(seq_len): hidden = net[seq_index] for layer_index in range(num_layer): if dropout_rate > 0: hidden = mx.sym.Dropout(data=hidden, p=dropout_rate) if is_batchnorm: hidden = fc(net=hidden, num_hidden=num_hidden_list[layer_index], act_type=None, weight=weight_list[layer_index], no_bias=is_batchnorm, name="%s_t%d_l%d_fc" % (prefix, seq_index, layer_index) ) # last layer doesn't have batchnorm hidden = batchnorm(net=hidden, gamma=gamma_list[layer_index], beta=beta_list[layer_index], name="%s_t%d_l%d_batchnorm" % (prefix, seq_index, layer_index)) hidden = mx.sym.Activation(data=hidden, act_type=act_type_list[layer_index], name="%s_t%d_l%d_activation" % (prefix, seq_index, layer_index)) else: hidden = fc(net=hidden, num_hidden=num_hidden_list[layer_index], act_type=act_type_list[layer_index], weight=weight_list[layer_index], bias=bias_list[layer_index] ) hidden_all.append(hidden) net = hidden_all return net else: raise Exception("length doesn't met - num_layer:", num_layer, ",len(num_hidden_list):", len(num_hidden_list), ",len(act_type_list):", len(act_type_list) )
def arch(args, seq_len=None): """ define deep speech 2 network """ if isinstance(args, argparse.Namespace): mode = args.config.get("common", "mode") is_bucketing = args.config.getboolean("arch", "is_bucketing") if mode == "train" or is_bucketing: channel_num = args.config.getint("arch", "channel_num") conv_layer1_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride"))) rnn_type = args.config.get("arch", "rnn_type") num_rnn_layer = args.config.getint("arch", "num_rnn_layer") num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list")) is_batchnorm = args.config.getboolean("arch", "is_batchnorm") if seq_len is None: seq_len = args.config.getint('arch', 'max_t_count') num_label = args.config.getint('arch', 'max_label_length') num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers") num_hidden_rear_fc_list = json.loads(args.config.get("arch", "num_hidden_rear_fc_list")) act_type_rear_fc_list = json.loads(args.config.get("arch", "act_type_rear_fc_list")) # model symbol generation # input preparation data = mx.sym.Variable('data') label = mx.sym.Variable('label') net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0)) net = conv(net=net, channels=channel_num, filter_dimension=conv_layer1_filter_dim, stride=conv_layer1_stride, no_bias=is_batchnorm, name='conv1') if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net, name="conv1_batchnorm") net = conv(net=net, channels=channel_num, filter_dimension=conv_layer2_filter_dim, stride=conv_layer2_stride, no_bias=is_batchnorm, name='conv2') if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net, name="conv2_batchnorm") net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3)) net = mx.sym.Reshape(data=net, shape=(0, 0, -3)) seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1) if rnn_type == "bilstm": net = bi_lstm_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_lstm_list=num_hidden_rnn_list, num_lstm_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) elif rnn_type == "gru": net = gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) elif rnn_type == "bigru": net = bi_gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) else: raise Exception('rnn_type should be one of the followings, bilstm,gru,bigru') # rear fc layers net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, num_layer=num_rear_fc_layers, prefix="rear", num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list, is_batchnorm=is_batchnorm) # warpctc layer net = warpctc_layer(net=net, seq_len=seq_len_after_conv_layer2, label=label, num_label=num_label, character_classes_count= (args.config.getint('arch', 'n_classes') + 1)) args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) return net elif mode == 'load' or mode == 'predict': conv_layer1_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride"))) if seq_len is None: seq_len = args.config.getint('arch', 'max_t_count') seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) else: raise Exception('mode must be the one of the followings - train,predict,load')
def arch(args, seq_len=None): """ define deep speech 2 network """ if isinstance(args, argparse.Namespace): mode = args.config.get("common", "mode") is_bucketing = args.config.getboolean("arch", "is_bucketing") if mode == "train" or is_bucketing: channel_num = args.config.getint("arch", "channel_num") conv_layer1_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple( json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple( json.loads(args.config.get("arch", "conv_layer2_stride"))) rnn_type = args.config.get("arch", "rnn_type") num_rnn_layer = args.config.getint("arch", "num_rnn_layer") num_hidden_proj = args.config.getint("arch", "num_hidden_proj") num_hidden_rnn_list = json.loads( args.config.get("arch", "num_hidden_rnn_list")) is_batchnorm = args.config.getboolean("arch", "is_batchnorm") if seq_len is None: seq_len = args.config.getint('arch', 'max_t_count') num_label = args.config.getint('arch', 'max_label_length') num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers") num_hidden_rear_fc_list = json.loads( args.config.get("arch", "num_hidden_rear_fc_list")) act_type_rear_fc_list = json.loads( args.config.get("arch", "act_type_rear_fc_list")) # model symbol generation # input preparation data = mx.sym.Variable('data') net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0)) net = conv(net=net, channels=channel_num, filter_dimension=conv_layer1_filter_dim, stride=conv_layer1_stride, no_bias=is_batchnorm, name='conv1') if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net, name="conv1_batchnorm") net = conv(net=net, channels=channel_num, filter_dimension=conv_layer2_filter_dim, stride=conv_layer2_stride, no_bias=is_batchnorm, name='conv2') if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net, name="conv2_batchnorm") net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3)) net = mx.sym.Reshape(data=net, shape=(0, 0, -3)) seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor( (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 net = slice_symbol_to_seq_symobls( net=net, seq_len=seq_len_after_conv_layer2, axis=1) if rnn_type == "bilstm": net, f_states, b_states = bi_lstm_unroll( net=net, seq_len=seq_len_after_conv_layer2, num_hidden_lstm_list=num_hidden_rnn_list, num_lstm_layer=num_rnn_layer, dropout=0., num_hidden_proj=num_hidden_proj, is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) elif rnn_type == "gru": net = gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) elif rnn_type == "bigru": net = bi_gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) else: raise Exception( 'rnn_type should be one of the followings, bilstm,gru,bigru' ) # rear fc layers net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, num_layer=num_rear_fc_layers, prefix="rear", num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list, is_batchnorm=is_batchnorm) cls_weight = mx.sym.Variable("cls_weight") cls_bias = mx.sym.Variable("cls_bias") fc_seq = [] character_classes_count = args.config.getint('arch', 'n_classes') + 1 for seqidx in range(seq_len_after_conv_layer2): hidden = net[seqidx] hidden = mx.sym.FullyConnected( data=hidden, num_hidden=character_classes_count, weight=cls_weight, bias=cls_bias) fc_seq.append(hidden) net = mx.sym.Concat(*fc_seq, dim=0, name="warpctc_layer_concat") if mode == 'server': sm = mx.sym.SoftmaxActivation(data=net, name='softmax') output = [sm] for state in b_states + f_states: output.append(state.c) output.append(state.h) return mx.sym.Group(output) label = mx.sym.Variable('label') # warpctc layer net = warpctc_layer(net=net, label=label, num_label=num_label, seq_len=seq_len_after_conv_layer2, character_classes_count=args.config.getint( 'arch', 'n_classes') + 1) args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) return net elif mode == 'load' or mode == 'predict': conv_layer1_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple( json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple( json.loads(args.config.get("arch", "conv_layer2_stride"))) if seq_len is None: seq_len = args.config.getint('arch', 'max_t_count') seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor( (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) else: raise Exception( 'mode must be the one of the followings - train,predict,load')
def arch(args): mode = args.config.get("common", "mode") if mode == "train": channel_num = args.config.getint("arch", "channel_num") conv_layer1_filter_dim = tuple( json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple( json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = tuple( json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple( json.loads(args.config.get("arch", "conv_layer2_stride"))) rnn_type = args.config.get("arch", "rnn_type") num_rnn_layer = args.config.getint("arch", "num_rnn_layer") num_hidden_rnn_list = json.loads( args.config.get("arch", "num_hidden_rnn_list")) is_batchnorm = args.config.getboolean("arch", "is_batchnorm") seq_len = args.config.getint('arch', 'max_t_count') num_label = args.config.getint('arch', 'max_label_length') num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers") num_hidden_rear_fc_list = json.loads( args.config.get("arch", "num_hidden_rear_fc_list")) act_type_rear_fc_list = json.loads( args.config.get("arch", "act_type_rear_fc_list")) # model symbol generation # input preparation data = mx.sym.Variable('data') label = mx.sym.Variable('label') net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0)) net = conv(net=net, channels=channel_num, filter_dimension=conv_layer1_filter_dim, stride=conv_layer1_stride, no_bias=is_batchnorm) if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net) net = conv(net=net, channels=channel_num, filter_dimension=conv_layer2_filter_dim, stride=conv_layer2_stride, no_bias=is_batchnorm) if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net) net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3)) net = mx.sym.Reshape(data=net, shape=(0, 0, -3)) seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor( (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1) if rnn_type == "bilstm": net = bi_lstm_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_lstm_list=num_hidden_rnn_list, num_lstm_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm) elif rnn_type == "gru": net = gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm) elif rnn_type == "bigru": net = bi_gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm) else: raise Exception( 'rnn_type should be one of the followings, bilstm,gru,bigru') # rear fc layers net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, num_layer=num_rear_fc_layers, prefix="rear", num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list, is_batchnorm=is_batchnorm) if is_batchnorm: hidden_all = [] # batch norm normalizes axis 1 for seq_index in range(seq_len_after_conv_layer2): hidden = net[seq_index] hidden = batchnorm(hidden) hidden_all.append(hidden) net = hidden_all # warpctc layer net = warpctc_layer( net=net, seq_len=seq_len_after_conv_layer2, label=label, num_label=num_label, character_classes_count=(args.config.getint('arch', 'n_classes') + 1)) args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) return net else: conv_layer1_filter_dim = tuple( json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple( json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = tuple( json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple( json.loads(args.config.get("arch", "conv_layer2_stride"))) seq_len = args.config.getint('arch', 'max_t_count') seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor( (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))