def arch(args, seq_len=None): """ define deep speech 2 network """ if isinstance(args, argparse.Namespace): mode = args.config.get("common", "mode") if mode == "train": channel_num = args.config.getint("arch", "channel_num") conv_layer1_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple( json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple( json.loads(args.config.get("arch", "conv_layer2_stride"))) rnn_type = args.config.get("arch", "rnn_type") num_rnn_layer = args.config.getint("arch", "num_rnn_layer") num_hidden_rnn_list = json.loads( args.config.get("arch", "num_hidden_rnn_list")) is_batchnorm = args.config.getboolean("arch", "is_batchnorm") is_bucketing = args.config.getboolean("arch", "is_bucketing") if seq_len is None: seq_len = args.config.getint('arch', 'max_t_count') num_label = args.config.getint('arch', 'max_label_length') num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers") num_hidden_rear_fc_list = json.loads( args.config.get("arch", "num_hidden_rear_fc_list")) act_type_rear_fc_list = json.loads( args.config.get("arch", "act_type_rear_fc_list")) # model symbol generation # input preparation data = mx.sym.Variable('data') label = mx.sym.Variable('label') net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0)) net = conv(net=net, channels=channel_num, filter_dimension=conv_layer1_filter_dim, stride=conv_layer1_stride, no_bias=is_batchnorm, name='conv1') if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net, name="conv1_batchnorm") net = conv(net=net, channels=channel_num, filter_dimension=conv_layer2_filter_dim, stride=conv_layer2_stride, no_bias=is_batchnorm, name='conv2') if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net, name="conv2_batchnorm") net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3)) net = mx.sym.Reshape(data=net, shape=(0, 0, -3)) seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor( (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 net = slice_symbol_to_seq_symobls( net=net, seq_len=seq_len_after_conv_layer2, axis=1) if rnn_type == "bilstm": net = bi_lstm_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_lstm_list=num_hidden_rnn_list, num_lstm_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) elif rnn_type == "gru": net = gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) elif rnn_type == "bigru": net = bi_gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) else: raise Exception( 'rnn_type should be one of the followings, bilstm,gru,bigru' ) # rear fc layers net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, num_layer=num_rear_fc_layers, prefix="rear", num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list, is_batchnorm=is_batchnorm) # warpctc layer net = warpctc_layer( net=net, seq_len=seq_len_after_conv_layer2, label=label, num_label=num_label, character_classes_count=( args.config.getint('arch', 'n_classes') + 1)) args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) return net elif mode == 'load' or mode == 'predict': conv_layer1_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple( json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple( json.loads(args.config.get("arch", "conv_layer2_stride"))) if seq_len is None: seq_len = args.config.getint('arch', 'max_t_count') seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor( (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) else: raise Exception( 'mode must be the one of the followings - train,predict,load') else: raise Exception( 'type of args should be one of the argparse.' + 'Namespace for fixed length model or integer for variable length model' )
def arch(args): mode = args.config.get("common", "mode") if mode == "train": channel_num = args.config.getint("arch", "channel_num") conv_layer1_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride"))) rnn_type = args.config.get("arch", "rnn_type") num_rnn_layer = args.config.getint("arch", "num_rnn_layer") num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list")) is_batchnorm = args.config.getboolean("arch", "is_batchnorm") seq_len = args.config.getint('arch', 'max_t_count') num_label = args.config.getint('arch', 'max_label_length') num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers") num_hidden_rear_fc_list = json.loads(args.config.get("arch", "num_hidden_rear_fc_list")) act_type_rear_fc_list = json.loads(args.config.get("arch", "act_type_rear_fc_list")) # model symbol generation # input preparation data = mx.sym.Variable('data') label = mx.sym.Variable('label') net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0)) net = conv(net=net, channels=channel_num, filter_dimension=conv_layer1_filter_dim, stride=conv_layer1_stride, no_bias=is_batchnorm ) if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net) net = conv(net=net, channels=channel_num, filter_dimension=conv_layer2_filter_dim, stride=conv_layer2_stride, no_bias=is_batchnorm ) if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net) net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3)) net = mx.sym.Reshape(data=net, shape=(0, 0, -3)) seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1) if rnn_type == "bilstm": net = bi_lstm_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_lstm_list=num_hidden_rnn_list, num_lstm_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm) elif rnn_type == "gru": net = gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm) elif rnn_type == "bigru": net = bi_gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm) else: raise Exception('rnn_type should be one of the followings, bilstm,gru,bigru') # rear fc layers net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, num_layer=num_rear_fc_layers, prefix="rear", num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list, is_batchnorm=is_batchnorm) if is_batchnorm: hidden_all = [] # batch norm normalizes axis 1 for seq_index in range(seq_len_after_conv_layer2): hidden = net[seq_index] hidden = batchnorm(hidden) hidden_all.append(hidden) net = hidden_all # warpctc layer net = warpctc_layer(net=net, seq_len=seq_len_after_conv_layer2, label=label, num_label=num_label, character_classes_count=(args.config.getint('arch', 'n_classes') + 1) ) args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) return net else: conv_layer1_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride"))) seq_len = args.config.getint('arch', 'max_t_count') seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
def arch(args, seq_len=None): """ define deep speech 2 network """ if isinstance(args, argparse.Namespace): mode = args.config.get("common", "mode") is_bucketing = args.config.getboolean("arch", "is_bucketing") if mode == "train" or is_bucketing: channel_num = args.config.getint("arch", "channel_num") conv_layer1_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride"))) rnn_type = args.config.get("arch", "rnn_type") num_rnn_layer = args.config.getint("arch", "num_rnn_layer") num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list")) is_batchnorm = args.config.getboolean("arch", "is_batchnorm") if seq_len is None: seq_len = args.config.getint('arch', 'max_t_count') num_label = args.config.getint('arch', 'max_label_length') num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers") num_hidden_rear_fc_list = json.loads(args.config.get("arch", "num_hidden_rear_fc_list")) act_type_rear_fc_list = json.loads(args.config.get("arch", "act_type_rear_fc_list")) # model symbol generation # input preparation data = mx.sym.Variable('data') label = mx.sym.Variable('label') net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0)) net = conv(net=net, channels=channel_num, filter_dimension=conv_layer1_filter_dim, stride=conv_layer1_stride, no_bias=is_batchnorm, name='conv1') if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net, name="conv1_batchnorm") net = conv(net=net, channels=channel_num, filter_dimension=conv_layer2_filter_dim, stride=conv_layer2_stride, no_bias=is_batchnorm, name='conv2') if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net, name="conv2_batchnorm") net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3)) net = mx.sym.Reshape(data=net, shape=(0, 0, -3)) seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1) if rnn_type == "bilstm": net = bi_lstm_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_lstm_list=num_hidden_rnn_list, num_lstm_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) elif rnn_type == "gru": net = gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) elif rnn_type == "bigru": net = bi_gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) else: raise Exception('rnn_type should be one of the followings, bilstm,gru,bigru') # rear fc layers net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, num_layer=num_rear_fc_layers, prefix="rear", num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list, is_batchnorm=is_batchnorm) # warpctc layer net = warpctc_layer(net=net, seq_len=seq_len_after_conv_layer2, label=label, num_label=num_label, character_classes_count= (args.config.getint('arch', 'n_classes') + 1)) args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) return net elif mode == 'load' or mode == 'predict': conv_layer1_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride"))) if seq_len is None: seq_len = args.config.getint('arch', 'max_t_count') seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) else: raise Exception('mode must be the one of the followings - train,predict,load')
def arch(args, seq_len=None): """ define deep speech 2 network """ if isinstance(args, argparse.Namespace): mode = args.config.get("common", "mode") is_bucketing = args.config.getboolean("arch", "is_bucketing") if mode == "train" or is_bucketing: channel_num = args.config.getint("arch", "channel_num") conv_layer1_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple( json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple( json.loads(args.config.get("arch", "conv_layer2_stride"))) rnn_type = args.config.get("arch", "rnn_type") num_rnn_layer = args.config.getint("arch", "num_rnn_layer") num_hidden_proj = args.config.getint("arch", "num_hidden_proj") num_hidden_rnn_list = json.loads( args.config.get("arch", "num_hidden_rnn_list")) is_batchnorm = args.config.getboolean("arch", "is_batchnorm") if seq_len is None: seq_len = args.config.getint('arch', 'max_t_count') num_label = args.config.getint('arch', 'max_label_length') num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers") num_hidden_rear_fc_list = json.loads( args.config.get("arch", "num_hidden_rear_fc_list")) act_type_rear_fc_list = json.loads( args.config.get("arch", "act_type_rear_fc_list")) # model symbol generation # input preparation data = mx.sym.Variable('data') net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0)) net = conv(net=net, channels=channel_num, filter_dimension=conv_layer1_filter_dim, stride=conv_layer1_stride, no_bias=is_batchnorm, name='conv1') if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net, name="conv1_batchnorm") net = conv(net=net, channels=channel_num, filter_dimension=conv_layer2_filter_dim, stride=conv_layer2_stride, no_bias=is_batchnorm, name='conv2') if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net, name="conv2_batchnorm") net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3)) net = mx.sym.Reshape(data=net, shape=(0, 0, -3)) seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor( (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 net = slice_symbol_to_seq_symobls( net=net, seq_len=seq_len_after_conv_layer2, axis=1) if rnn_type == "bilstm": net, f_states, b_states = bi_lstm_unroll( net=net, seq_len=seq_len_after_conv_layer2, num_hidden_lstm_list=num_hidden_rnn_list, num_lstm_layer=num_rnn_layer, dropout=0., num_hidden_proj=num_hidden_proj, is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) elif rnn_type == "gru": net = gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) elif rnn_type == "bigru": net = bi_gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) else: raise Exception( 'rnn_type should be one of the followings, bilstm,gru,bigru' ) # rear fc layers net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, num_layer=num_rear_fc_layers, prefix="rear", num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list, is_batchnorm=is_batchnorm) cls_weight = mx.sym.Variable("cls_weight") cls_bias = mx.sym.Variable("cls_bias") fc_seq = [] character_classes_count = args.config.getint('arch', 'n_classes') + 1 for seqidx in range(seq_len_after_conv_layer2): hidden = net[seqidx] hidden = mx.sym.FullyConnected( data=hidden, num_hidden=character_classes_count, weight=cls_weight, bias=cls_bias) fc_seq.append(hidden) net = mx.sym.Concat(*fc_seq, dim=0, name="warpctc_layer_concat") if mode == 'server': sm = mx.sym.SoftmaxActivation(data=net, name='softmax') output = [sm] for state in b_states + f_states: output.append(state.c) output.append(state.h) return mx.sym.Group(output) label = mx.sym.Variable('label') # warpctc layer net = warpctc_layer(net=net, label=label, num_label=num_label, seq_len=seq_len_after_conv_layer2, character_classes_count=args.config.getint( 'arch', 'n_classes') + 1) args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) return net elif mode == 'load' or mode == 'predict': conv_layer1_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple( json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple( json.loads(args.config.get("arch", "conv_layer2_stride"))) if seq_len is None: seq_len = args.config.getint('arch', 'max_t_count') seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor( (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) else: raise Exception( 'mode must be the one of the followings - train,predict,load')
def arch(args, seq_len=None): """ define deep speech 2 network """ if isinstance(args, argparse.Namespace): mode = args.config.get("common", "mode") is_bucketing = args.config.getboolean("arch", "is_bucketing") if mode == "train" or is_bucketing: # channel_num = args.config.getint("arch", "channel_num") conv_layer1_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple( json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple( json.loads(args.config.get("arch", "conv_layer2_stride"))) rnn_type = args.config.get("arch", "rnn_type") num_rnn_layer = args.config.getint("arch", "num_rnn_layer") num_hidden_proj = args.config.getint("arch", "num_hidden_proj") num_hidden_rnn_list = json.loads( args.config.get("arch", "num_hidden_rnn_list")) is_batchnorm = args.config.getboolean("arch", "is_batchnorm") fbank = args.config.getboolean("data", "fbank") if seq_len is None: seq_len = args.config.getint('arch', 'max_t_count') num_label = args.config.getint('arch', 'max_label_length') num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers") num_hidden_rear_fc_list = json.loads( args.config.get("arch", "num_hidden_rear_fc_list")) act_type_rear_fc_list = json.loads( args.config.get("arch", "act_type_rear_fc_list")) # model symbol generation # input preparation data = mx.sym.Variable('data') label = mx.sym.Variable('label') if not fbank: net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0)) else: net = data net = mx.sym.Convolution(data=net, kernel=(3, 3), pad=(1, 1), num_filter=64, no_bias=is_batchnorm, name="conv%s_%s" % (1, 1)) if is_batchnorm: net = mx.symbol.BatchNorm(data=net, name="bn%s_%s" % (1, 1)) net = mx.sym.Activation(data=net, act_type="relu", name="relu%s_%s" % (1, 1)) net = mx.sym.Convolution(data=net, kernel=(3, 3), pad=(1, 1), num_filter=64, no_bias=is_batchnorm, name="conv%s_%s" % (1, 2)) if is_batchnorm: net = mx.symbol.BatchNorm(data=net, name="bn%s_%s" % (1, 2)) net = mx.sym.Activation(data=net, act_type="relu", name="relu%s_%s" % (1, 2)) net = mx.sym.Pooling(data=net, pool_type="max", kernel=(3, 3), stride=(2, 2), name="pool%s" % (1)) net = mx.sym.Convolution(data=net, kernel=(3, 3), pad=(1, 1), num_filter=128, no_bias=is_batchnorm, name="conv%s_%s" % (2, 1)) if is_batchnorm: net = mx.symbol.BatchNorm(data=net, name="bn%s_%s" % (2, 1)) net = mx.sym.Activation(data=net, act_type="relu", name="relu%s_%s" % (2, 1)) net = mx.sym.Convolution(data=net, kernel=(3, 3), pad=(1, 1), num_filter=128, no_bias=is_batchnorm, name="conv%s_%s" % (2, 2)) if is_batchnorm: net = mx.symbol.BatchNorm(data=net, name="bn%s_%s" % (2, 2)) net = mx.sym.Activation(data=net, act_type="relu", name="relu%s_%s" % (2, 2)) net = mx.sym.Pooling(data=net, pool_type="max", kernel=(3, 3), stride=(2, 2), name="pool%s" % (2)) # net = conv(net=net, # channels=channel_num, # filter_dimension=conv_layer1_filter_dim, # stride=conv_layer1_stride, # no_bias=is_batchnorm, # name='conv1') # if is_batchnorm: # # batch norm normalizes axis 1 # net = batchnorm(net, name="conv1_batchnorm") # # net = conv(net=net, # channels=channel_num, # filter_dimension=conv_layer2_filter_dim, # stride=conv_layer2_stride, # no_bias=is_batchnorm, # name='conv2') # if is_batchnorm: # # batch norm normalizes axis 1 # net = batchnorm(net, name="conv2_batchnorm") # # net = conv(net=net, # channels=96, # filter_dimension=conv_layer3_filter_dim, # stride=conv_layer3_stride, # no_bias=is_batchnorm, # name='conv3') # if is_batchnorm: # # batch norm normalizes axis 1 # net = batchnorm(net, name="conv3_batchnorm") net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3)) net = mx.sym.Reshape(data=net, shape=(0, 0, -3)) seq_len_after_conv_layer1 = int(math.floor((seq_len - 3) / 2)) + 1 seq_len_after_conv = int( math.floor((seq_len_after_conv_layer1 - 3) / 2)) + 1 net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv, axis=1) if rnn_type == "bilstm": net, f_states, b_states = bi_lstm_unroll( net=net, seq_len=seq_len_after_conv, num_hidden_lstm_list=num_hidden_rnn_list, num_lstm_layer=num_rnn_layer, dropout=0., num_hidden_proj=num_hidden_proj, is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) elif rnn_type == "gru": net = gru_unroll(net=net, seq_len=seq_len_after_conv, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) elif rnn_type == "bigru": net = bi_gru_unroll(net=net, seq_len=seq_len_after_conv, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm, is_bucketing=is_bucketing) else: raise Exception( 'rnn_type should be one of the followings, bilstm,gru,bigru' ) # rear fc layers net = sequence_fc(net=net, seq_len=seq_len_after_conv, num_layer=num_rear_fc_layers, prefix="rear", num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list, is_batchnorm=is_batchnorm) cls_weight = mx.sym.Variable("cls_weight") cls_bias = mx.sym.Variable("cls_bias") fc_seq = [] character_classes_count = args.config.getint('arch', 'n_classes') + 1 for seqidx in range(seq_len_after_conv): hidden = net[seqidx] hidden = mx.sym.FullyConnected( data=hidden, num_hidden=character_classes_count, weight=cls_weight, bias=cls_bias) fc_seq.append(hidden) net = mx.sym.Concat(*fc_seq, dim=0, name="warpctc_layer_concat") # warpctc layer net = warpctc_layer( net=net, seq_len=seq_len_after_conv, label=label, num_label=num_label, character_classes_count=( args.config.getint('arch', 'n_classes') + 1)) args.config.set('arch', 'max_t_count', str(seq_len_after_conv)) return net elif mode == 'load' or mode == 'predict': conv_layer1_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple( json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple( json.loads(args.config.get("arch", "conv_layer2_stride"))) conv_layer3_filter_dim = \ tuple(json.loads(args.config.get("arch", "conv_layer3_filter_dim"))) conv_layer3_stride = tuple( json.loads(args.config.get("arch", "conv_layer3_stride"))) if seq_len is None: seq_len = args.config.getint('arch', 'max_t_count') seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor( (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 seq_len_after_conv_layer3 = int( math.floor( (seq_len_after_conv_layer2 - conv_layer3_filter_dim[0]) / conv_layer3_stride[0])) + 1 args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer3)) else: raise Exception( 'mode must be the one of the followings - train,predict,load')
def arch(args): mode = args.config.get("common", "mode") if mode == "train": channel_num = args.config.getint("arch", "channel_num") conv_layer1_filter_dim = tuple( json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple( json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = tuple( json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple( json.loads(args.config.get("arch", "conv_layer2_stride"))) rnn_type = args.config.get("arch", "rnn_type") num_rnn_layer = args.config.getint("arch", "num_rnn_layer") num_hidden_rnn_list = json.loads( args.config.get("arch", "num_hidden_rnn_list")) is_batchnorm = args.config.getboolean("arch", "is_batchnorm") seq_len = args.config.getint('arch', 'max_t_count') num_label = args.config.getint('arch', 'max_label_length') num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers") num_hidden_rear_fc_list = json.loads( args.config.get("arch", "num_hidden_rear_fc_list")) act_type_rear_fc_list = json.loads( args.config.get("arch", "act_type_rear_fc_list")) # model symbol generation # input preparation data = mx.sym.Variable('data') label = mx.sym.Variable('label') net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0)) net = conv(net=net, channels=channel_num, filter_dimension=conv_layer1_filter_dim, stride=conv_layer1_stride, no_bias=is_batchnorm) if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net) net = conv(net=net, channels=channel_num, filter_dimension=conv_layer2_filter_dim, stride=conv_layer2_stride, no_bias=is_batchnorm) if is_batchnorm: # batch norm normalizes axis 1 net = batchnorm(net) net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3)) net = mx.sym.Reshape(data=net, shape=(0, 0, -3)) seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor( (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1) if rnn_type == "bilstm": net = bi_lstm_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_lstm_list=num_hidden_rnn_list, num_lstm_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm) elif rnn_type == "gru": net = gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm) elif rnn_type == "bigru": net = bi_gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, num_hidden_gru_list=num_hidden_rnn_list, num_gru_layer=num_rnn_layer, dropout=0., is_batchnorm=is_batchnorm) else: raise Exception( 'rnn_type should be one of the followings, bilstm,gru,bigru') # rear fc layers net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, num_layer=num_rear_fc_layers, prefix="rear", num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list, is_batchnorm=is_batchnorm) if is_batchnorm: hidden_all = [] # batch norm normalizes axis 1 for seq_index in range(seq_len_after_conv_layer2): hidden = net[seq_index] hidden = batchnorm(hidden) hidden_all.append(hidden) net = hidden_all # warpctc layer net = warpctc_layer( net=net, seq_len=seq_len_after_conv_layer2, label=label, num_label=num_label, character_classes_count=(args.config.getint('arch', 'n_classes') + 1)) args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) return net else: conv_layer1_filter_dim = tuple( json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) conv_layer1_stride = tuple( json.loads(args.config.get("arch", "conv_layer1_stride"))) conv_layer2_filter_dim = tuple( json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) conv_layer2_stride = tuple( json.loads(args.config.get("arch", "conv_layer2_stride"))) seq_len = args.config.getint('arch', 'max_t_count') seq_len_after_conv_layer1 = int( math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 seq_len_after_conv_layer2 = int( math.floor( (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))