示例#1
0
文件: test_model.py 项目: huhoo/neon
def test_model_predict_rnn(backend):

    data_path = load_text('ptb-valid')

    data_set = Text(time_steps=50, path=data_path)

    # weight initialization
    init = Constant(0.08)

    # model initialization
    layers = [
        Recurrent(150, init, Logistic()),
        Affine(len(data_set.vocab), init, bias=init, activation=Rectlin())
    ]

    model = Model(layers=layers)
    output = model.predict(data_set)

    assert output.shape == (data_set.ndata, data_set.nclass)