예제 #1
0
def gen_inputs(batch_size, ilen, idim, olen, odim):
    labels, ilens = sequence_utils.gen_random_sequence(
        batch_size, ilen, idim)
    xs = []
    for l in ilens:
        xs.append(np.random.rand(l, idim).astype(dtype=np.float32))

    ys, _ = sequence_utils.gen_random_sequence(
        batch_size, olen, odim)
    return xs, ilens, ys
예제 #2
0
def main():
    np.random.seed(314)

    batch_size = 3
    sequence_length = 4
    num_vocabs = 10
    num_hidden = 5

    model_fn = MyLSTM(num_hidden, batch_size, sequence_length)

    labels, lengths = sequence_utils.gen_random_sequence(
        batch_size, sequence_length, num_vocabs)
    xs = []
    for l in lengths:
        xs.append(np.random.rand(l, num_hidden).astype(dtype=np.float32))

    h = np.zeros((batch_size, num_hidden), dtype=np.float32)
    c = np.zeros((batch_size, num_hidden), dtype=np.float32)
    mask = (np.expand_dims(np.arange(sequence_length), 0) < np.expand_dims(
        lengths, 1)).astype(np.float32)

    args = [xs, h, c, mask]

    testtools.generate_testcase(model_fn, args)
import ch2o

if __name__ == '__main__':
    import numpy as np
    np.random.seed(314)

    eprojs = 3
    dunits = 4
    att_dim = 5
    batch_size = 3
    sequence_length = 4
    num_vocabs = 10

    model_fn = lambda: AttDot(eprojs, dunits, att_dim)
    labels, ilens = sequence_utils.gen_random_sequence(batch_size,
                                                       sequence_length,
                                                       num_vocabs)
    xs = []
    for l in ilens:
        xs.append(np.random.rand(l, eprojs).astype(np.float32))

    # Check if our modification is valid.
    expected = model_fn().original(xs, None, None)
    actual = model_fn().forward(xs, None, None)
    for e, a in zip(expected, actual):
        assert np.allclose(e.array, a.array)

    ch2o.generate_testcase(model_fn, [xs, None, None])

    z = np.random.rand(batch_size, dunits).astype(np.float32)
    ch2o.generate_testcase(lambda: AttDotBackprop(eprojs, dunits, att_dim),