コード例 #1
0
ファイル: lstm_bucketing.py プロジェクト: 1132520084/mxnet
 def sym_gen(seq_len):
     sym = lstm_unroll(num_lstm_layer, seq_len, len(vocab),
                       num_hidden=num_hidden, num_embed=num_embed,
                       num_label=len(vocab))
     data_names = ['data'] + state_names
     label_names = ['softmax_label']
     return (sym, data_names, label_names)
コード例 #2
0
ファイル: lstm_bucketing.py プロジェクト: BawCOS/mxnet
 def sym_gen(seq_len):
     sym = lstm_unroll(num_lstm_layer, seq_len, len(vocab),
                       num_hidden=num_hidden, num_embed=num_embed,
                       num_label=len(vocab))
     data_names = ['data/%d' % t for t in range(seq_len)] + state_names
     label_names = ['label/%d' % t for t in range(seq_len)]
     return (sym, data_names, label_names)
コード例 #3
0
 def sym_gen(seq_len):
     return lstm_unroll(num_lstm_layer,
                        seq_len,
                        len(vocab),
                        num_hidden=num_hidden,
                        num_embed=num_embed,
                        num_label=len(vocab))
コード例 #4
0
 def sym_gen(seq_len):
     return lstm_unroll(num_lstm_layer,
                        seq_len,
                        input_len,
                        num_hidden=num_hidden,
                        num_embed=num_embed,
                        num_label=num_label,
                        dropout=dropout)
コード例 #5
0
 def sym_gen(seq_len):
     sym = lstm_unroll(num_lstm_layer,
                       seq_len,
                       len(vocab),
                       num_hidden=num_hidden,
                       num_embed=num_embed,
                       num_label=len(vocab))
     data_names = ['data/%d' % t for t in range(seq_len)] + state_names
     label_names = ['label/%d' % t for t in range(seq_len)]
     return (sym, data_names, label_names)
コード例 #6
0
 def sym_gen(seq_len):
     sym = lstm_unroll(num_lstm_layer,
                       seq_len,
                       len(vocab),
                       num_hidden=num_hidden,
                       num_embed=num_embed,
                       num_label=len(vocab))
     data_names = ['data'] + state_names
     label_names = ['softmax_label']
     return (sym, data_names, label_names)
コード例 #7
0
ファイル: Seq2Seq.py プロジェクト: JianboTang/mxnet-seq2seq
    def build_lstm_decoder(self):
        dec_lstm = lstm_unroll(num_lstm_layer=self.num_layers,
                               seq_len=self.seq_len,
                               input_size=self.input_size,
                               num_hidden=self.hidden_size,
                               num_embed=self.embed_size,
                               num_label=self.output_size)

        init_c = [('l%d_init_c' % l, (self.batch_size, self.hidden_size))
                  for l in range(self.num_layers)]
        init_h = [('l%d_init_h' % l, (self.batch_size, self.hidden_size))
                  for l in range(self.num_layers)]
        init_states = init_c + init_h
        input_data = [('data', (self.batch_size, self.seq_len))]
        provide_data = input_data + init_states
        provide_label = [('softmax_label', (self.batch_size, self.seq_len))]
        provide_args = dict(provide_data + provide_label)
        arg_shape, output_shape, _ = dec_lstm.infer_shape(**provide_args)
        arg_names = dec_lstm.list_arguments()

        args = {}
        args_grad = {}
        grad_req = {}
        for shape, name in zip(arg_shape, arg_names):
            args[name] = mx.nd.zeros(shape, self.mx_ctx)
            if name in ['softmax_label', 'data'] or name.endswith('init_c'):
                continue
            args_grad[name] = mx.nd.zeros(shape, self.mx_ctx)
            grad_req[name] = 'write'

        for name in arg_names:
            if name in ['data', 'softmax_label'] or \
                    name.endswith('init_h') or name.endswith('init_c'):
                continue
            self.initializer(name, args_grad[name])

        dec_lstm_exe = dec_lstm.bind(ctx=self.mx_ctx,
                                     args=args,
                                     args_grad=args_grad,
                                     grad_req=grad_req)

        for name in args_grad.keys():
            self.params_blocks.append((len(self.params_blocks) + 1, name,
                                       args[name], args_grad[name]))
        print self.params_blocks

        return dec_lstm_exe
コード例 #8
0
def build_lstm():
    seq_len = 129
    # embedding dimension, which maps a character to a 256-dimension vector
    num_embed = 256
    # number of lstm layers
    num_lstm_layer = 3
    # hidden unit in LSTM cell
    num_hidden = 512

    symbol = lstm.lstm_unroll(
        num_lstm_layer,
        seq_len,
        len(vocab) + 1,
        num_hidden=num_hidden,
        num_embed=num_embed,
        num_label=len(vocab) + 1,
        dropout=0.2)
コード例 #9
0
def get_cnn_rnn_attention(num_cls,
                          for_training,
                          rnn_dropout,
                          rnn_hidden,
                          rnn_window,
                          fix_till_relu7=False):
    """ get model with CNN + RNN + attention
    Parameters
    ----------------------------
    num_cls: int
        number of classes
    for_training: bool
    rnn_dropout: float
        RNN dropout probability
    rnn_hidden: int
        number of hidden units of each RNN unit
    rnn_window: int
        number of timesteps
    fix_till_relu7: bool
        whether to fix CNN feature extracting part

    Return
    -----------------------------
    (mx.symbol for net, mx.symbol for loss)
    """
    """
    require from DataIter:
        data
        gesture_softmax_label
        att_gesture_softmax_label
    """

    # input
    net = mx.symbol.Variable(
        name="data")  # (batch_size, rnn_windows * c, h, w)
    net = mx.symbol.Reshape(
        net,
        shape=(
            0,
            -1,  # (batch_size, rnn_windows, c, h, w)
            my_constant.INPUT_CHANNEL,
            my_constant.INPUT_SIDE,
            my_constant.INPUT_SIDE))
    net = mx.symbol.SwapAxis(
        net,  # (rnn_windows, batch_size, c, h, w)
        dim1=0,
        dim2=1)
    net = mx.symbol.Reshape(
        net,
        shape=(
            -1,  # (rnn_windows * batch_size, c, h, w)
            my_constant.INPUT_CHANNEL,
            my_constant.INPUT_SIDE,
            my_constant.INPUT_SIDE))
    """
    CNN module
    """
    feature = get_vgg16(data=net,
                        num_classes=num_cls,
                        fix_till_relu7=fix_till_relu7)["relu7"]

    loss = []
    """
    RNN module
    """
    # split into time steps
    feature = mx.symbol.Reshape(
        feature,
        shape=(rnn_window, -1, my_constant.FEATURE_DIM))  # (32, 1, 4096)

    feature = mx.symbol.SwapAxis(feature, dim1=0, dim2=1)  # (1, 32, 4096)

    # f_h(X) (output of LSTM)
    feature = lstm_unroll(prefix='',
                          data=feature,
                          num_rnn_layer=1,
                          seq_len=rnn_window,
                          num_hidden=rnn_hidden,
                          dropout=rnn_dropout,
                          bn=True)

    concat_feature = mx.sym.Reshape(data=mx.sym.Concat(*feature, dim=1),
                                    shape=(-1, rnn_window,
                                           rnn_hidden))  # (1, 32, 512)
    """
    attention module
    """
    M = []
    # weight_v = w^T
    weight_v = mx.sym.Variable('atten_v_bias', shape=(rnn_hidden, 1))
    # weight_u = W_h
    weight_u = mx.sym.Variable('atten_u_weight',
                               shape=(rnn_hidden, rnn_hidden))
    for i in range(rnn_window):
        # feature1[i] = h_t
        tmp = mx.sym.dot(feature[i], weight_u, name='atten_u_%d' % i)
        # M_t
        tmp = mx.sym.Activation(tmp, act_type='tanh')
        tmp = mx.sym.dot(tmp, weight_v, name='atten_v_%d' % i)
        M.append(tmp)

    M = mx.sym.Concat(*M, dim=1)  # (1, 32)

    # alphas
    a = mx.symbol.SoftmaxActivation(name='atten_softmax_%d' % i, data=M)
    a = mx.sym.Reshape(data=a, shape=(-1, rnn_window, 1))  # (1, 32, 1)

    # r
    r = mx.symbol.broadcast_mul(name='atten_r_%d' % i,
                                lhs=a,
                                rhs=concat_feature)  # (1, 32, 512)

    z = mx.sym.sum(data=r, axis=1)  # (1, 512)

    # loss_target is used only in training
    if for_training:
        feature = mx.symbol.Concat(*feature, dim=0)  # (32, 512)

        # loss_target
        gesture_branch_kargs = {}
        gesture_label = mx.symbol.Variable(
            name='att_gesture_softmax_label')  # m
        gesture_label = mx.symbol.Reshape(mx.symbol.Concat(*[
            mx.symbol.Reshape(gesture_label, shape=(0, 1))
            for i in range(rnn_window)
        ],
                                                           dim=0),
                                          shape=(-1, ))

        gesture_branch_kargs['label'] = gesture_label
        gesture_branch_kargs['grad_scale'] = 1 / rnn_window
        gesture_softmax, gesture_fc = get_branch(
            for_training=for_training,
            name='att_gesture',  # m
            data=feature,
            num_class=num_cls,
            return_fc=True,
            use_ignore=True,  # ???
            **gesture_branch_kargs)

        loss.append(gesture_softmax)

    # loss_attention is used in both training and testing
    att_gesture_branch_kargs = {}
    att_gesture_label = mx.symbol.Variable(name='gesture_softmax_label')  # m

    att_gesture_label = mx.symbol.Reshape(mx.symbol.Concat(*[
        mx.symbol.Reshape(att_gesture_label, shape=(0, 1)) for i in range(1)
    ],
                                                           dim=0),
                                          shape=(-1, ))  # (1,)

    att_gesture_branch_kargs['label'] = att_gesture_label
    att_gesture_branch_kargs['grad_scale'] = 0.1 / rnn_window

    att_gesture_softmax, att_gesture_fc = get_branch(  # (1, 200)
        for_training=for_training,
        name='gesture',  # m
        data=z,  # (1, 512)
        num_class=num_cls,
        return_fc=True,
        use_ignore=True,  # ???
        **att_gesture_branch_kargs)

    loss.insert(0, att_gesture_softmax)

    net = loss[0] if len(loss) == 1 else mx.sym.Group(loss)
    return net
コード例 #10
0
def main():
    """Program entry point"""
    args = parse_args()
    if not any(args.loss == s for s in ['ctc', 'warpctc']):
        raise ValueError(
            "Invalid loss '{}' (must be 'ctc' or 'warpctc')".format(args.loss))

    hp = Hyperparams()

    # Start a multiprocessor captcha image generator
    mp_captcha = MPDigitCaptcha(font_paths=get_fonts(args.font_path),
                                h=hp.seq_length,
                                w=30,
                                num_digit_min=3,
                                num_digit_max=4,
                                num_processes=args.num_proc,
                                max_queue_size=hp.batch_size * 2)
    try:
        # Must call start() before any call to mxnet module (https://github.com/apache/incubator-mxnet/issues/9213)
        mp_captcha.start()

        if args.gpu:
            contexts = [mx.context.gpu(i) for i in range(args.gpu)]
        else:
            contexts = [mx.context.cpu(i) for i in range(args.cpu)]

        init_states = lstm.init_states(hp.batch_size, hp.num_lstm_layer,
                                       hp.num_hidden)

        data_train = OCRIter(hp.train_epoch_size // hp.batch_size,
                             hp.batch_size,
                             init_states,
                             captcha=mp_captcha,
                             name='train')
        data_val = OCRIter(hp.eval_epoch_size // hp.batch_size,
                           hp.batch_size,
                           init_states,
                           captcha=mp_captcha,
                           name='val')

        symbol = lstm.lstm_unroll(num_lstm_layer=hp.num_lstm_layer,
                                  seq_len=hp.seq_length,
                                  num_hidden=hp.num_hidden,
                                  num_label=hp.num_label,
                                  loss_type=args.loss)

        head = '%(asctime)-15s %(message)s'
        logging.basicConfig(level=logging.DEBUG, format=head)

        module = mx.mod.Module(symbol,
                               data_names=[
                                   'data', 'l0_init_c', 'l0_init_h',
                                   'l1_init_c', 'l1_init_h'
                               ],
                               label_names=['label'],
                               context=contexts)

        metrics = CtcMetrics(hp.seq_length)
        module.fit(
            train_data=data_train,
            eval_data=data_val,
            # use metrics.accuracy or metrics.accuracy_lcs
            eval_metric=mx.gluon.metric.np(metrics.accuracy,
                                           allow_extra_outputs=True),
            optimizer='sgd',
            optimizer_params={
                'learning_rate': hp.learning_rate,
                'momentum': hp.momentum,
                'wd': 0.00001,
            },
            initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
            num_epoch=hp.num_epoch,
            batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50),
            epoch_end_callback=mx.callback.do_checkpoint(args.prefix),
        )
    except KeyboardInterrupt:
        print("W: interrupt received, stopping...")
    finally:
        # Reset multiprocessing captcha generator to stop processes
        mp_captcha.reset()
コード例 #11
0
 def sym_gen(seq_len):
     return lstm_unroll(num_lstm_layer, seq_len, len(vocab),
                        num_hidden=num_hidden, num_embed=num_embed,
                        num_label=len(vocab))
コード例 #12
0
def main():
    args = parse_args()
    if not any(args.loss == s for s in ['ctc', 'warpctc']):
        raise ValueError("Invalid loss '{}' (must be 'ctc' or 'warpctc')".format(args.loss))

    hp = Hyperparams()

    # Start a multiprocessor captcha image generator
    mp_captcha = MPDigitCaptcha(
        font_paths=get_fonts(args.font_path), h=hp.seq_length, w=30,
        num_digit_min=3, num_digit_max=4, num_processes=args.num_proc, max_queue_size=hp.batch_size * 2)
    try:
        # Must call start() before any call to mxnet module (https://github.com/apache/incubator-mxnet/issues/9213)
        mp_captcha.start()

        if args.gpu:
            contexts = [mx.context.gpu(i) for i in range(args.gpu)]
        else:
            contexts = [mx.context.cpu(i) for i in range(args.cpu)]

        init_states = lstm.init_states(hp.batch_size, hp.num_lstm_layer, hp.num_hidden)

        data_train = OCRIter(
            hp.train_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, name='train')
        data_val = OCRIter(
            hp.eval_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, name='val')

        symbol = lstm.lstm_unroll(
            num_lstm_layer=hp.num_lstm_layer,
            seq_len=hp.seq_length,
            num_hidden=hp.num_hidden,
            num_label=hp.num_label,
            loss_type=args.loss)

        head = '%(asctime)-15s %(message)s'
        logging.basicConfig(level=logging.DEBUG, format=head)

        module = mx.mod.Module(
            symbol,
            data_names=['data', 'l0_init_c', 'l0_init_h', 'l1_init_c', 'l1_init_h'],
            label_names=['label'],
            context=contexts)

        metrics = CtcMetrics(hp.seq_length)
        module.fit(train_data=data_train,
                   eval_data=data_val,
                   # use metrics.accuracy or metrics.accuracy_lcs
                   eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True),
                   optimizer='sgd',
                   optimizer_params={'learning_rate': hp.learning_rate,
                                     'momentum': hp.momentum,
                                     'wd': 0.00001,
                                     },
                   initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
                   num_epoch=hp.num_epoch,
                   batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50),
                   epoch_end_callback=mx.callback.do_checkpoint(args.prefix),
                   )
    except KeyboardInterrupt:
        print("W: interrupt received, stopping...")
    finally:
        # Reset multiprocessing captcha generator to stop processes
        mp_captcha.reset()
コード例 #13
0
ファイル: lstm_ocr.py プロジェクト: ChisBread/mxnet
 def sym_gen(seq_len):
     return lstm_unroll(num_lstm_layer, seq_len,
                        num_hidden=num_hidden,
                        num_label = num_label)
コード例 #14
0
 def lstm_gen(seq_len):
     sym = lstm.lstm_unroll(nlayer, seq_len, len(vocab), nhidden,
                            nembed, nlabels, dropout)
     data_names = ['data'] + state_names
     label_names = ['label']
     return sym, data_names, label_names
コード例 #15
0
ファイル: test.py プロジェクト: jinfagang/Char-Generator-LSTM
if __name__ == '__main__':
    with open('obama.txt', 'r+') as f:
        print(f.read()[0:1000])

    vocab = build_vocab('obama.txt')
    print('vocab size = ', len(vocab))

    seq_len = 129
    num_embed = 256
    num_lstm_layer = 3
    num_hidden = 512

    symbol = lstm.lstm_unroll(num_lstm_layer,
                              seq_len,
                              len(vocab) + 1,
                              num_hidden=num_hidden,
                              num_embed=num_embed,
                              num_label=len(vocab) + 1,
                              dropout=0.2)

    batch_size = 32

    init_c = [('l%d_init_c' % l, (batch_size, num_hidden))
              for l in range(num_lstm_layer)]
    init_h = [('l%d_init_h' % l, (batch_size, num_hidden))
              for l in range(num_lstm_layer)]
    init_states = init_c + init_h

    print(init_c)
    data_train = bucket_io.BucketSentenceIter("./obama.txt",
                                              vocab, [seq_len],
コード例 #16
0
ファイル: train.py プロジェクト: MrPig/mx-char-rnn
 def sym_gen(seq_len):
     return lstm_unroll(num_lstm_layer=args.num_lstm_layer, seq_len=seq_len, input_size=vocab_size,
                        num_hidden=args.num_hidden, num_embed=args.num_embed, num_label=vocab_size, dropout=0.5)
コード例 #17
0
 def sym_gen(seq_len):
     return lstm_unroll(num_lstm_layer, seq_len, 10000,
                        num_hidden=num_hidden, num_embed=num_embed,
                        num_label=10000)
コード例 #18
0
ファイル: train_lstm.py プロジェクト: Answeror/mxnet
 def sym_gen(seq_len):
     sym = lstm_unroll(num_lstm_layer, seq_len, feat_dim, num_hidden=num_hidden,
                       num_label=label_dim)
     data_names = ['data'] + state_names
     label_names = ['softmax_label']
     return (sym, data_names, label_names)
コード例 #19
0
ファイル: lstm_ocr.py プロジェクト: xcgoner/dist-mxnet-udp
 def sym_gen(seq_len):
     return lstm_unroll(num_lstm_layer,
                        seq_len,
                        num_hidden=num_hidden,
                        num_label=num_label)
コード例 #20
0
ファイル: data.py プロジェクト: qinjian623/dlnotes
print('vocab size = %d' % (len(vocab)))


# Each line contains at most 129 chars.
seq_len = 129
# embedding dimension, which maps a character to a 256-dimension vector
num_embed = 256
# number of lstm layers
num_lstm_layer = 3
# hidden unit in LSTM cell
num_hidden = 512

symbol = lstm.lstm_unroll(
    num_lstm_layer,
    seq_len,
    len(vocab) + 1,
    num_hidden=num_hidden,
    num_embed=num_embed,
    num_label=len(vocab) + 1,
    dropout=0.2)

"""test_seq_len"""
data_file = open("./obama.txt")
for line in data_file:
    assert len(line) <= seq_len + 1, "seq_len is smaller than maximum line length. \
    Current line length is %d. Line content is: %s" % (len(line), line)

data_file.close()

# The batch size for training
batch_size = 32