示例#1
0
def vanilla_lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, is_batchnorm=False, gamma=None, beta=None, name=None):
    """LSTM Cell symbol"""
    i2h = mx.sym.FullyConnected(data=indata,
                                weight=param.i2h_weight,
                                bias=param.i2h_bias,
                                num_hidden=num_hidden * 4,
                                name="t%d_l%d_i2h" % (seqidx, layeridx))
    if is_batchnorm:
        if name is not None:
            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name)
        else:
            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)
    h2h = mx.sym.FullyConnected(data=prev_state.h,
                                weight=param.h2h_weight,
                                bias=param.h2h_bias,
                                num_hidden=num_hidden * 4,
                                name="t%d_l%d_h2h" % (seqidx, layeridx))
    gates = i2h + h2h
    slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
                                      name="t%d_l%d_slice" % (seqidx, layeridx))
    in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
    in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
    forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
    out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
    next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
    next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
    return LSTMState(c=next_c, h=next_h)
def vanilla_lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, is_batchnorm=False, gamma=None, beta=None,
                 name=None):
    """LSTM Cell symbol"""
    i2h = mx.sym.FullyConnected(data=indata,
                                weight=param.i2h_weight,
                                bias=param.i2h_bias,
                                num_hidden=num_hidden * 4,
                                name="t%d_l%d_i2h" % (seqidx, layeridx))
    if is_batchnorm:
        if name is not None:
            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name)
        else:
            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)
    h2h = mx.sym.FullyConnected(data=prev_state.h,
                                weight=param.h2h_weight,
                                bias=param.h2h_bias,
                                num_hidden=num_hidden * 4,
                                name="t%d_l%d_h2h" % (seqidx, layeridx))
    gates = i2h + h2h
    slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
                                      name="t%d_l%d_slice" % (seqidx, layeridx))
    in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
    in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
    forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
    out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
    next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
    next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
    return LSTMState(c=next_c, h=next_h)
示例#3
0
def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., num_hidden_proj=0, is_batchnorm=False,
         gamma=None, beta=None, name=None):
    """LSTM Cell symbol"""
    # dropout input
    if dropout > 0.:
        indata = mx.sym.Dropout(data=indata, p=dropout)

    i2h = mx.sym.FullyConnected(data=indata,
                                weight=param.i2h_weight,
                                bias=param.i2h_bias,
                                num_hidden=num_hidden * 4,
                                name="t%d_l%d_i2h" % (seqidx, layeridx))
    if is_batchnorm:
        if name is not None:
            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name)
        else:
            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)

    h2h = mx.sym.FullyConnected(data=prev_state.h,
                                weight=param.h2h_weight,
                                # bias=param.h2h_bias,
                                no_bias=True,
                                num_hidden=num_hidden * 4,
                                name="t%d_l%d_h2h" % (seqidx, layeridx))

    gates = i2h + h2h
    slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
                                      name="t%d_l%d_slice" % (seqidx, layeridx))

    Wcidc = mx.sym.broadcast_mul(param.c2i_bias, prev_state.c) + slice_gates[0]
    in_gate = mx.sym.Activation(Wcidc, act_type="sigmoid")

    in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")

    Wcfdc = mx.sym.broadcast_mul(param.c2f_bias, prev_state.c) + slice_gates[2]
    forget_gate = mx.sym.Activation(Wcfdc, act_type="sigmoid")

    next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)

    Wcoct = mx.sym.broadcast_mul(param.c2o_bias, next_c) + slice_gates[3]
    out_gate = mx.sym.Activation(Wcoct, act_type="sigmoid")

    next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")

    if num_hidden_proj > 0:
        proj_next_h = mx.sym.FullyConnected(data=next_h,
                                            weight=param.ph2h_weight,
                                            no_bias=True,
                                            num_hidden=num_hidden_proj,
                                            name="t%d_l%d_ph2h" % (seqidx, layeridx))

        return LSTMState(c=next_c, h=proj_next_h)
    else:
        return LSTMState(c=next_c, h=next_h)
def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., num_hidden_proj=0, is_batchnorm=False,
         gamma=None, beta=None, name=None):
    """LSTM Cell symbol"""
    # dropout input
    if dropout > 0.:
        indata = mx.sym.Dropout(data=indata, p=dropout)

    i2h = mx.sym.FullyConnected(data=indata,
                                weight=param.i2h_weight,
                                bias=param.i2h_bias,
                                num_hidden=num_hidden * 4,
                                name="t%d_l%d_i2h" % (seqidx, layeridx))
    if is_batchnorm:
        if name is not None:
            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name)
        else:
            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)

    h2h = mx.sym.FullyConnected(data=prev_state.h,
                                weight=param.h2h_weight,
                                # bias=param.h2h_bias,
                                no_bias=True,
                                num_hidden=num_hidden * 4,
                                name="t%d_l%d_h2h" % (seqidx, layeridx))

    gates = i2h + h2h
    slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
                                      name="t%d_l%d_slice" % (seqidx, layeridx))

    Wcidc = mx.sym.broadcast_mul(param.c2i_bias, prev_state.c) + slice_gates[0]
    in_gate = mx.sym.Activation(Wcidc, act_type="sigmoid")

    in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")

    Wcfdc = mx.sym.broadcast_mul(param.c2f_bias, prev_state.c) + slice_gates[2]
    forget_gate = mx.sym.Activation(Wcfdc, act_type="sigmoid")

    next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)

    Wcoct = mx.sym.broadcast_mul(param.c2o_bias, next_c) + slice_gates[3]
    out_gate = mx.sym.Activation(Wcoct, act_type="sigmoid")

    next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")

    if num_hidden_proj > 0:
        proj_next_h = mx.sym.FullyConnected(data=next_h,
                                            weight=param.ph2h_weight,
                                            no_bias=True,
                                            num_hidden=num_hidden_proj,
                                            name="t%d_l%d_ph2h" % (seqidx, layeridx))

        return LSTMState(c=next_c, h=proj_next_h)
    else:
        return LSTMState(c=next_c, h=next_h)
示例#5
0
def gru(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., is_batchnorm=False, gamma=None, beta=None, name=None):
    """
    GRU Cell symbol
    Reference:
    * Chung, Junyoung, et al. "Empirical evaluation of gated recurrent neural
        networks on sequence modeling." arXiv preprint arXiv:1412.3555 (2014).
    """
    if dropout > 0.:
        indata = mx.sym.Dropout(data=indata, p=dropout)
    i2h = mx.sym.FullyConnected(data=indata,
                                weight=param.gates_i2h_weight,
                                bias=param.gates_i2h_bias,
                                num_hidden=num_hidden * 2,
                                name="t%d_l%d_gates_i2h" % (seqidx, layeridx))

    if is_batchnorm:
        if name is not None:
            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name)
        else:
            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)
    h2h = mx.sym.FullyConnected(data=prev_state.h,
                                weight=param.gates_h2h_weight,
                                bias=param.gates_h2h_bias,
                                num_hidden=num_hidden * 2,
                                name="t%d_l%d_gates_h2h" % (seqidx, layeridx))
    gates = i2h + h2h
    slice_gates = mx.sym.SliceChannel(gates, num_outputs=2,
                                      name="t%d_l%d_slice" % (seqidx, layeridx))
    update_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
    reset_gate = mx.sym.Activation(slice_gates[1], act_type="sigmoid")
    # The transform part of GRU is a little magic
    htrans_i2h = mx.sym.FullyConnected(data=indata,
                                       weight=param.trans_i2h_weight,
                                       bias=param.trans_i2h_bias,
                                       num_hidden=num_hidden,
                                       name="t%d_l%d_trans_i2h" % (seqidx, layeridx))
    h_after_reset = prev_state.h * reset_gate
    htrans_h2h = mx.sym.FullyConnected(data=h_after_reset,
                                       weight=param.trans_h2h_weight,
                                       bias=param.trans_h2h_bias,
                                       num_hidden=num_hidden,
                                       name="t%d_l%d_trans_h2h" % (seqidx, layeridx))
    h_trans = htrans_i2h + htrans_h2h
    h_trans_active = mx.sym.Activation(h_trans, act_type="tanh")
    next_h = prev_state.h + update_gate * (h_trans_active - prev_state.h)
    return GRUState(h=next_h)
示例#6
0
def gru(num_hidden,
        indata,
        prev_state,
        param,
        seqidx,
        layeridx,
        dropout=0.,
        is_batchnorm=False,
        gamma=None,
        beta=None):
    """
    GRU Cell symbol
    Reference:
    * Chung, Junyoung, et al. "Empirical evaluation of gated recurrent neural
        networks on sequence modeling." arXiv preprint arXiv:1412.3555 (2014).
    """
    if dropout > 0.:
        indata = mx.sym.Dropout(data=indata, p=dropout)
    i2h = mx.sym.FullyConnected(data=indata,
                                weight=param.gates_i2h_weight,
                                bias=param.gates_i2h_bias,
                                num_hidden=num_hidden * 2,
                                name="t%d_l%d_gates_i2h" % (seqidx, layeridx))

    if is_batchnorm:
        i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)
    h2h = mx.sym.FullyConnected(data=prev_state.h,
                                weight=param.gates_h2h_weight,
                                bias=param.gates_h2h_bias,
                                num_hidden=num_hidden * 2,
                                name="t%d_l%d_gates_h2h" % (seqidx, layeridx))
    gates = i2h + h2h
    slice_gates = mx.sym.SliceChannel(gates,
                                      num_outputs=2,
                                      name="t%d_l%d_slice" %
                                      (seqidx, layeridx))
    update_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
    reset_gate = mx.sym.Activation(slice_gates[1], act_type="sigmoid")
    # The transform part of GRU is a little magic
    htrans_i2h = mx.sym.FullyConnected(data=indata,
                                       weight=param.trans_i2h_weight,
                                       bias=param.trans_i2h_bias,
                                       num_hidden=num_hidden,
                                       name="t%d_l%d_trans_i2h" %
                                       (seqidx, layeridx))
    h_after_reset = prev_state.h * reset_gate
    htrans_h2h = mx.sym.FullyConnected(data=h_after_reset,
                                       weight=param.trans_h2h_weight,
                                       bias=param.trans_h2h_bias,
                                       num_hidden=num_hidden,
                                       name="t%d_l%d_trans_i2h" %
                                       (seqidx, layeridx))
    h_trans = htrans_i2h + htrans_h2h
    h_trans_active = mx.sym.Activation(h_trans, act_type="tanh")
    next_h = prev_state.h + update_gate * (h_trans_active - prev_state.h)
    return GRUState(h=next_h)
def arch(args, seq_len=None):
    """
    define deep speech 2 network
    """
    if isinstance(args, argparse.Namespace):
        mode = args.config.get("common", "mode")
        if mode == "train":
            channel_num = args.config.getint("arch", "channel_num")
            conv_layer1_filter_dim = \
                tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
            conv_layer1_stride = tuple(
                json.loads(args.config.get("arch", "conv_layer1_stride")))
            conv_layer2_filter_dim = \
                tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
            conv_layer2_stride = tuple(
                json.loads(args.config.get("arch", "conv_layer2_stride")))

            rnn_type = args.config.get("arch", "rnn_type")
            num_rnn_layer = args.config.getint("arch", "num_rnn_layer")
            num_hidden_rnn_list = json.loads(
                args.config.get("arch", "num_hidden_rnn_list"))

            is_batchnorm = args.config.getboolean("arch", "is_batchnorm")
            is_bucketing = args.config.getboolean("arch", "is_bucketing")

            if seq_len is None:
                seq_len = args.config.getint('arch', 'max_t_count')

            num_label = args.config.getint('arch', 'max_label_length')

            num_rear_fc_layers = args.config.getint("arch",
                                                    "num_rear_fc_layers")
            num_hidden_rear_fc_list = json.loads(
                args.config.get("arch", "num_hidden_rear_fc_list"))
            act_type_rear_fc_list = json.loads(
                args.config.get("arch", "act_type_rear_fc_list"))
            # model symbol generation
            # input preparation
            data = mx.sym.Variable('data')
            label = mx.sym.Variable('label')

            net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0))
            net = conv(net=net,
                       channels=channel_num,
                       filter_dimension=conv_layer1_filter_dim,
                       stride=conv_layer1_stride,
                       no_bias=is_batchnorm,
                       name='conv1')
            if is_batchnorm:
                # batch norm normalizes axis 1
                net = batchnorm(net, name="conv1_batchnorm")

            net = conv(net=net,
                       channels=channel_num,
                       filter_dimension=conv_layer2_filter_dim,
                       stride=conv_layer2_stride,
                       no_bias=is_batchnorm,
                       name='conv2')
            if is_batchnorm:
                # batch norm normalizes axis 1
                net = batchnorm(net, name="conv2_batchnorm")

            net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3))
            net = mx.sym.Reshape(data=net, shape=(0, 0, -3))
            seq_len_after_conv_layer1 = int(
                math.floor((seq_len - conv_layer1_filter_dim[0]) /
                           conv_layer1_stride[0])) + 1
            seq_len_after_conv_layer2 = int(
                math.floor(
                    (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) /
                    conv_layer2_stride[0])) + 1
            net = slice_symbol_to_seq_symobls(
                net=net, seq_len=seq_len_after_conv_layer2, axis=1)
            if rnn_type == "bilstm":
                net = bi_lstm_unroll(net=net,
                                     seq_len=seq_len_after_conv_layer2,
                                     num_hidden_lstm_list=num_hidden_rnn_list,
                                     num_lstm_layer=num_rnn_layer,
                                     dropout=0.,
                                     is_batchnorm=is_batchnorm,
                                     is_bucketing=is_bucketing)
            elif rnn_type == "gru":
                net = gru_unroll(net=net,
                                 seq_len=seq_len_after_conv_layer2,
                                 num_hidden_gru_list=num_hidden_rnn_list,
                                 num_gru_layer=num_rnn_layer,
                                 dropout=0.,
                                 is_batchnorm=is_batchnorm,
                                 is_bucketing=is_bucketing)
            elif rnn_type == "bigru":
                net = bi_gru_unroll(net=net,
                                    seq_len=seq_len_after_conv_layer2,
                                    num_hidden_gru_list=num_hidden_rnn_list,
                                    num_gru_layer=num_rnn_layer,
                                    dropout=0.,
                                    is_batchnorm=is_batchnorm,
                                    is_bucketing=is_bucketing)
            else:
                raise Exception(
                    'rnn_type should be one of the followings, bilstm,gru,bigru'
                )

            # rear fc layers
            net = sequence_fc(net=net,
                              seq_len=seq_len_after_conv_layer2,
                              num_layer=num_rear_fc_layers,
                              prefix="rear",
                              num_hidden_list=num_hidden_rear_fc_list,
                              act_type_list=act_type_rear_fc_list,
                              is_batchnorm=is_batchnorm)
            # warpctc layer
            net = warpctc_layer(
                net=net,
                seq_len=seq_len_after_conv_layer2,
                label=label,
                num_label=num_label,
                character_classes_count=(
                    args.config.getint('arch', 'n_classes') + 1))
            args.config.set('arch', 'max_t_count',
                            str(seq_len_after_conv_layer2))
            return net
        elif mode == 'load' or mode == 'predict':
            conv_layer1_filter_dim = \
                tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
            conv_layer1_stride = tuple(
                json.loads(args.config.get("arch", "conv_layer1_stride")))
            conv_layer2_filter_dim = \
                tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
            conv_layer2_stride = tuple(
                json.loads(args.config.get("arch", "conv_layer2_stride")))
            if seq_len is None:
                seq_len = args.config.getint('arch', 'max_t_count')
            seq_len_after_conv_layer1 = int(
                math.floor((seq_len - conv_layer1_filter_dim[0]) /
                           conv_layer1_stride[0])) + 1
            seq_len_after_conv_layer2 = int(
                math.floor(
                    (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) /
                    conv_layer2_stride[0])) + 1

            args.config.set('arch', 'max_t_count',
                            str(seq_len_after_conv_layer2))
        else:
            raise Exception(
                'mode must be the one of the followings - train,predict,load')
    else:
        raise Exception(
            'type of args should be one of the argparse.' +
            'Namespace for fixed length model or integer for variable length model'
        )
示例#8
0
def arch(args):
    mode = args.config.get("common", "mode")
    if mode == "train":
        channel_num = args.config.getint("arch", "channel_num")
        conv_layer1_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
        conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride")))
        conv_layer2_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
        conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride")))

        rnn_type = args.config.get("arch", "rnn_type")
        num_rnn_layer = args.config.getint("arch", "num_rnn_layer")
        num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list"))

        is_batchnorm = args.config.getboolean("arch", "is_batchnorm")

        seq_len = args.config.getint('arch', 'max_t_count')
        num_label = args.config.getint('arch', 'max_label_length')

        num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers")
        num_hidden_rear_fc_list = json.loads(args.config.get("arch", "num_hidden_rear_fc_list"))
        act_type_rear_fc_list = json.loads(args.config.get("arch", "act_type_rear_fc_list"))
        # model symbol generation
        # input preparation
        data = mx.sym.Variable('data')
        label = mx.sym.Variable('label')

        net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0))
        net = conv(net=net,
                   channels=channel_num,
                   filter_dimension=conv_layer1_filter_dim,
                   stride=conv_layer1_stride,
                   no_bias=is_batchnorm
                   )
        if is_batchnorm:
            # batch norm normalizes axis 1
            net = batchnorm(net)

        net = conv(net=net,
                   channels=channel_num,
                   filter_dimension=conv_layer2_filter_dim,
                   stride=conv_layer2_stride,
                   no_bias=is_batchnorm
                   )
        if is_batchnorm:
            # batch norm normalizes axis 1
            net = batchnorm(net)
        net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3))
        net = mx.sym.Reshape(data=net, shape=(0, 0, -3))
        seq_len_after_conv_layer1 = int(
            math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1
        seq_len_after_conv_layer2 = int(
            math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1
        net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1)
        if rnn_type == "bilstm":
            net = bi_lstm_unroll(net=net,
                                 seq_len=seq_len_after_conv_layer2,
                                 num_hidden_lstm_list=num_hidden_rnn_list,
                                 num_lstm_layer=num_rnn_layer,
                                 dropout=0.,
                                 is_batchnorm=is_batchnorm)
        elif rnn_type == "gru":
            net = gru_unroll(net=net,
                             seq_len=seq_len_after_conv_layer2,
                             num_hidden_gru_list=num_hidden_rnn_list,
                             num_gru_layer=num_rnn_layer,
                             dropout=0.,
                             is_batchnorm=is_batchnorm)
        elif rnn_type == "bigru":
            net = bi_gru_unroll(net=net,
                                seq_len=seq_len_after_conv_layer2,
                                num_hidden_gru_list=num_hidden_rnn_list,
                                num_gru_layer=num_rnn_layer,
                                dropout=0.,
                                is_batchnorm=is_batchnorm)
        else:
            raise Exception('rnn_type should be one of the followings, bilstm,gru,bigru')

        # rear fc layers
        net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, num_layer=num_rear_fc_layers, prefix="rear",
                          num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list,
                          is_batchnorm=is_batchnorm)
        if is_batchnorm:
            hidden_all = []
            # batch norm normalizes axis 1
            for seq_index in range(seq_len_after_conv_layer2):
                hidden = net[seq_index]
                hidden = batchnorm(hidden)
                hidden_all.append(hidden)
            net = hidden_all

        # warpctc layer
        net = warpctc_layer(net=net,
                            seq_len=seq_len_after_conv_layer2,
                            label=label,
                            num_label=num_label,
                            character_classes_count=(args.config.getint('arch', 'n_classes') + 1)
                            )
        args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
        return net
    else:
        conv_layer1_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
        conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride")))
        conv_layer2_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
        conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride")))
        seq_len = args.config.getint('arch', 'max_t_count')
        seq_len_after_conv_layer1 = int(
            math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1
        seq_len_after_conv_layer2 = int(
            math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1
        args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
示例#9
0
def sequence_fc(
    net,
    seq_len,
    num_layer,
    prefix,
    num_hidden_list=[],
    act_type_list=[],
    is_batchnorm=False,
    dropout_rate=0,
):
    if num_layer == len(num_hidden_list) == len(act_type_list):
        if num_layer > 0:
            weight_list = []
            bias_list = []

            for layer_index in range(num_layer):
                weight_list.append(
                    mx.sym.Variable(name='%s_sequence_fc%d_weight' %
                                    (prefix, layer_index)))
                # if you use batchnorm bias do not have any effect
                if not is_batchnorm:
                    bias_list.append(
                        mx.sym.Variable(name='%s_sequence_fc%d_bias' %
                                        (prefix, layer_index)))
            # batch normalization parameters
            gamma_list = []
            beta_list = []
            if is_batchnorm:
                for layer_index in range(num_layer):
                    gamma_list.append(
                        mx.sym.Variable(name='%s_sequence_fc%d_gamma' %
                                        (prefix, layer_index)))
                    beta_list.append(
                        mx.sym.Variable(name='%s_sequence_fc%d_beta' %
                                        (prefix, layer_index)))
            # batch normalization parameters ends
            if type(net) is mx.symbol.Symbol:
                net = mx.sym.SliceChannel(data=net,
                                          num_outputs=seq_len,
                                          axis=1,
                                          squeeze_axis=1)
            elif type(net) is list:
                for net_index, one_net in enumerate(net):
                    if type(one_net) is not mx.symbol.Symbol:
                        raise Exception(
                            '%d th elements of the net should be mx.symbol.Symbol'
                            % net_index)
            else:
                raise Exception(
                    'type of net should be whether mx.symbol.Symbol or list of mx.symbol.Symbol'
                )
            hidden_all = []
            for seq_index in range(seq_len):
                hidden = net[seq_index]
                for layer_index in range(num_layer):
                    if dropout_rate > 0:
                        hidden = mx.sym.Dropout(data=hidden, p=dropout_rate)

                    if is_batchnorm:
                        hidden = fc(net=hidden,
                                    num_hidden=num_hidden_list[layer_index],
                                    act_type=None,
                                    weight=weight_list[layer_index],
                                    no_bias=is_batchnorm,
                                    name="%s_t%d_l%d_fc" %
                                    (prefix, seq_index, layer_index))
                        # last layer doesn't have batchnorm
                        hidden = batchnorm(net=hidden,
                                           gamma=gamma_list[layer_index],
                                           beta=beta_list[layer_index],
                                           name="%s_t%d_l%d_batchnorm" %
                                           (prefix, seq_index, layer_index))
                        hidden = mx.sym.Activation(
                            data=hidden,
                            act_type=act_type_list[layer_index],
                            name="%s_t%d_l%d_activation" %
                            (prefix, seq_index, layer_index))
                    else:
                        hidden = fc(net=hidden,
                                    num_hidden=num_hidden_list[layer_index],
                                    act_type=act_type_list[layer_index],
                                    weight=weight_list[layer_index],
                                    bias=bias_list[layer_index])
                hidden_all.append(hidden)
            net = hidden_all
        return net
    else:
        raise Exception("length doesn't met - num_layer:",
                        num_layer, ",len(num_hidden_list):",
                        len(num_hidden_list), ",len(act_type_list):",
                        len(act_type_list))
示例#10
0
def sequence_fc(net,
                seq_len,
                num_layer,
                prefix,
                num_hidden_list=[],
                act_type_list=[],
                is_batchnorm=False,
                dropout_rate=0,
                ):
    if num_layer == len(num_hidden_list) == len(act_type_list):
        if num_layer > 0:
            weight_list = []
            bias_list = []

            for layer_index in range(num_layer):
                weight_list.append(mx.sym.Variable(name='%s_sequence_fc%d_weight' % (prefix, layer_index)))
                # if you use batchnorm bias do not have any effect
                if not is_batchnorm:
                    bias_list.append(mx.sym.Variable(name='%s_sequence_fc%d_bias' % (prefix, layer_index)))
            # batch normalization parameters
            gamma_list = []
            beta_list = []
            if is_batchnorm:
                for layer_index in range(num_layer):
                    gamma_list.append(mx.sym.Variable(name='%s_sequence_fc%d_gamma' % (prefix, layer_index)))
                    beta_list.append(mx.sym.Variable(name='%s_sequence_fc%d_beta' % (prefix, layer_index)))
            # batch normalization parameters ends
            if type(net) is mx.symbol.Symbol:
                net = mx.sym.SliceChannel(data=net, num_outputs=seq_len, axis=1, squeeze_axis=1)
            elif type(net) is list:
                for net_index, one_net in enumerate(net):
                    if type(one_net) is not mx.symbol.Symbol:
                        raise Exception('%d th elements of the net should be mx.symbol.Symbol' % net_index)
            else:
                raise Exception('type of net should be whether mx.symbol.Symbol or list of mx.symbol.Symbol')
            hidden_all = []
            for seq_index in range(seq_len):
                hidden = net[seq_index]
                for layer_index in range(num_layer):
                    if dropout_rate > 0:
                        hidden = mx.sym.Dropout(data=hidden, p=dropout_rate)

                    if is_batchnorm:
                        hidden = fc(net=hidden,
                                    num_hidden=num_hidden_list[layer_index],
                                    act_type=None,
                                    weight=weight_list[layer_index],
                                    no_bias=is_batchnorm,
                                    name="%s_t%d_l%d_fc" % (prefix, seq_index, layer_index)
                                    )
                        # last layer doesn't have batchnorm
                        hidden = batchnorm(net=hidden,
                                           gamma=gamma_list[layer_index],
                                           beta=beta_list[layer_index],
                                           name="%s_t%d_l%d_batchnorm" % (prefix, seq_index, layer_index))
                        hidden = mx.sym.Activation(data=hidden, act_type=act_type_list[layer_index],
                                                   name="%s_t%d_l%d_activation" % (prefix, seq_index, layer_index))
                    else:
                        hidden = fc(net=hidden,
                                    num_hidden=num_hidden_list[layer_index],
                                    act_type=act_type_list[layer_index],
                                    weight=weight_list[layer_index],
                                    bias=bias_list[layer_index]
                                    )
                hidden_all.append(hidden)
            net = hidden_all
        return net
    else:
        raise Exception("length doesn't met - num_layer:",
                        num_layer, ",len(num_hidden_list):",
                        len(num_hidden_list),
                        ",len(act_type_list):",
                        len(act_type_list)
                        )
示例#11
0
def arch(args, seq_len=None):
    """
    define deep speech 2 network
    """
    if isinstance(args, argparse.Namespace):
        mode = args.config.get("common", "mode")
        is_bucketing = args.config.getboolean("arch", "is_bucketing")
        if mode == "train" or is_bucketing:
            channel_num = args.config.getint("arch", "channel_num")
            conv_layer1_filter_dim = \
                tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
            conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride")))
            conv_layer2_filter_dim = \
                tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
            conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride")))

            rnn_type = args.config.get("arch", "rnn_type")
            num_rnn_layer = args.config.getint("arch", "num_rnn_layer")
            num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list"))

            is_batchnorm = args.config.getboolean("arch", "is_batchnorm")

            if seq_len is None:
                seq_len = args.config.getint('arch', 'max_t_count')

            num_label = args.config.getint('arch', 'max_label_length')

            num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers")
            num_hidden_rear_fc_list = json.loads(args.config.get("arch", "num_hidden_rear_fc_list"))
            act_type_rear_fc_list = json.loads(args.config.get("arch", "act_type_rear_fc_list"))
            # model symbol generation
            # input preparation
            data = mx.sym.Variable('data')
            label = mx.sym.Variable('label')

            net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0))
            net = conv(net=net,
                       channels=channel_num,
                       filter_dimension=conv_layer1_filter_dim,
                       stride=conv_layer1_stride,
                       no_bias=is_batchnorm,
                       name='conv1')
            if is_batchnorm:
               # batch norm normalizes axis 1
               net = batchnorm(net, name="conv1_batchnorm")

            net = conv(net=net,
                       channels=channel_num,
                       filter_dimension=conv_layer2_filter_dim,
                       stride=conv_layer2_stride,
                       no_bias=is_batchnorm,
                       name='conv2')
            if is_batchnorm:
                # batch norm normalizes axis 1
                net = batchnorm(net, name="conv2_batchnorm")

            net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3))
            net = mx.sym.Reshape(data=net, shape=(0, 0, -3))
            seq_len_after_conv_layer1 = int(
                math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1
            seq_len_after_conv_layer2 = int(
                math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0])
                           / conv_layer2_stride[0])) + 1
            net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1)
            if rnn_type == "bilstm":
                net = bi_lstm_unroll(net=net,
                                     seq_len=seq_len_after_conv_layer2,
                                     num_hidden_lstm_list=num_hidden_rnn_list,
                                     num_lstm_layer=num_rnn_layer,
                                     dropout=0.,
                                     is_batchnorm=is_batchnorm,
                                     is_bucketing=is_bucketing)
            elif rnn_type == "gru":
                net = gru_unroll(net=net,
                                 seq_len=seq_len_after_conv_layer2,
                                 num_hidden_gru_list=num_hidden_rnn_list,
                                 num_gru_layer=num_rnn_layer,
                                 dropout=0.,
                                 is_batchnorm=is_batchnorm,
                                 is_bucketing=is_bucketing)
            elif rnn_type == "bigru":
                net = bi_gru_unroll(net=net,
                                    seq_len=seq_len_after_conv_layer2,
                                    num_hidden_gru_list=num_hidden_rnn_list,
                                    num_gru_layer=num_rnn_layer,
                                    dropout=0.,
                                    is_batchnorm=is_batchnorm,
                                    is_bucketing=is_bucketing)
            else:
                raise Exception('rnn_type should be one of the followings, bilstm,gru,bigru')

            # rear fc layers
            net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2,
                              num_layer=num_rear_fc_layers, prefix="rear",
                              num_hidden_list=num_hidden_rear_fc_list,
                              act_type_list=act_type_rear_fc_list,
                              is_batchnorm=is_batchnorm)
            # warpctc layer
            net = warpctc_layer(net=net,
                                seq_len=seq_len_after_conv_layer2,
                                label=label,
                                num_label=num_label,
                                character_classes_count=
                                (args.config.getint('arch', 'n_classes') + 1))
            args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
            return net
        elif mode == 'load' or mode == 'predict':
            conv_layer1_filter_dim = \
                tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
            conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride")))
            conv_layer2_filter_dim = \
                tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
            conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride")))
            if seq_len is None:
                seq_len = args.config.getint('arch', 'max_t_count')
            seq_len_after_conv_layer1 = int(
                math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1
            seq_len_after_conv_layer2 = int(
                math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0])
                           / conv_layer2_stride[0])) + 1

            args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
        else:
            raise Exception('mode must be the one of the followings - train,predict,load')
示例#12
0
def arch(args, seq_len=None):
    """
    define deep speech 2 network
    """
    if isinstance(args, argparse.Namespace):
        mode = args.config.get("common", "mode")
        is_bucketing = args.config.getboolean("arch", "is_bucketing")
        if mode == "train" or is_bucketing:
            channel_num = args.config.getint("arch", "channel_num")
            conv_layer1_filter_dim = \
                tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
            conv_layer1_stride = tuple(
                json.loads(args.config.get("arch", "conv_layer1_stride")))
            conv_layer2_filter_dim = \
                tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
            conv_layer2_stride = tuple(
                json.loads(args.config.get("arch", "conv_layer2_stride")))

            rnn_type = args.config.get("arch", "rnn_type")
            num_rnn_layer = args.config.getint("arch", "num_rnn_layer")

            num_hidden_proj = args.config.getint("arch", "num_hidden_proj")

            num_hidden_rnn_list = json.loads(
                args.config.get("arch", "num_hidden_rnn_list"))

            is_batchnorm = args.config.getboolean("arch", "is_batchnorm")

            if seq_len is None:
                seq_len = args.config.getint('arch', 'max_t_count')

            num_label = args.config.getint('arch', 'max_label_length')

            num_rear_fc_layers = args.config.getint("arch",
                                                    "num_rear_fc_layers")
            num_hidden_rear_fc_list = json.loads(
                args.config.get("arch", "num_hidden_rear_fc_list"))
            act_type_rear_fc_list = json.loads(
                args.config.get("arch", "act_type_rear_fc_list"))
            # model symbol generation
            # input preparation
            data = mx.sym.Variable('data')

            net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0))
            net = conv(net=net,
                       channels=channel_num,
                       filter_dimension=conv_layer1_filter_dim,
                       stride=conv_layer1_stride,
                       no_bias=is_batchnorm,
                       name='conv1')
            if is_batchnorm:
                # batch norm normalizes axis 1
                net = batchnorm(net, name="conv1_batchnorm")

            net = conv(net=net,
                       channels=channel_num,
                       filter_dimension=conv_layer2_filter_dim,
                       stride=conv_layer2_stride,
                       no_bias=is_batchnorm,
                       name='conv2')
            if is_batchnorm:
                # batch norm normalizes axis 1
                net = batchnorm(net, name="conv2_batchnorm")

            net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3))
            net = mx.sym.Reshape(data=net, shape=(0, 0, -3))
            seq_len_after_conv_layer1 = int(
                math.floor((seq_len - conv_layer1_filter_dim[0]) /
                           conv_layer1_stride[0])) + 1
            seq_len_after_conv_layer2 = int(
                math.floor(
                    (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) /
                    conv_layer2_stride[0])) + 1
            net = slice_symbol_to_seq_symobls(
                net=net, seq_len=seq_len_after_conv_layer2, axis=1)
            if rnn_type == "bilstm":
                net, f_states, b_states = bi_lstm_unroll(
                    net=net,
                    seq_len=seq_len_after_conv_layer2,
                    num_hidden_lstm_list=num_hidden_rnn_list,
                    num_lstm_layer=num_rnn_layer,
                    dropout=0.,
                    num_hidden_proj=num_hidden_proj,
                    is_batchnorm=is_batchnorm,
                    is_bucketing=is_bucketing)
            elif rnn_type == "gru":
                net = gru_unroll(net=net,
                                 seq_len=seq_len_after_conv_layer2,
                                 num_hidden_gru_list=num_hidden_rnn_list,
                                 num_gru_layer=num_rnn_layer,
                                 dropout=0.,
                                 is_batchnorm=is_batchnorm,
                                 is_bucketing=is_bucketing)
            elif rnn_type == "bigru":
                net = bi_gru_unroll(net=net,
                                    seq_len=seq_len_after_conv_layer2,
                                    num_hidden_gru_list=num_hidden_rnn_list,
                                    num_gru_layer=num_rnn_layer,
                                    dropout=0.,
                                    is_batchnorm=is_batchnorm,
                                    is_bucketing=is_bucketing)
            else:
                raise Exception(
                    'rnn_type should be one of the followings, bilstm,gru,bigru'
                )

            # rear fc layers
            net = sequence_fc(net=net,
                              seq_len=seq_len_after_conv_layer2,
                              num_layer=num_rear_fc_layers,
                              prefix="rear",
                              num_hidden_list=num_hidden_rear_fc_list,
                              act_type_list=act_type_rear_fc_list,
                              is_batchnorm=is_batchnorm)

            cls_weight = mx.sym.Variable("cls_weight")
            cls_bias = mx.sym.Variable("cls_bias")
            fc_seq = []
            character_classes_count = args.config.getint('arch',
                                                         'n_classes') + 1
            for seqidx in range(seq_len_after_conv_layer2):
                hidden = net[seqidx]
                hidden = mx.sym.FullyConnected(
                    data=hidden,
                    num_hidden=character_classes_count,
                    weight=cls_weight,
                    bias=cls_bias)
                fc_seq.append(hidden)
            net = mx.sym.Concat(*fc_seq, dim=0, name="warpctc_layer_concat")
            if mode == 'server':
                sm = mx.sym.SoftmaxActivation(data=net, name='softmax')
                output = [sm]
                for state in b_states + f_states:
                    output.append(state.c)
                    output.append(state.h)
                return mx.sym.Group(output)
            label = mx.sym.Variable('label')
            # warpctc layer
            net = warpctc_layer(net=net,
                                label=label,
                                num_label=num_label,
                                seq_len=seq_len_after_conv_layer2,
                                character_classes_count=args.config.getint(
                                    'arch', 'n_classes') + 1)
            args.config.set('arch', 'max_t_count',
                            str(seq_len_after_conv_layer2))
            return net
        elif mode == 'load' or mode == 'predict':
            conv_layer1_filter_dim = \
                tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
            conv_layer1_stride = tuple(
                json.loads(args.config.get("arch", "conv_layer1_stride")))
            conv_layer2_filter_dim = \
                tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
            conv_layer2_stride = tuple(
                json.loads(args.config.get("arch", "conv_layer2_stride")))
            if seq_len is None:
                seq_len = args.config.getint('arch', 'max_t_count')
            seq_len_after_conv_layer1 = int(
                math.floor((seq_len - conv_layer1_filter_dim[0]) /
                           conv_layer1_stride[0])) + 1
            seq_len_after_conv_layer2 = int(
                math.floor(
                    (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) /
                    conv_layer2_stride[0])) + 1

            args.config.set('arch', 'max_t_count',
                            str(seq_len_after_conv_layer2))
        else:
            raise Exception(
                'mode must be the one of the followings - train,predict,load')
示例#13
0
def arch(args):
    mode = args.config.get("common", "mode")
    if mode == "train":
        channel_num = args.config.getint("arch", "channel_num")
        conv_layer1_filter_dim = tuple(
            json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
        conv_layer1_stride = tuple(
            json.loads(args.config.get("arch", "conv_layer1_stride")))
        conv_layer2_filter_dim = tuple(
            json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
        conv_layer2_stride = tuple(
            json.loads(args.config.get("arch", "conv_layer2_stride")))

        rnn_type = args.config.get("arch", "rnn_type")
        num_rnn_layer = args.config.getint("arch", "num_rnn_layer")
        num_hidden_rnn_list = json.loads(
            args.config.get("arch", "num_hidden_rnn_list"))

        is_batchnorm = args.config.getboolean("arch", "is_batchnorm")

        seq_len = args.config.getint('arch', 'max_t_count')
        num_label = args.config.getint('arch', 'max_label_length')

        num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers")
        num_hidden_rear_fc_list = json.loads(
            args.config.get("arch", "num_hidden_rear_fc_list"))
        act_type_rear_fc_list = json.loads(
            args.config.get("arch", "act_type_rear_fc_list"))
        # model symbol generation
        # input preparation
        data = mx.sym.Variable('data')
        label = mx.sym.Variable('label')

        net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0))
        net = conv(net=net,
                   channels=channel_num,
                   filter_dimension=conv_layer1_filter_dim,
                   stride=conv_layer1_stride,
                   no_bias=is_batchnorm)
        if is_batchnorm:
            # batch norm normalizes axis 1
            net = batchnorm(net)

        net = conv(net=net,
                   channels=channel_num,
                   filter_dimension=conv_layer2_filter_dim,
                   stride=conv_layer2_stride,
                   no_bias=is_batchnorm)
        if is_batchnorm:
            # batch norm normalizes axis 1
            net = batchnorm(net)
        net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3))
        net = mx.sym.Reshape(data=net, shape=(0, 0, -3))
        seq_len_after_conv_layer1 = int(
            math.floor((seq_len - conv_layer1_filter_dim[0]) /
                       conv_layer1_stride[0])) + 1
        seq_len_after_conv_layer2 = int(
            math.floor(
                (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) /
                conv_layer2_stride[0])) + 1
        net = slice_symbol_to_seq_symobls(net=net,
                                          seq_len=seq_len_after_conv_layer2,
                                          axis=1)
        if rnn_type == "bilstm":
            net = bi_lstm_unroll(net=net,
                                 seq_len=seq_len_after_conv_layer2,
                                 num_hidden_lstm_list=num_hidden_rnn_list,
                                 num_lstm_layer=num_rnn_layer,
                                 dropout=0.,
                                 is_batchnorm=is_batchnorm)
        elif rnn_type == "gru":
            net = gru_unroll(net=net,
                             seq_len=seq_len_after_conv_layer2,
                             num_hidden_gru_list=num_hidden_rnn_list,
                             num_gru_layer=num_rnn_layer,
                             dropout=0.,
                             is_batchnorm=is_batchnorm)
        elif rnn_type == "bigru":
            net = bi_gru_unroll(net=net,
                                seq_len=seq_len_after_conv_layer2,
                                num_hidden_gru_list=num_hidden_rnn_list,
                                num_gru_layer=num_rnn_layer,
                                dropout=0.,
                                is_batchnorm=is_batchnorm)
        else:
            raise Exception(
                'rnn_type should be one of the followings, bilstm,gru,bigru')

        # rear fc layers
        net = sequence_fc(net=net,
                          seq_len=seq_len_after_conv_layer2,
                          num_layer=num_rear_fc_layers,
                          prefix="rear",
                          num_hidden_list=num_hidden_rear_fc_list,
                          act_type_list=act_type_rear_fc_list,
                          is_batchnorm=is_batchnorm)
        if is_batchnorm:
            hidden_all = []
            # batch norm normalizes axis 1
            for seq_index in range(seq_len_after_conv_layer2):
                hidden = net[seq_index]
                hidden = batchnorm(hidden)
                hidden_all.append(hidden)
            net = hidden_all

        # warpctc layer
        net = warpctc_layer(
            net=net,
            seq_len=seq_len_after_conv_layer2,
            label=label,
            num_label=num_label,
            character_classes_count=(args.config.getint('arch', 'n_classes') +
                                     1))
        args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
        return net
    else:
        conv_layer1_filter_dim = tuple(
            json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
        conv_layer1_stride = tuple(
            json.loads(args.config.get("arch", "conv_layer1_stride")))
        conv_layer2_filter_dim = tuple(
            json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
        conv_layer2_stride = tuple(
            json.loads(args.config.get("arch", "conv_layer2_stride")))
        seq_len = args.config.getint('arch', 'max_t_count')
        seq_len_after_conv_layer1 = int(
            math.floor((seq_len - conv_layer1_filter_dim[0]) /
                       conv_layer1_stride[0])) + 1
        seq_len_after_conv_layer2 = int(
            math.floor(
                (seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) /
                conv_layer2_stride[0])) + 1
        args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))